Implement Serialization for List and Tuple
PiperOrigin-RevId: 455646308
diff --git a/tensorflow/core/function/trace_type/default_types.proto b/tensorflow/core/function/trace_type/default_types.proto
index ee46e74..2a00a97 100644
--- a/tensorflow/core/function/trace_type/default_types.proto
+++ b/tensorflow/core/function/trace_type/default_types.proto
@@ -14,6 +14,17 @@
}
}
+// Represents a serialized Tuple type.
+message SerializedTuple {
+ repeated tensorflow.core.function.trace_type.serialization.SerializedTraceType
+ components = 1;
+}
+
+// Represents a serialized List type.
+message SerializedList {
+ optional SerializedTuple components_tuple = 1;
+}
+
// Represents a serialized Dict type.
message SerializedDict {
repeated SerializedLiteral keys = 1;
diff --git a/tensorflow/core/function/trace_type/default_types.py b/tensorflow/core/function/trace_type/default_types.py
index f437def..cdc99a8 100644
--- a/tensorflow/core/function/trace_type/default_types.py
+++ b/tensorflow/core/function/trace_type/default_types.py
@@ -133,6 +133,131 @@
return f"{self.__class__.__name__}(ref={self._ref!r})"
+class Tuple(trace.TraceType, serialization.Serializable):
+ """Represents a tuple of TraceType objects."""
+
+ def __init__(self, *components: trace.TraceType):
+ self.components = components
+
+ def is_subtype_of(self, other: trace.TraceType) -> bool:
+ if (not isinstance(other, Tuple) or
+ len(self.components) != len(other.components)):
+ return False
+
+ return all(
+ self_component.is_subtype_of(other_component) for self_component,
+ other_component in zip(self.components, other.components))
+
+ def most_specific_common_supertype(
+ self, others: Sequence[trace.TraceType]) -> Optional["Tuple"]:
+ """See base class."""
+ if not all(
+ isinstance(other, Tuple) and
+ len(self.components) == len(other.components) for other in others):
+ return None
+
+ supertyped_components = []
+ for i, component in enumerate(self.components):
+ supertyped_component = component.most_specific_common_supertype(
+ [other.components[i] for other in others])
+ if supertyped_component is None:
+ return None
+ supertyped_components.append(supertyped_component)
+
+ return Tuple(*supertyped_components)
+
+ @classmethod
+ def type_proto(cls) -> Type[default_types_pb2.SerializedTuple]:
+ return default_types_pb2.SerializedTuple
+
+ @classmethod
+ def from_proto(cls, proto: default_types_pb2.SerializedTuple) -> "Tuple":
+ return Tuple(*[serialization.deserialize(c) for c in proto.components])
+
+ def to_proto(self) -> default_types_pb2.SerializedTuple:
+ return default_types_pb2.SerializedTuple(
+ components=[serialization.serialize(c) for c in self.components])
+
+ def _placeholder_value(self) -> Any:
+ components = [
+ component._placeholder_value() # pylint: disable=protected-access
+ for component in self.components
+ ]
+ return tuple(components)
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, trace.TraceType):
+ return NotImplemented
+
+ if not isinstance(other, Tuple):
+ return False
+
+ return self.components == other.components
+
+ def __hash__(self) -> int:
+ return hash(self.components)
+
+ def __repr__(self):
+ return f"Tuple(components={self.components!r})"
+
+
+class List(trace.TraceType, serialization.Serializable):
+ """Represents a list of TraceType objects."""
+
+ def __init__(self, *components: trace.TraceType):
+ self.components_tuple = Tuple(*components)
+
+ def is_subtype_of(self, other: trace.TraceType) -> bool:
+ if not isinstance(other, List):
+ return False
+
+ return self.components_tuple.is_subtype_of(other.components_tuple)
+
+ def most_specific_common_supertype(
+ self, others: Sequence[trace.TraceType]) -> Optional["Tuple"]:
+ """See base class."""
+ if not all(isinstance(other, List) for other in others):
+ return None
+
+ supertyped_components_tuple = self.components_tuple.most_specific_common_supertype(
+ [other.components_tuple for other in others])
+
+ if supertyped_components_tuple is None:
+ return None
+
+ return List(*supertyped_components_tuple.components)
+
+ @classmethod
+ def type_proto(cls) -> Type[default_types_pb2.SerializedList]:
+ return default_types_pb2.SerializedList
+
+ @classmethod
+ def from_proto(cls, proto: default_types_pb2.SerializedList) -> "List":
+ return List(*Tuple.from_proto(proto.components_tuple).components)
+
+ def to_proto(self) -> default_types_pb2.SerializedList:
+ return default_types_pb2.SerializedList(
+ components_tuple=self.components_tuple.to_proto())
+
+ def _placeholder_value(self) -> Any:
+ return list(self.components_tuple._placeholder_value()) # pylint: disable=protected-access
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, trace.TraceType):
+ return NotImplemented
+
+ if not isinstance(other, List):
+ return False
+
+ return self.components_tuple == other.components_tuple
+
+ def __hash__(self) -> int:
+ return hash(self.components_tuple)
+
+ def __repr__(self):
+ return f"List(components={self.components_tuple.components!r})"
+
+
class OrderedCollection(trace.TraceType):
"""Represents an ordered collection of TraceType objects.
@@ -192,54 +317,6 @@
f"{self.collection_type!r}, components={self.components!r})")
-class List(OrderedCollection):
- """Represents a list of TraceType objects."""
-
- def __init__(self, *components: trace.TraceType):
- super().__init__(list, components)
-
- def most_specific_common_supertype(
- self, types: Sequence[trace.TraceType]) -> Optional["List"]:
- """See base class."""
- if not all(self._shallow_equal(other) for other in types):
- return None
-
- new_components = self._supertype_components(types)
-
- return None if new_components is None else List(*new_components)
-
- def _placeholder_value(self) -> Any:
- components = [
- component._placeholder_value() # pylint: disable=protected-access
- for component in self.components
- ]
- return list(components)
-
-
-class Tuple(OrderedCollection):
- """Represents a tuple of TraceType objects."""
-
- def __init__(self, *components: trace.TraceType):
- super().__init__(tuple, components)
-
- def most_specific_common_supertype(
- self, types: Sequence[trace.TraceType]) -> Optional["Tuple"]:
- """See base class."""
- if not all(self._shallow_equal(other) for other in types):
- return None
-
- new_components = self._supertype_components(types)
-
- return None if new_components is None else Tuple(*new_components)
-
- def _placeholder_value(self) -> Any:
- components = [
- component._placeholder_value() # pylint: disable=protected-access
- for component in self.components
- ]
- return tuple(components)
-
-
class NamedTuple(OrderedCollection):
"""Represents a NamedTuple of TraceType objects."""
diff --git a/tensorflow/core/function/trace_type/default_types_test.py b/tensorflow/core/function/trace_type/default_types_test.py
index 66ef271..52f5586 100644
--- a/tensorflow/core/function/trace_type/default_types_test.py
+++ b/tensorflow/core/function/trace_type/default_types_test.py
@@ -107,6 +107,15 @@
MockSupertypes2With3(3), MockSupertypes2With3(3),
MockSupertypes2With3(3)))
+ def testListSerialization(self):
+ list_original = default_types.List(
+ default_types.Literal(1), default_types.Literal(2),
+ default_types.Literal(3))
+
+ self.assertEqual(
+ serialization.deserialize(serialization.serialize(list_original)),
+ list_original)
+
def testTupleSupertype(self):
tuple_a = default_types.Tuple(
MockSupertypes2With3(1), MockSupertypes2With3(2),
@@ -123,6 +132,15 @@
MockSupertypes2With3(3), MockSupertypes2With3(3),
MockSupertypes2With3(3)))
+ def testTupleSerialization(self):
+ tuple_original = default_types.Tuple(
+ default_types.Literal(1), default_types.Literal(2),
+ default_types.Literal(3))
+
+ self.assertEqual(
+ serialization.deserialize(serialization.serialize(tuple_original)),
+ tuple_original)
+
def testNamedTupleSupertype(self):
named_tuple_type = collections.namedtuple('MyNamedTuple', 'x y z')
tuple_a = default_types.NamedTuple(