Fix bug in StructuredTensorSpec._from_components.

PiperOrigin-RevId: 376041533
Change-Id: If11f9f70fb1ae2268463c4b685135ef37a64f768
diff --git a/tensorflow/python/ops/structured/structured_tensor.py b/tensorflow/python/ops/structured/structured_tensor.py
index 762f366..7e317d0 100644
--- a/tensorflow/python/ops/structured/structured_tensor.py
+++ b/tensorflow/python/ops/structured/structured_tensor.py
@@ -19,6 +19,7 @@
 from __future__ import division
 from __future__ import print_function
 
+import logging
 import re
 from typing import Callable, Dict, List, Sequence, Tuple, Union
 
@@ -37,6 +38,7 @@
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops.ragged import ragged_factory_ops
 from tensorflow.python.ops.ragged import ragged_tensor
+from tensorflow.python.ops.ragged import row_partition as row_partition_lib
 from tensorflow.python.ops.ragged.row_partition import RowPartition
 from tensorflow.python.util import compat
 from tensorflow.python.util import nest
@@ -1132,30 +1134,40 @@
     return StructuredTensor
 
   def _to_components(self, value):
-    if value._fields:
-      return value._fields
-    elif value.nrows() is None:
-      return ((), value.row_partitions)  # empty rank-0 structured tensor
-    else:
-      return (value.nrows(), value.row_partitions)
+    nrows = () if value.nrows() is None else value.nrows()
+    return (value._fields, nrows, value.row_partitions)
 
   def _from_components(self, components):
     if isinstance(components, dict):
-      fields = components
-      nrows = None
-      row_partitions = None
-    else:
+      logging.warning('Loading deprecated encoding for StructuredTensorSpec.')
+      return StructuredTensor.from_fields(components, self._shape,
+                                          validate=False)
+    elif not isinstance(components[0], dict):
+      logging.warning('Loading deprecated encoding for StructuredTensorSpec.')
       fields = {}
       nrows, row_partitions = components
       if isinstance(nrows, tuple) and not nrows:
         nrows = None  # empty rank-0 structured tensor
-    return StructuredTensor.from_fields(fields, self._shape, nrows=nrows,
-                                        row_partitions=row_partitions,
-                                        validate=False)
+      return StructuredTensor.from_fields(fields, self._shape, nrows=nrows,
+                                          row_partitions=row_partitions,
+                                          validate=False)
+
+    (fields, nrows, row_partitions) = components
+    if isinstance(nrows, tuple) and not nrows:
+      nrows = None  # empty rank-0 structured tensor
+    return StructuredTensor(fields, self._shape, nrows, row_partitions,
+                            internal=_structured_tensor_factory_key)
 
   @property
   def _component_specs(self):
-    return self._field_specs
+    if self._shape.rank == 0:
+      nrows_spec = ()
+    else:
+      nrows_spec = tensor_spec.TensorSpec([], dtypes.int64)
+
+    row_partition_specs = ((row_partition_lib.RowPartitionSpec(),)
+                           * (self._shape.rank - 1))
+    return (self._field_specs, nrows_spec, row_partition_specs)
 
   @classmethod
   def from_value(cls, value):
diff --git a/tensorflow/python/ops/structured/structured_tensor_spec_test.py b/tensorflow/python/ops/structured/structured_tensor_spec_test.py
index 5c63456..4fc182a 100644
--- a/tensorflow/python/ops/structured/structured_tensor_spec_test.py
+++ b/tensorflow/python/ops/structured/structured_tensor_spec_test.py
@@ -20,6 +20,7 @@
 
 from absl.testing import parameterized
 
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_util
@@ -44,6 +45,10 @@
 R_1_N_N = ragged_tensor.RaggedTensorSpec([1, None, None])
 R_2_1_N = ragged_tensor.RaggedTensorSpec([2, 1, None])
 
+# TensorSpecs for nrows & row_splits in the _to_components encoding.
+NROWS_SPEC = tensor_spec.TensorSpec([], dtypes.int64)
+PARTITION_SPEC = row_partition.RowPartitionSpec()
+
 
 # pylint: disable=g-long-lambda
 @test_util.run_all_in_graph_and_eager_modes
@@ -117,11 +122,14 @@
     self.assertEqual(serialization, expected)
 
   @parameterized.parameters([
-      (StructuredTensorSpec([1, 2, 3], {}), {}),
-      (StructuredTensorSpec([], {'a': T_1_2}), {'a': T_1_2}),
+      (StructuredTensorSpec([1, 2, 3], {}),
+       ({}, NROWS_SPEC, (PARTITION_SPEC, PARTITION_SPEC))),
+      (StructuredTensorSpec([], {'a': T_1_2}),
+       ({'a': T_1_2}, (), ())),
       (StructuredTensorSpec([1, 2], {'a': T_1_2, 'b': R_1_N}),
-       {'a': T_1_2, 'b': R_1_N}),
-      (StructuredTensorSpec([], {'a': T_1_2}), {'a': T_1_2}),
+       ({'a': T_1_2, 'b': R_1_N}, NROWS_SPEC, (PARTITION_SPEC,))),
+      (StructuredTensorSpec([], {'a': T_1_2}),
+       ({'a': T_1_2}, (), ())),
   ])  # pyformat: disable
   def testComponentSpecs(self, spec, expected):
     self.assertEqual(spec._component_specs, expected)
@@ -141,11 +149,11 @@
       },
   ])  # pyformat: disable
   def testToFromComponents(self, shape, fields, field_specs):
-    components = fields
     struct = StructuredTensor.from_fields(fields, shape)
     spec = StructuredTensorSpec(shape, field_specs)
     actual_components = spec._to_components(struct)
-    self.assertAllTensorsEqual(actual_components, components)
+    self.assertLen(actual_components, 3)
+    self.assertAllTensorsEqual(actual_components[0], fields)
     rt_reconstructed = spec._from_components(actual_components)
     self.assertAllEqual(struct, rt_reconstructed)
 
@@ -155,7 +163,7 @@
     components = spec._to_components(struct)
     rt_reconstructed = spec._from_components(components)
     self.assertAllEqual(struct, rt_reconstructed)
-    self.assertEqual(components, ((), ()))
+    self.assertEqual(components, ({}, (), ()))
 
   def testToFromComponentsEmptyTensor(self):
     struct = StructuredTensor.from_fields(fields={}, shape=[1, 2, 3])
@@ -163,8 +171,9 @@
     components = spec._to_components(struct)
     rt_reconstructed = spec._from_components(components)
     self.assertAllEqual(struct, rt_reconstructed)
-    self.assertLen(components, 2)
-    nrows, row_partitions = components
+    self.assertLen(components, 3)
+    fields, nrows, row_partitions = components
+    self.assertEmpty(fields)
     self.assertAllEqual(nrows, 1)
     self.assertLen(row_partitions, 2)
     self.assertIsInstance(row_partitions[0], row_partition.RowPartition)