`tf.data.experimental.dense_to_ragged_batch` to output variable ragged rank.

This means:
* For input tensors with shapes known statically, the output will still be a `tf.Tensor`.
* For input tensors with a `None` shape in the i-th axis, the output will be a `tf.RaggedTensor` with ragged rank of `i`, where it is the higher axis with `None` shape.

PiperOrigin-RevId: 303704654
Change-Id: Id4d232688d7a5e4ee0dbca6093743d27432e5de8
diff --git a/tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py
index 8fb92ec..e66f401 100644
--- a/tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py
@@ -45,11 +45,27 @@
   return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([x], x))
 
 
-def _make_matrix_ds(nrows):
+def _make_matrix_ds1(nrows):
   """Create a test dataset with matrix elements (of varying size)."""
   return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([x, 2], x))
 
 
+def _make_matrix_ds2(nrows):
+  """Create a test dataset with matrix elements (of varying size)."""
+  return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([2, x], x))
+
+
+def _make_matrix_ds_fully_defined(nrows):
+  """Create a test dataset with matrix elements (of varying size)."""
+  return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([2, 3], x))
+
+
+def _make_5dtensor_ds(nrows):
+  """Create a test dataset with matrix elements (of varying size)."""
+  return _make_scalar_ds(nrows).map(
+      lambda x: array_ops.fill([2, x, 3, 2*x, 4], x))
+
+
 def _make_ragged_ds(nrows):
   """Create a test dataset with RaggedTensor elements (of varying size)."""
   values = [[[i] * (i % 3) for i in range(j)] * (j % 3) for j in range(nrows)]
@@ -64,6 +80,8 @@
         'shape=[]': ops.convert_to_tensor(x),
         'shape=[x]': math_ops.range(x),
         'shape=[x, 2]': array_ops.fill([x, 2], x),
+        'shape=[2, x]': array_ops.fill([2, x], x),
+        'shape=[2, x, 3, 2x, 4]': array_ops.fill([2, x, 3, 2*x, 4], x)
     }
   return _make_scalar_ds(nrows).map(transform)
 
@@ -88,8 +106,9 @@
           test_base.default_test_combinations(),
           combinations.combine(
               make_dataset=[
-                  _make_scalar_ds, _make_vector_ds, _make_matrix_ds,
-                  _make_ragged_ds, _make_dict_ds, _make_tuple_ds,
+                  _make_scalar_ds, _make_vector_ds, _make_matrix_ds1,
+                  _make_matrix_ds2, _make_ragged_ds, _make_5dtensor_ds,
+                  _make_dict_ds, _make_tuple_ds, _make_matrix_ds_fully_defined,
               ],
               nrows=[0, 20, 23],
               batch_size=[4],
diff --git a/tensorflow/python/data/experimental/ops/batching.py b/tensorflow/python/data/experimental/ops/batching.py
index a2d3e00..398ec98 100644
--- a/tensorflow/python/data/experimental/ops/batching.py
+++ b/tensorflow/python/data/experimental/ops/batching.py
@@ -50,11 +50,23 @@
   batch from being produced.
 
   Unlike `tf.data.Dataset.batch`, the input elements to be batched may have
-  different shapes, and each batch will be encoded as a `tf.RaggedTensor`.
+  different shapes:
+
+  *  If an input element is a `tf.Tensor` whose static `tf.TensorShape` is
+     fully defined, then it is batched as normal.
+  *  If an input element is a `tf.Tensor` whose static `tf.TensorShape` contains
+     one or more axes with unknown size (i.e., `shape[i]=None`), then the output
+     will contain a `tf.RaggedTensor` that is ragged up to any of such
+     dimensions.
+  *  If an input element is a `tf.RaggedTensor` or any other type, then it is
+     batched as normal.
+
   Example:
 
   >>> dataset = tf.data.Dataset.from_tensor_slices(np.arange(6))
   >>> dataset = dataset.map(lambda x: tf.range(x))
+  >>> dataset.element_spec.shape
+  TensorShape([None])
   >>> dataset = dataset.apply(
   ...     tf.data.experimental.dense_to_ragged_batch(batch_size=2))
   >>> for batch in dataset:
@@ -385,32 +397,44 @@
         any new ragged tensors.  Existing `tf.RaggedTensor` elements do *not*
         have their row_splits dtype changed.
     """
-
     # Replace each TensorSpec in the input dataset's structure with a
     # corresponding RaggedTensorSpec.
     def to_ragged_spec(spec):
-      if isinstance(spec, tensor_spec.TensorSpec) and spec.shape.ndims != 0:
+      """Returns the new spec based on RaggedTensors."""
+      if (not isinstance(spec, tensor_spec.TensorSpec) or
+          spec.shape.rank is None or
+          spec.shape.is_fully_defined()):
+        return spec
+      else:
+        ragged_rank = max([
+            axis for (axis, size) in enumerate(spec.shape.as_list())
+            if size is None
+        ])
         return ragged_tensor.RaggedTensorSpec(
             shape=spec.shape,
             dtype=spec.dtype,
-            ragged_rank=0,
+            ragged_rank=ragged_rank,
             row_splits_dtype=row_splits_dtype)
-      else:
-        return spec
 
     self._structure = nest.map_structure(to_ragged_spec,
                                          input_dataset.element_spec)
 
     # Replace each tf.Tensor value in the input dataset with a variant-encoded
-    # RaggedTensor.  Since we're updating the corresponding structure to be
+    # RaggedTensor. Since we're updating the corresponding structure to be
     # a RaggedTensorSpec, this variant-encoded tensor will be decoded with
     # RaggedTensorSpec._from_tensor_list.
     def to_ragged_variant(value):
-      if isinstance(value, ops.Tensor) and value.shape.ndims != 0:
-        spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value))
-        return spec._to_tensor_list(value)[0]  # pylint: disable=protected-access
-      else:
+      """Re-encode Tensors as RaggedTensors."""
+      if (not isinstance(value, ops.Tensor) or
+          value.shape.rank is None or
+          value.shape.is_fully_defined()):
         return value
+      else:
+        spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value))
+        if spec._ragged_rank > 0:  # pylint: disable=protected-access
+          value = ragged_tensor.RaggedTensor.from_tensor(
+              value, ragged_rank=spec._ragged_rank)  # pylint: disable=protected-access
+        return spec._to_tensor_list(value)[0]  # pylint: disable=protected-access
 
     # Tuples are automatically unpacked by `dataset.map` so we repack them.
     if dataset_ops._should_unpack_args(input_dataset.element_spec):  # pylint: disable=protected-access