In RaggedTensorDynamicShape.from_tensor: Check for uniform_row_lengths.

PiperOrigin-RevId: 348125868
Change-Id: I287be5a37a93b63956ac5c3a71ebe2cd5f3602c7
diff --git a/tensorflow/python/ops/ragged/ragged_tensor_shape.py b/tensorflow/python/ops/ragged/ragged_tensor_shape.py
index c8635ac..eb6ac54 100644
--- a/tensorflow/python/ops/ragged/ragged_tensor_shape.py
+++ b/tensorflow/python/ops/ragged/ragged_tensor_shape.py
@@ -50,9 +50,7 @@
   Furthermore, there are two ways a dimension might be encoded:
 
     * "Partitioned dimensions" are dimensions that are encoded using a
-      `RaggedTensor`'s `nested_row_splits`.  The outermostmost partitioned
-      dimension must be uniform, and the innermost partitioned dimension must
-      be ragged.
+      `RowPartition`.  The outermostmost partitioned dimension must be uniform.
 
     * "Inner dimensions" are dimensions that are encoded using a
       `RaggedTensor`'s `flat_values`.  Inner dimensions are always uniform.
@@ -120,8 +118,6 @@
           dimension_size.shape.with_rank_at_most(1)
         if partitioned_dim_sizes[0].shape.ndims == 1:
           raise ValueError('outermost partitioned dimension must be uniform')
-        if partitioned_dim_sizes[-1].shape.ndims == 0:
-          raise ValueError('innermost partitioned dimension must be ragged')
       inner_dim_sizes.shape.assert_has_rank(1)
 
       # Convert dimension size tensors to a single dtype.
@@ -185,10 +181,17 @@
       if not ragged_tensor.is_ragged(rt_input):
         return cls([], array_ops.shape(rt_input))
       else:
-        partitioned_dim_sizes = (
-            (rt_input.nrows(),) + rt_input.nested_row_lengths())
+        partitioned_dim_sizes = [rt_input.nrows()]
+        rt = rt_input
+        while ragged_tensor.is_ragged(rt):
+          if rt.uniform_row_length is None:
+            partitioned_dim_sizes.append(rt.row_lengths())
+          else:
+            partitioned_dim_sizes.append(rt.uniform_row_length)
+          rt = rt.values
+
         return RaggedTensorDynamicShape(
-            partitioned_dim_sizes,
+            tuple(partitioned_dim_sizes),
             array_ops.shape(rt_input.flat_values)[1:],
             dim_size_dtype=dim_size_dtype)
 
diff --git a/tensorflow/python/ops/ragged/ragged_tensor_shape_test.py b/tensorflow/python/ops/ragged/ragged_tensor_shape_test.py
index 1e8aeee..2f588ea 100644
--- a/tensorflow/python/ops/ragged/ragged_tensor_shape_test.py
+++ b/tensorflow/python/ops/ragged/ragged_tensor_shape_test.py
@@ -30,6 +30,9 @@
 from tensorflow.python.platform import googletest
 
 
+# pylint: disable=g-long-lambda
+
+
 @test_util.run_all_in_graph_and_eager_modes
 class RaggedTensorShapeTest(test_util.TensorFlowTestCase,
                             parameterized.TestCase):
@@ -80,8 +83,15 @@
       dict(
           value=ragged_factory_ops.constant_value([[[1, 2], [3]], [[4, 5]]]),
           expected_dim_sizes=[2, [2, 1], [2, 1, 2]]),
+      dict(
+          value=lambda: ragged_tensor.RaggedTensor.from_uniform_row_length(
+              ragged_factory_ops.constant([[1, 2], [3, 4, 5], [], [6]]),
+              uniform_row_length=2),
+          expected_dim_sizes=[2, 2, [2, 3, 0, 1]]),
   ])
   def testFromTensor(self, value, expected_dim_sizes):
+    if callable(value):
+      value = value()
     shape = RaggedTensorDynamicShape.from_tensor(value)
     expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dim_sizes)
     self.assertShapeEq(shape, expected)