In RaggedTensorDynamicShape.from_tensor: Check for uniform_row_lengths.
PiperOrigin-RevId: 348125868
Change-Id: I287be5a37a93b63956ac5c3a71ebe2cd5f3602c7
diff --git a/tensorflow/python/ops/ragged/ragged_tensor_shape.py b/tensorflow/python/ops/ragged/ragged_tensor_shape.py
index c8635ac..eb6ac54 100644
--- a/tensorflow/python/ops/ragged/ragged_tensor_shape.py
+++ b/tensorflow/python/ops/ragged/ragged_tensor_shape.py
@@ -50,9 +50,7 @@
Furthermore, there are two ways a dimension might be encoded:
* "Partitioned dimensions" are dimensions that are encoded using a
- `RaggedTensor`'s `nested_row_splits`. The outermostmost partitioned
- dimension must be uniform, and the innermost partitioned dimension must
- be ragged.
+ `RowPartition`. The outermostmost partitioned dimension must be uniform.
* "Inner dimensions" are dimensions that are encoded using a
`RaggedTensor`'s `flat_values`. Inner dimensions are always uniform.
@@ -120,8 +118,6 @@
dimension_size.shape.with_rank_at_most(1)
if partitioned_dim_sizes[0].shape.ndims == 1:
raise ValueError('outermost partitioned dimension must be uniform')
- if partitioned_dim_sizes[-1].shape.ndims == 0:
- raise ValueError('innermost partitioned dimension must be ragged')
inner_dim_sizes.shape.assert_has_rank(1)
# Convert dimension size tensors to a single dtype.
@@ -185,10 +181,17 @@
if not ragged_tensor.is_ragged(rt_input):
return cls([], array_ops.shape(rt_input))
else:
- partitioned_dim_sizes = (
- (rt_input.nrows(),) + rt_input.nested_row_lengths())
+ partitioned_dim_sizes = [rt_input.nrows()]
+ rt = rt_input
+ while ragged_tensor.is_ragged(rt):
+ if rt.uniform_row_length is None:
+ partitioned_dim_sizes.append(rt.row_lengths())
+ else:
+ partitioned_dim_sizes.append(rt.uniform_row_length)
+ rt = rt.values
+
return RaggedTensorDynamicShape(
- partitioned_dim_sizes,
+ tuple(partitioned_dim_sizes),
array_ops.shape(rt_input.flat_values)[1:],
dim_size_dtype=dim_size_dtype)
diff --git a/tensorflow/python/ops/ragged/ragged_tensor_shape_test.py b/tensorflow/python/ops/ragged/ragged_tensor_shape_test.py
index 1e8aeee..2f588ea 100644
--- a/tensorflow/python/ops/ragged/ragged_tensor_shape_test.py
+++ b/tensorflow/python/ops/ragged/ragged_tensor_shape_test.py
@@ -30,6 +30,9 @@
from tensorflow.python.platform import googletest
+# pylint: disable=g-long-lambda
+
+
@test_util.run_all_in_graph_and_eager_modes
class RaggedTensorShapeTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@@ -80,8 +83,15 @@
dict(
value=ragged_factory_ops.constant_value([[[1, 2], [3]], [[4, 5]]]),
expected_dim_sizes=[2, [2, 1], [2, 1, 2]]),
+ dict(
+ value=lambda: ragged_tensor.RaggedTensor.from_uniform_row_length(
+ ragged_factory_ops.constant([[1, 2], [3, 4, 5], [], [6]]),
+ uniform_row_length=2),
+ expected_dim_sizes=[2, 2, [2, 3, 0, 1]]),
])
def testFromTensor(self, value, expected_dim_sizes):
+ if callable(value):
+ value = value()
shape = RaggedTensorDynamicShape.from_tensor(value)
expected = RaggedTensorDynamicShape.from_dim_sizes(expected_dim_sizes)
self.assertShapeEq(shape, expected)