Allow exact type recreation of OrderedCollection TraceTypes

PiperOrigin-RevId: 436569302
diff --git a/tensorflow/core/function/trace_type/BUILD b/tensorflow/core/function/trace_type/BUILD
index 7721646..2d2ed38 100644
--- a/tensorflow/core/function/trace_type/BUILD
+++ b/tensorflow/core/function/trace_type/BUILD
@@ -16,6 +16,8 @@
     deps = [
         ":default_types",
         "//tensorflow/python/types",
+        # TODO(b/225045380): Depend on the more specific `leaf` target once the used utils are moved.
+        "//tensorflow/python/util",
     ],
 )
 
@@ -63,5 +65,6 @@
     deps = [
         ":default_types",
         "//tensorflow/python/platform:client_testlib",
+        "@absl_py//absl/testing:parameterized",
     ],
 )
diff --git a/tensorflow/core/function/trace_type/default_types.py b/tensorflow/core/function/trace_type/default_types.py
index c306970..a2661fc 100644
--- a/tensorflow/core/function/trace_type/default_types.py
+++ b/tensorflow/core/function/trace_type/default_types.py
@@ -15,8 +15,8 @@
 
 """TraceType implementations for common Python types."""
 
+from typing import Any, Hashable, Optional, Sequence, Type
 from typing import Dict as PythonDict
-from typing import Hashable, Optional, Sequence, Type, Any
 from typing import Tuple as PythonTuple
 
 from tensorflow.python.types import trace
@@ -33,7 +33,7 @@
     return self == other
 
   def most_specific_common_supertype(
-      self, types: Sequence[trace.TraceType]) -> Optional[trace.TraceType]:
+      self, types: Sequence[trace.TraceType]) -> Optional["Generic"]:
     return self if all(self == other for other in types) else None
 
   def _placeholder_value(self) -> Any:
@@ -85,84 +85,136 @@
   """Represents an ordered collection of TraceType objects.
 
   Attributes:
+    collection_type: Python type for the collection (list, tuple etc.)
     components: A corresponding sequence of TraceTypes to the values in the
       collection.
   """
 
-  def __init__(self, *components: trace.TraceType):
+  def __init__(self, collection_type: Type[Any],
+               components: PythonTuple[trace.TraceType]):
+    self.collection_type = collection_type
     self.components = components
 
-  def _has_same_structure(self, other):
-    if not isinstance(other, type(self)):
-      return False
+  def _shallow_equal(self, other):
+    return (isinstance(other, OrderedCollection) and
+            self.collection_type == other.collection_type and
+            len(self.components) == len(other.components))
 
-    if len(self.components) != len(other.components):
-      return False
-
-    return True
-
-  def is_subtype_of(self, other: trace.TraceType) -> bool:
-    """See base class."""
-    if not self._has_same_structure(other):
-      return False
-
-    if not all([
-        component.is_subtype_of(other.components[i])
-        for i, component in enumerate(self.components)
-    ]):
-      return False
-
-    return True
-
-  def most_specific_common_supertype(self, types: Sequence[trace.TraceType]):
-    """See base class."""
-    if not all(self._has_same_structure(other) for other in types):
-      return None
-
+  def _supertype_components(
+      self, others: Sequence["OrderedCollection"]
+  ) -> Optional[Sequence[trace.TraceType]]:
+    """Helper that generates a list of per-component supertypes or None."""
     new_components = []
     for i, component in enumerate(self.components):
       common = component.most_specific_common_supertype(
-          [other.components[i] for other in types])
+          [other.components[i] for other in others])
       if common is None:
         return None
       else:
         new_components.append(common)
+    return new_components
 
-    return type(self)(*new_components)
+  def is_subtype_of(self, other: trace.TraceType) -> bool:
+    """See base class."""
+    if not self._shallow_equal(other):
+      return False
+
+    return all(
+        self_component.is_subtype_of(other_component) for self_component,
+        other_component in zip(self.components, other.components))
 
   def __eq__(self, other) -> bool:
     if not isinstance(other, trace.TraceType):
       return NotImplemented
 
-    if not self._has_same_structure(other):
+    if not self._shallow_equal(other):
       return False
 
     return self.components == other.components
 
   def __hash__(self) -> int:
-    return hash(self.components)
+    return hash((self.collection_type, self.components))
 
   def __repr__(self):
-    return f"{self.__class__.__name__}(components={self.components!r})"
+    return (f"{self.__class__.__name__}(collection_type="
+            f"{self.collection_type!r}, components={self.components!r})")
 
 
 class List(OrderedCollection):
+  """Represents a list of TraceType objects."""
+
+  def __init__(self, *components: trace.TraceType):
+    super().__init__(list, components)
+
+  def most_specific_common_supertype(
+      self, types: Sequence[trace.TraceType]) -> Optional["List"]:
+    """See base class."""
+    if not all(self._shallow_equal(other) for other in types):
+      return None
+
+    new_components = self._supertype_components(types)
+
+    return None if new_components is None else List(*new_components)
 
   def _placeholder_value(self) -> Any:
-    return [
+    components = [
         component._placeholder_value()  # pylint: disable=protected-access
         for component in self.components
     ]
+    return list(components)
 
 
 class Tuple(OrderedCollection):
+  """Represents a tuple of TraceType objects."""
+
+  def __init__(self, *components: trace.TraceType):
+    super().__init__(tuple, components)
+
+  def most_specific_common_supertype(
+      self, types: Sequence[trace.TraceType]) -> Optional["Tuple"]:
+    """See base class."""
+    if not all(self._shallow_equal(other) for other in types):
+      return None
+
+    new_components = self._supertype_components(types)
+
+    return None if new_components is None else Tuple(*new_components)
 
   def _placeholder_value(self) -> Any:
-    return tuple(component._placeholder_value()  # pylint: disable=protected-access
-                 for component in self.components)
+    components = [
+        component._placeholder_value()  # pylint: disable=protected-access
+        for component in self.components
+    ]
+    return tuple(components)
 
 
-class Attrs(OrderedCollection):
+class NamedTuple(OrderedCollection):
+  """Represents a NamedTuple of TraceType objects."""
+
+  def __init__(self, collection_type: Type[object],
+               attributes: PythonTuple[trace.TraceType]):
+    super().__init__(collection_type, attributes)
+
+  def most_specific_common_supertype(
+      self, types: Sequence[trace.TraceType]) -> Optional["NamedTuple"]:
+    """See base class."""
+    if not all(self._shallow_equal(other) for other in types):
+      return None
+
+    new_components = self._supertype_components(types)
+
+    return None if new_components is None else type(self)(self.collection_type,
+                                                          tuple(new_components))
+
+  def _placeholder_value(self) -> Any:
+    components = [
+        component._placeholder_value()  # pylint: disable=protected-access
+        for component in self.components
+    ]
+    return self.collection_type(*components)
+
+
+class Attrs(NamedTuple):
   """Represents a class annotated by attr.s.
 
   Each attr.s class has a fixed, ordered set of attributes. Therefore, we only
@@ -170,15 +222,6 @@
   metadata including attribute names can be ignored.
   """
 
-  def __init__(self, classtype: Type[object],
-               attributes: PythonTuple[trace.TraceType]):
-    super().__init__(Generic(classtype), *attributes)
-
-  def _placeholder_value(self) -> Any:
-    attrs_class = self.components[0]._placeholder_value()  # pylint: disable=protected-access
-    return attrs_class(*(component._placeholder_value()  # pylint: disable=protected-access
-                         for component in self.components[1:]))
-
 
 class Dict(trace.TraceType):
   """Represents a dictionary of TraceType objects.
@@ -209,7 +252,8 @@
     return all(self.mapping[key].is_subtype_of(other.mapping[key])
                for key in self.mapping)
 
-  def most_specific_common_supertype(self, types: Sequence[trace.TraceType]):
+  def most_specific_common_supertype(
+      self, types: Sequence[trace.TraceType]) -> Optional["Dict"]:
     """See base class."""
     if not all(self._has_same_structure(other) for other in types):
       return None
@@ -266,7 +310,7 @@
     return False
 
   def most_specific_common_supertype(
-      self, types: Sequence[trace.TraceType]) -> Optional[trace.TraceType]:
+      self, types: Sequence[trace.TraceType]) -> Optional["Reference"]:
     if all(
         isinstance(other, Reference) and self.identifier == other.identifier
         for other in types):
diff --git a/tensorflow/core/function/trace_type/default_types_test.py b/tensorflow/core/function/trace_type/default_types_test.py
index af6545a..9c268d2 100644
--- a/tensorflow/core/function/trace_type/default_types_test.py
+++ b/tensorflow/core/function/trace_type/default_types_test.py
@@ -14,6 +14,8 @@
 # ==============================================================================
 """Tests for default_types."""
 
+import collections
+
 from tensorflow.core.function.trace_type import default_types
 from tensorflow.python.platform import test
 
@@ -32,34 +34,7 @@
                      generic_a.most_specific_common_supertype([generic_c]))
     self.assertIsNone(generic_a.most_specific_common_supertype([generic_b]))
 
-  def testOrderedCollectionTypeEquality(self):
-    collection = default_types.OrderedCollection
-    generic = default_types.Generic
-    collection_a = collection(generic(1), generic(2), generic(3))
-    collection_b = collection(generic(1), generic(2), generic(1))
-    collection_c = collection(generic(1), generic(2), generic(3))
-
-    self.assertNotEqual(collection_a, collection_b)
-    self.assertEqual(collection_a, collection_c)
-    self.assertEqual(hash(collection_a), hash(collection_c))
-
-  def testOrderedCollectionTypeSubtype(self):
-
-    class Subtypable(default_types.Generic):
-
-      def is_subtype_of(self, other):
-        return self._object == 2 or other._object == 3
-
-    collection = default_types.OrderedCollection
-    collection_a = collection(Subtypable(1), Subtypable(2), Subtypable(3))
-    collection_b = collection(Subtypable(2), Subtypable(1), Subtypable(2))
-    collection_c = collection(Subtypable(1), Subtypable(3), Subtypable(3))
-
-    self.assertTrue(collection_b.is_subtype_of(collection_c))
-    self.assertFalse(collection_a.is_subtype_of(collection_b))
-    self.assertFalse(collection_c.is_subtype_of(collection_a))
-
-  def testOrderedCollectionTypeSupertype(self):
+  def testListSupertype(self):
 
     class Supertypable(default_types.Generic):
 
@@ -72,17 +47,67 @@
         else:
           return None
 
-    collection = default_types.OrderedCollection
-    collection_a = collection(Supertypable(1), Supertypable(2), Supertypable(3))
-    collection_b = collection(Supertypable(2), Supertypable(2), Supertypable(2))
+    list_a = default_types.List(
+        Supertypable(1), Supertypable(2), Supertypable(3))
+    list_b = default_types.List(
+        Supertypable(2), Supertypable(2), Supertypable(2))
 
-    self.assertEqual(collection_a,
-                     collection_a.most_specific_common_supertype([]))
-    self.assertIsNone(
-        collection_a.most_specific_common_supertype([collection_b]))
+    self.assertEqual(list_a, list_a.most_specific_common_supertype([]))
+    self.assertIsNone(list_a.most_specific_common_supertype([list_b]))
     self.assertEqual(
-        collection_b.most_specific_common_supertype([collection_a]),
-        collection(Supertypable(3), Supertypable(3), Supertypable(3)))
+        list_b.most_specific_common_supertype([list_a]),
+        default_types.List(Supertypable(3), Supertypable(3), Supertypable(3)))
+
+  def testTupleSupertype(self):
+
+    class Supertypable(default_types.Generic):
+
+      def most_specific_common_supertype(self, others):
+        if not others:
+          return self
+
+        if self._object == 2 and isinstance(others[0]._object, int):
+          return Supertypable(3)
+        else:
+          return None
+
+    tuple_a = default_types.Tuple(
+        Supertypable(1), Supertypable(2), Supertypable(3))
+    tuple_b = default_types.Tuple(
+        Supertypable(2), Supertypable(2), Supertypable(2))
+
+    self.assertEqual(tuple_a, tuple_a.most_specific_common_supertype([]))
+    self.assertIsNone(tuple_a.most_specific_common_supertype([tuple_b]))
+    self.assertEqual(
+        tuple_b.most_specific_common_supertype([tuple_a]),
+        default_types.Tuple(Supertypable(3), Supertypable(3), Supertypable(3)))
+
+  def testNamedTupleSupertype(self):
+
+    class Supertypable(default_types.Generic):
+
+      def most_specific_common_supertype(self, others):
+        if not others:
+          return self
+
+        if self._object == 2 and isinstance(others[0]._object, int):
+          return Supertypable(3)
+        else:
+          return None
+
+    named_tuple_type = collections.namedtuple('MyNamedTuple', 'x y z')
+    tuple_a = default_types.NamedTuple(
+        named_tuple_type, (Supertypable(1), Supertypable(2), Supertypable(3)))
+    tuple_b = default_types.NamedTuple(
+        named_tuple_type, (Supertypable(2), Supertypable(2), Supertypable(2)))
+
+    self.assertEqual(tuple_a, tuple_a.most_specific_common_supertype([]))
+    self.assertIsNone(tuple_a.most_specific_common_supertype([tuple_b]))
+    self.assertEqual(
+        tuple_b.most_specific_common_supertype([tuple_a]),
+        default_types.NamedTuple(
+            named_tuple_type,
+            (Supertypable(3), Supertypable(3), Supertypable(3))))
 
   def testDictTypeSubtype(self):
 
diff --git a/tensorflow/core/function/trace_type/signature_builder.py b/tensorflow/core/function/trace_type/signature_builder.py
index a5e12cb..836c0fe 100644
--- a/tensorflow/core/function/trace_type/signature_builder.py
+++ b/tensorflow/core/function/trace_type/signature_builder.py
@@ -20,6 +20,7 @@
 
 from tensorflow.core.function.trace_type import default_types
 from tensorflow.python.types import trace
+from tensorflow.python.util import nest
 
 
 class WeakrefDeletionObserver:
@@ -102,16 +103,22 @@
     return default_types.List(*(create_trace_type(c, context) for c in obj))
 
   if isinstance(obj, tuple):
-    return default_types.Tuple(*(create_trace_type(c, context) for c in obj))
+    if nest.is_namedtuple(obj):
+      return default_types.NamedTuple(
+          type(obj), tuple(create_trace_type(c, context) for c in obj))
+    else:
+      return default_types.Tuple(*(create_trace_type(c, context) for c in obj))
 
   if isinstance(obj, collections.abc.Mapping):
     return default_types.Dict(
         {k: create_trace_type(obj[k], context) for k in obj})
 
-  if hasattr(type(obj), "__attrs_attrs__"):
+  if nest.is_attrs(obj):
     return default_types.Attrs(
-        type(obj), (create_trace_type(getattr(obj, a.name), context)
-                    for a in obj.__attrs_attrs__))
+        type(obj),
+        tuple(
+            create_trace_type(getattr(obj, a.name), context)
+            for a in obj.__attrs_attrs__))
 
   if hasattr(obj, "__wrapped__"):
     return create_trace_type(obj.__wrapped__, context)
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index f38096f..42a660c 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -143,6 +143,8 @@
     raise TypeError("nest only supports dicts with sortable keys.")
 
 
+# TODO(b/225045380): Move utils like these to a "leaf" library, since they are
+# used in other places like TraceType.
 def is_namedtuple(instance, strict=False):
   """Returns True iff `instance` is a `namedtuple`.