Allow exact type recreation of OrderedCollection TraceTypes
PiperOrigin-RevId: 436569302
diff --git a/tensorflow/core/function/trace_type/BUILD b/tensorflow/core/function/trace_type/BUILD
index 7721646..2d2ed38 100644
--- a/tensorflow/core/function/trace_type/BUILD
+++ b/tensorflow/core/function/trace_type/BUILD
@@ -16,6 +16,8 @@
deps = [
":default_types",
"//tensorflow/python/types",
+ # TODO(b/225045380): Depend on the more specific `leaf` target once the used utils are moved.
+ "//tensorflow/python/util",
],
)
@@ -63,5 +65,6 @@
deps = [
":default_types",
"//tensorflow/python/platform:client_testlib",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/core/function/trace_type/default_types.py b/tensorflow/core/function/trace_type/default_types.py
index c306970..a2661fc 100644
--- a/tensorflow/core/function/trace_type/default_types.py
+++ b/tensorflow/core/function/trace_type/default_types.py
@@ -15,8 +15,8 @@
"""TraceType implementations for common Python types."""
+from typing import Any, Hashable, Optional, Sequence, Type
from typing import Dict as PythonDict
-from typing import Hashable, Optional, Sequence, Type, Any
from typing import Tuple as PythonTuple
from tensorflow.python.types import trace
@@ -33,7 +33,7 @@
return self == other
def most_specific_common_supertype(
- self, types: Sequence[trace.TraceType]) -> Optional[trace.TraceType]:
+ self, types: Sequence[trace.TraceType]) -> Optional["Generic"]:
return self if all(self == other for other in types) else None
def _placeholder_value(self) -> Any:
@@ -85,84 +85,136 @@
"""Represents an ordered collection of TraceType objects.
Attributes:
+ collection_type: Python type for the collection (list, tuple etc.)
components: A corresponding sequence of TraceTypes to the values in the
collection.
"""
- def __init__(self, *components: trace.TraceType):
+ def __init__(self, collection_type: Type[Any],
+ components: PythonTuple[trace.TraceType]):
+ self.collection_type = collection_type
self.components = components
- def _has_same_structure(self, other):
- if not isinstance(other, type(self)):
- return False
+ def _shallow_equal(self, other):
+ return (isinstance(other, OrderedCollection) and
+ self.collection_type == other.collection_type and
+ len(self.components) == len(other.components))
- if len(self.components) != len(other.components):
- return False
-
- return True
-
- def is_subtype_of(self, other: trace.TraceType) -> bool:
- """See base class."""
- if not self._has_same_structure(other):
- return False
-
- if not all([
- component.is_subtype_of(other.components[i])
- for i, component in enumerate(self.components)
- ]):
- return False
-
- return True
-
- def most_specific_common_supertype(self, types: Sequence[trace.TraceType]):
- """See base class."""
- if not all(self._has_same_structure(other) for other in types):
- return None
-
+ def _supertype_components(
+ self, others: Sequence["OrderedCollection"]
+ ) -> Optional[Sequence[trace.TraceType]]:
+ """Helper that generates a list of per-component supertypes or None."""
new_components = []
for i, component in enumerate(self.components):
common = component.most_specific_common_supertype(
- [other.components[i] for other in types])
+ [other.components[i] for other in others])
if common is None:
return None
else:
new_components.append(common)
+ return new_components
- return type(self)(*new_components)
+ def is_subtype_of(self, other: trace.TraceType) -> bool:
+ """See base class."""
+ if not self._shallow_equal(other):
+ return False
+
+ return all(
+ self_component.is_subtype_of(other_component) for self_component,
+ other_component in zip(self.components, other.components))
def __eq__(self, other) -> bool:
if not isinstance(other, trace.TraceType):
return NotImplemented
- if not self._has_same_structure(other):
+ if not self._shallow_equal(other):
return False
return self.components == other.components
def __hash__(self) -> int:
- return hash(self.components)
+ return hash((self.collection_type, self.components))
def __repr__(self):
- return f"{self.__class__.__name__}(components={self.components!r})"
+ return (f"{self.__class__.__name__}(collection_type="
+ f"{self.collection_type!r}, components={self.components!r})")
class List(OrderedCollection):
+ """Represents a list of TraceType objects."""
+
+ def __init__(self, *components: trace.TraceType):
+ super().__init__(list, components)
+
+ def most_specific_common_supertype(
+ self, types: Sequence[trace.TraceType]) -> Optional["List"]:
+ """See base class."""
+ if not all(self._shallow_equal(other) for other in types):
+ return None
+
+ new_components = self._supertype_components(types)
+
+ return None if new_components is None else List(*new_components)
def _placeholder_value(self) -> Any:
- return [
+ components = [
component._placeholder_value() # pylint: disable=protected-access
for component in self.components
]
+ return list(components)
class Tuple(OrderedCollection):
+ """Represents a tuple of TraceType objects."""
+
+ def __init__(self, *components: trace.TraceType):
+ super().__init__(tuple, components)
+
+ def most_specific_common_supertype(
+ self, types: Sequence[trace.TraceType]) -> Optional["Tuple"]:
+ """See base class."""
+ if not all(self._shallow_equal(other) for other in types):
+ return None
+
+ new_components = self._supertype_components(types)
+
+ return None if new_components is None else Tuple(*new_components)
def _placeholder_value(self) -> Any:
- return tuple(component._placeholder_value() # pylint: disable=protected-access
- for component in self.components)
+ components = [
+ component._placeholder_value() # pylint: disable=protected-access
+ for component in self.components
+ ]
+ return tuple(components)
-class Attrs(OrderedCollection):
+class NamedTuple(OrderedCollection):
+ """Represents a NamedTuple of TraceType objects."""
+
+ def __init__(self, collection_type: Type[object],
+ attributes: PythonTuple[trace.TraceType]):
+ super().__init__(collection_type, attributes)
+
+ def most_specific_common_supertype(
+ self, types: Sequence[trace.TraceType]) -> Optional["NamedTuple"]:
+ """See base class."""
+ if not all(self._shallow_equal(other) for other in types):
+ return None
+
+ new_components = self._supertype_components(types)
+
+ return None if new_components is None else type(self)(self.collection_type,
+ tuple(new_components))
+
+ def _placeholder_value(self) -> Any:
+ components = [
+ component._placeholder_value() # pylint: disable=protected-access
+ for component in self.components
+ ]
+ return self.collection_type(*components)
+
+
+class Attrs(NamedTuple):
"""Represents a class annotated by attr.s.
Each attr.s class has a fixed, ordered set of attributes. Therefore, we only
@@ -170,15 +222,6 @@
metadata including attribute names can be ignored.
"""
- def __init__(self, classtype: Type[object],
- attributes: PythonTuple[trace.TraceType]):
- super().__init__(Generic(classtype), *attributes)
-
- def _placeholder_value(self) -> Any:
- attrs_class = self.components[0]._placeholder_value() # pylint: disable=protected-access
- return attrs_class(*(component._placeholder_value() # pylint: disable=protected-access
- for component in self.components[1:]))
-
class Dict(trace.TraceType):
"""Represents a dictionary of TraceType objects.
@@ -209,7 +252,8 @@
return all(self.mapping[key].is_subtype_of(other.mapping[key])
for key in self.mapping)
- def most_specific_common_supertype(self, types: Sequence[trace.TraceType]):
+ def most_specific_common_supertype(
+ self, types: Sequence[trace.TraceType]) -> Optional["Dict"]:
"""See base class."""
if not all(self._has_same_structure(other) for other in types):
return None
@@ -266,7 +310,7 @@
return False
def most_specific_common_supertype(
- self, types: Sequence[trace.TraceType]) -> Optional[trace.TraceType]:
+ self, types: Sequence[trace.TraceType]) -> Optional["Reference"]:
if all(
isinstance(other, Reference) and self.identifier == other.identifier
for other in types):
diff --git a/tensorflow/core/function/trace_type/default_types_test.py b/tensorflow/core/function/trace_type/default_types_test.py
index af6545a..9c268d2 100644
--- a/tensorflow/core/function/trace_type/default_types_test.py
+++ b/tensorflow/core/function/trace_type/default_types_test.py
@@ -14,6 +14,8 @@
# ==============================================================================
"""Tests for default_types."""
+import collections
+
from tensorflow.core.function.trace_type import default_types
from tensorflow.python.platform import test
@@ -32,34 +34,7 @@
generic_a.most_specific_common_supertype([generic_c]))
self.assertIsNone(generic_a.most_specific_common_supertype([generic_b]))
- def testOrderedCollectionTypeEquality(self):
- collection = default_types.OrderedCollection
- generic = default_types.Generic
- collection_a = collection(generic(1), generic(2), generic(3))
- collection_b = collection(generic(1), generic(2), generic(1))
- collection_c = collection(generic(1), generic(2), generic(3))
-
- self.assertNotEqual(collection_a, collection_b)
- self.assertEqual(collection_a, collection_c)
- self.assertEqual(hash(collection_a), hash(collection_c))
-
- def testOrderedCollectionTypeSubtype(self):
-
- class Subtypable(default_types.Generic):
-
- def is_subtype_of(self, other):
- return self._object == 2 or other._object == 3
-
- collection = default_types.OrderedCollection
- collection_a = collection(Subtypable(1), Subtypable(2), Subtypable(3))
- collection_b = collection(Subtypable(2), Subtypable(1), Subtypable(2))
- collection_c = collection(Subtypable(1), Subtypable(3), Subtypable(3))
-
- self.assertTrue(collection_b.is_subtype_of(collection_c))
- self.assertFalse(collection_a.is_subtype_of(collection_b))
- self.assertFalse(collection_c.is_subtype_of(collection_a))
-
- def testOrderedCollectionTypeSupertype(self):
+ def testListSupertype(self):
class Supertypable(default_types.Generic):
@@ -72,17 +47,67 @@
else:
return None
- collection = default_types.OrderedCollection
- collection_a = collection(Supertypable(1), Supertypable(2), Supertypable(3))
- collection_b = collection(Supertypable(2), Supertypable(2), Supertypable(2))
+ list_a = default_types.List(
+ Supertypable(1), Supertypable(2), Supertypable(3))
+ list_b = default_types.List(
+ Supertypable(2), Supertypable(2), Supertypable(2))
- self.assertEqual(collection_a,
- collection_a.most_specific_common_supertype([]))
- self.assertIsNone(
- collection_a.most_specific_common_supertype([collection_b]))
+ self.assertEqual(list_a, list_a.most_specific_common_supertype([]))
+ self.assertIsNone(list_a.most_specific_common_supertype([list_b]))
self.assertEqual(
- collection_b.most_specific_common_supertype([collection_a]),
- collection(Supertypable(3), Supertypable(3), Supertypable(3)))
+ list_b.most_specific_common_supertype([list_a]),
+ default_types.List(Supertypable(3), Supertypable(3), Supertypable(3)))
+
+ def testTupleSupertype(self):
+
+ class Supertypable(default_types.Generic):
+
+ def most_specific_common_supertype(self, others):
+ if not others:
+ return self
+
+ if self._object == 2 and isinstance(others[0]._object, int):
+ return Supertypable(3)
+ else:
+ return None
+
+ tuple_a = default_types.Tuple(
+ Supertypable(1), Supertypable(2), Supertypable(3))
+ tuple_b = default_types.Tuple(
+ Supertypable(2), Supertypable(2), Supertypable(2))
+
+ self.assertEqual(tuple_a, tuple_a.most_specific_common_supertype([]))
+ self.assertIsNone(tuple_a.most_specific_common_supertype([tuple_b]))
+ self.assertEqual(
+ tuple_b.most_specific_common_supertype([tuple_a]),
+ default_types.Tuple(Supertypable(3), Supertypable(3), Supertypable(3)))
+
+ def testNamedTupleSupertype(self):
+
+ class Supertypable(default_types.Generic):
+
+ def most_specific_common_supertype(self, others):
+ if not others:
+ return self
+
+ if self._object == 2 and isinstance(others[0]._object, int):
+ return Supertypable(3)
+ else:
+ return None
+
+ named_tuple_type = collections.namedtuple('MyNamedTuple', 'x y z')
+ tuple_a = default_types.NamedTuple(
+ named_tuple_type, (Supertypable(1), Supertypable(2), Supertypable(3)))
+ tuple_b = default_types.NamedTuple(
+ named_tuple_type, (Supertypable(2), Supertypable(2), Supertypable(2)))
+
+ self.assertEqual(tuple_a, tuple_a.most_specific_common_supertype([]))
+ self.assertIsNone(tuple_a.most_specific_common_supertype([tuple_b]))
+ self.assertEqual(
+ tuple_b.most_specific_common_supertype([tuple_a]),
+ default_types.NamedTuple(
+ named_tuple_type,
+ (Supertypable(3), Supertypable(3), Supertypable(3))))
def testDictTypeSubtype(self):
diff --git a/tensorflow/core/function/trace_type/signature_builder.py b/tensorflow/core/function/trace_type/signature_builder.py
index a5e12cb..836c0fe 100644
--- a/tensorflow/core/function/trace_type/signature_builder.py
+++ b/tensorflow/core/function/trace_type/signature_builder.py
@@ -20,6 +20,7 @@
from tensorflow.core.function.trace_type import default_types
from tensorflow.python.types import trace
+from tensorflow.python.util import nest
class WeakrefDeletionObserver:
@@ -102,16 +103,22 @@
return default_types.List(*(create_trace_type(c, context) for c in obj))
if isinstance(obj, tuple):
- return default_types.Tuple(*(create_trace_type(c, context) for c in obj))
+ if nest.is_namedtuple(obj):
+ return default_types.NamedTuple(
+ type(obj), tuple(create_trace_type(c, context) for c in obj))
+ else:
+ return default_types.Tuple(*(create_trace_type(c, context) for c in obj))
if isinstance(obj, collections.abc.Mapping):
return default_types.Dict(
{k: create_trace_type(obj[k], context) for k in obj})
- if hasattr(type(obj), "__attrs_attrs__"):
+ if nest.is_attrs(obj):
return default_types.Attrs(
- type(obj), (create_trace_type(getattr(obj, a.name), context)
- for a in obj.__attrs_attrs__))
+ type(obj),
+ tuple(
+ create_trace_type(getattr(obj, a.name), context)
+ for a in obj.__attrs_attrs__))
if hasattr(obj, "__wrapped__"):
return create_trace_type(obj.__wrapped__, context)
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index f38096f..42a660c 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -143,6 +143,8 @@
raise TypeError("nest only supports dicts with sortable keys.")
+# TODO(b/225045380): Move utils like these to a "leaf" library, since they are
+# used in other places like TraceType.
def is_namedtuple(instance, strict=False):
"""Returns True iff `instance` is a `namedtuple`.