Changing RowPartition.static_nrows to return Optional[int] instead of Optional[Dimension].

This now agrees with the documentation.

PiperOrigin-RevId: 428650715
Change-Id: Ic36466594b3c75acc88073328b88fb3da57674ed
diff --git a/tensorflow/python/ops/ragged/row_partition.py b/tensorflow/python/ops/ragged/row_partition.py
index 605c60f..b5dc971 100644
--- a/tensorflow/python/ops/ragged/row_partition.py
+++ b/tensorflow/python/ops/ragged/row_partition.py
@@ -875,15 +875,15 @@
       or `None` (otherwise).
     """
     if self._row_splits is not None:
-      nrows = tensor_shape.dimension_at_index(self._row_splits.shape, 0) - 1
-      if nrows.value is not None:
-        return nrows
+      nrows_plus_one = tensor_shape.dimension_value(self._row_splits.shape[0])
+      if nrows_plus_one is not None:
+        return nrows_plus_one - 1
     if self._row_lengths is not None:
-      nrows = tensor_shape.dimension_at_index(self._row_lengths.shape, 0)
-      if nrows.value is not None:
+      nrows = tensor_shape.dimension_value(self._row_lengths.shape[0])
+      if nrows is not None:
         return nrows
     if self._nrows is not None:
-      return tensor_shape.Dimension(tensor_util.constant_value(self._nrows))
+      return tensor_util.constant_value(self._nrows)
     return None
 
   @property
diff --git a/tensorflow/python/ops/ragged/row_partition_test.py b/tensorflow/python/ops/ragged/row_partition_test.py
index 27b5d72..580f31f 100644
--- a/tensorflow/python/ops/ragged/row_partition_test.py
+++ b/tensorflow/python/ops/ragged/row_partition_test.py
@@ -852,6 +852,21 @@
     rp = RowPartition.from_row_starts([0, 3, 6], nvals=12)
     self.assertAllEqual(12, rp.static_nvals)
 
+  def testStaticNrows(self):
+    rp = RowPartition.from_row_splits([0, 3, 4, 5])
+    static_nrows = rp.static_nrows
+    self.assertIsInstance(static_nrows, int)
+    self.assertAllEqual(3, static_nrows)
+
+  def testStaticNrowsUnknown(self):
+    @def_function.function(
+        input_signature=[tensor_spec.TensorSpec(None, dtypes.int32)])
+    def foo(rs):
+      rp = RowPartition.from_row_splits(rs)
+      static_nrows = rp.static_nrows
+      self.assertIsNone(static_nrows)
+    foo(array_ops.constant([0, 3, 4, 5], dtype=dtypes.int32))
+
 
 @test_util.run_all_in_graph_and_eager_modes
 class RowPartitionSpecTest(test_util.TensorFlowTestCase,