Implement Serialization for List and Tuple

PiperOrigin-RevId: 455646308
diff --git a/tensorflow/core/function/trace_type/default_types.proto b/tensorflow/core/function/trace_type/default_types.proto
index ee46e74..2a00a97 100644
--- a/tensorflow/core/function/trace_type/default_types.proto
+++ b/tensorflow/core/function/trace_type/default_types.proto
@@ -14,6 +14,17 @@
   }
 }
 
+// Represents a serialized Tuple type.
+message SerializedTuple {
+  repeated tensorflow.core.function.trace_type.serialization.SerializedTraceType
+      components = 1;
+}
+
+// Represents a serialized List type.
+message SerializedList {
+  optional SerializedTuple components_tuple = 1;
+}
+
 // Represents a serialized Dict type.
 message SerializedDict {
   repeated SerializedLiteral keys = 1;
diff --git a/tensorflow/core/function/trace_type/default_types.py b/tensorflow/core/function/trace_type/default_types.py
index f437def..cdc99a8 100644
--- a/tensorflow/core/function/trace_type/default_types.py
+++ b/tensorflow/core/function/trace_type/default_types.py
@@ -133,6 +133,131 @@
     return f"{self.__class__.__name__}(ref={self._ref!r})"
 
 
+class Tuple(trace.TraceType, serialization.Serializable):
+  """Represents a tuple of TraceType objects."""
+
+  def __init__(self, *components: trace.TraceType):
+    self.components = components
+
+  def is_subtype_of(self, other: trace.TraceType) -> bool:
+    if (not isinstance(other, Tuple) or
+        len(self.components) != len(other.components)):
+      return False
+
+    return all(
+        self_component.is_subtype_of(other_component) for self_component,
+        other_component in zip(self.components, other.components))
+
+  def most_specific_common_supertype(
+      self, others: Sequence[trace.TraceType]) -> Optional["Tuple"]:
+    """See base class."""
+    if not all(
+        isinstance(other, Tuple) and
+        len(self.components) == len(other.components) for other in others):
+      return None
+
+    supertyped_components = []
+    for i, component in enumerate(self.components):
+      supertyped_component = component.most_specific_common_supertype(
+          [other.components[i] for other in others])
+      if supertyped_component is None:
+        return None
+      supertyped_components.append(supertyped_component)
+
+    return Tuple(*supertyped_components)
+
+  @classmethod
+  def type_proto(cls) -> Type[default_types_pb2.SerializedTuple]:
+    return default_types_pb2.SerializedTuple
+
+  @classmethod
+  def from_proto(cls, proto: default_types_pb2.SerializedTuple) -> "Tuple":
+    return Tuple(*[serialization.deserialize(c) for c in proto.components])
+
+  def to_proto(self) -> default_types_pb2.SerializedTuple:
+    return default_types_pb2.SerializedTuple(
+        components=[serialization.serialize(c) for c in self.components])
+
+  def _placeholder_value(self) -> Any:
+    components = [
+        component._placeholder_value()  # pylint: disable=protected-access
+        for component in self.components
+    ]
+    return tuple(components)
+
+  def __eq__(self, other: Any) -> bool:
+    if not isinstance(other, trace.TraceType):
+      return NotImplemented
+
+    if not isinstance(other, Tuple):
+      return False
+
+    return self.components == other.components
+
+  def __hash__(self) -> int:
+    return hash(self.components)
+
+  def __repr__(self):
+    return f"Tuple(components={self.components!r})"
+
+
+class List(trace.TraceType, serialization.Serializable):
+  """Represents a list of TraceType objects."""
+
+  def __init__(self, *components: trace.TraceType):
+    self.components_tuple = Tuple(*components)
+
+  def is_subtype_of(self, other: trace.TraceType) -> bool:
+    if not isinstance(other, List):
+      return False
+
+    return self.components_tuple.is_subtype_of(other.components_tuple)
+
+  def most_specific_common_supertype(
+      self, others: Sequence[trace.TraceType]) -> Optional["Tuple"]:
+    """See base class."""
+    if not all(isinstance(other, List) for other in others):
+      return None
+
+    supertyped_components_tuple = self.components_tuple.most_specific_common_supertype(
+        [other.components_tuple for other in others])
+
+    if supertyped_components_tuple is None:
+      return None
+
+    return List(*supertyped_components_tuple.components)
+
+  @classmethod
+  def type_proto(cls) -> Type[default_types_pb2.SerializedList]:
+    return default_types_pb2.SerializedList
+
+  @classmethod
+  def from_proto(cls, proto: default_types_pb2.SerializedList) -> "List":
+    return List(*Tuple.from_proto(proto.components_tuple).components)
+
+  def to_proto(self) -> default_types_pb2.SerializedList:
+    return default_types_pb2.SerializedList(
+        components_tuple=self.components_tuple.to_proto())
+
+  def _placeholder_value(self) -> Any:
+    return list(self.components_tuple._placeholder_value())  # pylint: disable=protected-access
+
+  def __eq__(self, other: Any) -> bool:
+    if not isinstance(other, trace.TraceType):
+      return NotImplemented
+
+    if not isinstance(other, List):
+      return False
+
+    return self.components_tuple == other.components_tuple
+
+  def __hash__(self) -> int:
+    return hash(self.components_tuple)
+
+  def __repr__(self):
+    return f"List(components={self.components_tuple.components!r})"
+
+
 class OrderedCollection(trace.TraceType):
   """Represents an ordered collection of TraceType objects.
 
@@ -192,54 +317,6 @@
             f"{self.collection_type!r}, components={self.components!r})")
 
 
-class List(OrderedCollection):
-  """Represents a list of TraceType objects."""
-
-  def __init__(self, *components: trace.TraceType):
-    super().__init__(list, components)
-
-  def most_specific_common_supertype(
-      self, types: Sequence[trace.TraceType]) -> Optional["List"]:
-    """See base class."""
-    if not all(self._shallow_equal(other) for other in types):
-      return None
-
-    new_components = self._supertype_components(types)
-
-    return None if new_components is None else List(*new_components)
-
-  def _placeholder_value(self) -> Any:
-    components = [
-        component._placeholder_value()  # pylint: disable=protected-access
-        for component in self.components
-    ]
-    return list(components)
-
-
-class Tuple(OrderedCollection):
-  """Represents a tuple of TraceType objects."""
-
-  def __init__(self, *components: trace.TraceType):
-    super().__init__(tuple, components)
-
-  def most_specific_common_supertype(
-      self, types: Sequence[trace.TraceType]) -> Optional["Tuple"]:
-    """See base class."""
-    if not all(self._shallow_equal(other) for other in types):
-      return None
-
-    new_components = self._supertype_components(types)
-
-    return None if new_components is None else Tuple(*new_components)
-
-  def _placeholder_value(self) -> Any:
-    components = [
-        component._placeholder_value()  # pylint: disable=protected-access
-        for component in self.components
-    ]
-    return tuple(components)
-
-
 class NamedTuple(OrderedCollection):
   """Represents a NamedTuple of TraceType objects."""
 
diff --git a/tensorflow/core/function/trace_type/default_types_test.py b/tensorflow/core/function/trace_type/default_types_test.py
index 66ef271..52f5586 100644
--- a/tensorflow/core/function/trace_type/default_types_test.py
+++ b/tensorflow/core/function/trace_type/default_types_test.py
@@ -107,6 +107,15 @@
             MockSupertypes2With3(3), MockSupertypes2With3(3),
             MockSupertypes2With3(3)))
 
+  def testListSerialization(self):
+    list_original = default_types.List(
+        default_types.Literal(1), default_types.Literal(2),
+        default_types.Literal(3))
+
+    self.assertEqual(
+        serialization.deserialize(serialization.serialize(list_original)),
+        list_original)
+
   def testTupleSupertype(self):
     tuple_a = default_types.Tuple(
         MockSupertypes2With3(1), MockSupertypes2With3(2),
@@ -123,6 +132,15 @@
             MockSupertypes2With3(3), MockSupertypes2With3(3),
             MockSupertypes2With3(3)))
 
+  def testTupleSerialization(self):
+    tuple_original = default_types.Tuple(
+        default_types.Literal(1), default_types.Literal(2),
+        default_types.Literal(3))
+
+    self.assertEqual(
+        serialization.deserialize(serialization.serialize(tuple_original)),
+        tuple_original)
+
   def testNamedTupleSupertype(self):
     named_tuple_type = collections.namedtuple('MyNamedTuple', 'x y z')
     tuple_a = default_types.NamedTuple(