Changing RowPartition.static_nrows to return Optional[int] instead of Optional[Dimension].
This now agrees with the documentation.
PiperOrigin-RevId: 428650715
Change-Id: Ic36466594b3c75acc88073328b88fb3da57674ed
diff --git a/tensorflow/python/ops/ragged/row_partition.py b/tensorflow/python/ops/ragged/row_partition.py
index 605c60f..b5dc971 100644
--- a/tensorflow/python/ops/ragged/row_partition.py
+++ b/tensorflow/python/ops/ragged/row_partition.py
@@ -875,15 +875,15 @@
or `None` (otherwise).
"""
if self._row_splits is not None:
- nrows = tensor_shape.dimension_at_index(self._row_splits.shape, 0) - 1
- if nrows.value is not None:
- return nrows
+ nrows_plus_one = tensor_shape.dimension_value(self._row_splits.shape[0])
+ if nrows_plus_one is not None:
+ return nrows_plus_one - 1
if self._row_lengths is not None:
- nrows = tensor_shape.dimension_at_index(self._row_lengths.shape, 0)
- if nrows.value is not None:
+ nrows = tensor_shape.dimension_value(self._row_lengths.shape[0])
+ if nrows is not None:
return nrows
if self._nrows is not None:
- return tensor_shape.Dimension(tensor_util.constant_value(self._nrows))
+ return tensor_util.constant_value(self._nrows)
return None
@property
diff --git a/tensorflow/python/ops/ragged/row_partition_test.py b/tensorflow/python/ops/ragged/row_partition_test.py
index 27b5d72..580f31f 100644
--- a/tensorflow/python/ops/ragged/row_partition_test.py
+++ b/tensorflow/python/ops/ragged/row_partition_test.py
@@ -852,6 +852,21 @@
rp = RowPartition.from_row_starts([0, 3, 6], nvals=12)
self.assertAllEqual(12, rp.static_nvals)
+ def testStaticNrows(self):
+ rp = RowPartition.from_row_splits([0, 3, 4, 5])
+ static_nrows = rp.static_nrows
+ self.assertIsInstance(static_nrows, int)
+ self.assertAllEqual(3, static_nrows)
+
+ def testStaticNrowsUnknown(self):
+ @def_function.function(
+ input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
+ def foo(rs):
+ rp = RowPartition.from_row_splits(rs)
+ static_nrows = rp.static_nrows
+ self.assertIsNone(static_nrows)
+ foo(array_ops.constant([0, 3, 4, 5], dtype=dtypes.int32))
+
@test_util.run_all_in_graph_and_eager_modes
class RowPartitionSpecTest(test_util.TensorFlowTestCase,