Set up TraceType as an independent core/function package
PiperOrigin-RevId: 418082697
Change-Id: I74ce638f178c8712d20f00a9db5d7ef11484de71
diff --git a/tensorflow/core/function/trace_type/BUILD b/tensorflow/core/function/trace_type/BUILD
new file mode 100644
index 0000000..5a0a056
--- /dev/null
+++ b/tensorflow/core/function/trace_type/BUILD
@@ -0,0 +1,65 @@
+load("//tensorflow:tensorflow.bzl", "pytype_strict_library")
+load("//tensorflow:tensorflow.bzl", "py_strict_test")
+
+package(
+ licenses = ["notice"],
+)
+
+pytype_strict_library(
+ name = "trace_type",
+ srcs = [
+ "__init__.py",
+ "signature_builder.py",
+ ],
+ srcs_version = "PY3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":default_types",
+ "//tensorflow/python/types",
+ ],
+)
+
+py_strict_test(
+ name = "trace_type_test",
+ srcs = ["trace_type_test.py"],
+ python_version = "PY3",
+ deps = [
+ ":default_types",
+ ":trace_type",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/compat:v2_compat",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/eager:function",
+ "//tensorflow/python/framework:combinations",
+ "//tensorflow/python/framework:dtypes",
+ "//tensorflow/python/framework:tensor_spec",
+ "//tensorflow/python/framework:test_lib",
+ "//tensorflow/python/ops/ragged:ragged_tensor",
+ "//tensorflow/python/platform:client_testlib",
+ "//tensorflow/python/types",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+pytype_strict_library(
+ name = "default_types",
+ srcs = [
+ "default_types.py",
+ ],
+ srcs_version = "PY3",
+ visibility = ["//tensorflow:internal"],
+ deps = ["//tensorflow/python/types"],
+)
+
+py_strict_test(
+ name = "default_types_test",
+ srcs = ["default_types_test.py"],
+ python_version = "PY3",
+ deps = [
+ ":default_types",
+ "//tensorflow/python/platform:client_testlib",
+ ],
+)
diff --git a/tensorflow/core/function/trace_type/__init__.py b/tensorflow/core/function/trace_type/__init__.py
new file mode 100644
index 0000000..1e35416
--- /dev/null
+++ b/tensorflow/core/function/trace_type/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tracing Protocol for tf.function.
+
+TODO(b/202447704): Briefly describe the tracing, retracing, and how trace types
+control it.
+"""
+
+
+from tensorflow.core.function.trace_type.signature_builder import make_function_signature
+from tensorflow.core.function.trace_type.signature_builder import SignatureContext
+from tensorflow.core.function.trace_type.signature_builder import WeakrefDeletionObserver
+
diff --git a/tensorflow/core/function/trace_type/default_types.py b/tensorflow/core/function/trace_type/default_types.py
new file mode 100644
index 0000000..a307633
--- /dev/null
+++ b/tensorflow/core/function/trace_type/default_types.py
@@ -0,0 +1,232 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""TraceType implementations for common Python types."""
+
+from typing import Dict as PythonDict
+from typing import Hashable, Optional, Sequence, Type
+from typing import Tuple as PythonTuple
+
+from tensorflow.python.types import trace
+
+
+class Generic(trace.TraceType):
+ """Represents an arbitrary Python object."""
+
+ def __init__(self, obj):
+ self._object = obj
+ self._object_hash = hash(obj)
+
+ def is_subtype_of(self, other: trace.TraceType) -> bool:
+ return self == other
+
+ def most_specific_common_supertype(
+ self, types: Sequence[trace.TraceType]) -> Optional[trace.TraceType]:
+ if not types:
+ raise ValueError(f"`types` must be a non-empty sequence, got{types}")
+
+ return None
+
+ def __eq__(self, other) -> bool:
+ if not isinstance(other, trace.TraceType):
+ return NotImplemented
+
+ return isinstance(other, Generic) and self._object == other._object
+
+ def __hash__(self) -> int:
+ return self._object_hash
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}(obj={self._object!r})"
+
+
+class Weakref(Generic):
+ """Represents weakref of an arbitrary Python object.
+
+ When a function argument is a custom class, instead of making a copy of it
+ just for the sake of function cache, a weakref is instead kept to save memory.
+ """
+
+ def __eq__(self, other):
+ if not isinstance(other, trace.TraceType):
+ return NotImplemented
+
+ if not isinstance(other, Weakref):
+ return False
+
+ if self._object() is None or other._object() is None:
+ return False
+
+ if self._object() is other._object():
+ return True
+
+ return self._object == other._object
+
+ def __hash__(self):
+ return self._object_hash
+
+
+class OrderedCollection(trace.TraceType):
+ """Represents an ordered collection of TraceType objects.
+
+ Attributes:
+ components: A corresponding sequence of TraceTypes to the values in the
+ collection.
+ """
+
+ def __init__(self, *components: trace.TraceType):
+ self.components = components
+
+ def _has_same_structure(self, other):
+ if not isinstance(other, type(self)):
+ return False
+
+ if len(self.components) != len(other.components):
+ return False
+
+ return True
+
+ def is_subtype_of(self, other: trace.TraceType) -> bool:
+ """See base class."""
+ if not self._has_same_structure(other):
+ return False
+
+ if not all([
+ component.is_subtype_of(other.components[i])
+ for i, component in enumerate(self.components)
+ ]):
+ return False
+
+ return True
+
+ def most_specific_common_supertype(self, types: Sequence[trace.TraceType]):
+ """See base class."""
+ if not types:
+ raise ValueError(f"`types` must be a non-empty sequence, got{types}")
+
+ if not all(self._has_same_structure(other) for other in types):
+ return None
+
+ new_components = []
+ for i, component in enumerate(self.components):
+ common = component.most_specific_common_supertype(
+ [other.components[i] for other in types])
+ if common is None:
+ return None
+ else:
+ new_components.append(common)
+
+ return type(self)(*new_components)
+
+ def __eq__(self, other) -> bool:
+ if not isinstance(other, trace.TraceType):
+ return NotImplemented
+
+ if not self._has_same_structure(other):
+ return False
+
+ return self.components == other.components
+
+ def __hash__(self) -> int:
+ return hash(self.components)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}(components={self.components!r})"
+
+
+class List(OrderedCollection):
+ pass
+
+
+class Tuple(OrderedCollection):
+ pass
+
+
+class Attrs(OrderedCollection):
+ """Represents a class annotated by attr.s.
+
+ Each attr.s class has a fixed, ordered set of attributes. Therefore, we only
+ need to consider the class type and the underlying attributes. Extra
+ metadata including attribute names can be ignored.
+ """
+
+ def __init__(self, classtype: Type[object],
+ attributes: PythonTuple[trace.TraceType]):
+ super().__init__(Generic(classtype), *attributes)
+
+
+class Dict(trace.TraceType):
+ """Represents a dictionary of TraceType objects.
+
+ Attributes:
+ mapping: A mapping from keys to corresponding TraceTypes of the dict values.
+ """
+
+ def __init__(self, mapping: PythonDict[Hashable, trace.TraceType]):
+ self.mapping = mapping
+
+ def _has_same_structure(self, other):
+ if not isinstance(other, Dict):
+ return False
+
+ return self.mapping.keys() == other.mapping.keys()
+
+ def is_subtype_of(self, other: trace.TraceType) -> bool:
+ """See base class."""
+ if not self._has_same_structure(other):
+ return False
+
+ # We need all keys to be present because there can be logic relying on
+ # their existence or lack thereof and hence can not guarantee subtype based
+ # on a subset or superset of keys.
+ # Only the tracing code can explicitly check for key dependencies and inform
+ # that decision.
+ return all(self.mapping[key].is_subtype_of(other.mapping[key])
+ for key in self.mapping)
+
+ def most_specific_common_supertype(self, types: Sequence[trace.TraceType]):
+ """See base class."""
+
+ if not types:
+ raise ValueError(f"`types` must be a non-empty sequence, got{types}")
+
+ if not all(self._has_same_structure(other) for other in types):
+ return None
+
+ new_mapping = {}
+ for key in self.mapping.keys():
+ common = self.mapping[key].most_specific_common_supertype(
+ [other.mapping[key] for other in types])
+ if common is None:
+ return None
+ else:
+ new_mapping[key] = common
+
+ return Dict(new_mapping)
+
+ def __eq__(self, other) -> bool:
+ if not isinstance(other, trace.TraceType):
+ return NotImplemented
+
+ if not isinstance(other, Dict):
+ return False
+
+ return self.mapping == other.mapping
+
+ def __hash__(self) -> int:
+ return hash(frozenset(self.mapping.keys()))
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}(mapping={self.mapping!r})"
diff --git a/tensorflow/core/function/trace_type/default_types_test.py b/tensorflow/core/function/trace_type/default_types_test.py
new file mode 100644
index 0000000..f28128f
--- /dev/null
+++ b/tensorflow/core/function/trace_type/default_types_test.py
@@ -0,0 +1,155 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for default_types."""
+
+from tensorflow.core.function.trace_type import default_types
+from tensorflow.python.platform import test
+
+
+class DefaultTypesTest(test.TestCase):
+
+ def testOrderedCollectionTypeEquality(self):
+ collection = default_types.OrderedCollection
+ generic = default_types.Generic
+ collection_a = collection(generic(1), generic(2), generic(3))
+ collection_b = collection(generic(1), generic(2), generic(1))
+ collection_c = collection(generic(1), generic(2), generic(3))
+
+ self.assertNotEqual(collection_a, collection_b)
+ self.assertEqual(collection_a, collection_c)
+ self.assertEqual(hash(collection_a), hash(collection_c))
+
+ def testOrderedCollectionTypeSubtype(self):
+
+ class Subtypable(default_types.Generic):
+
+ def is_subtype_of(self, other):
+ return self._object == 2 or other._object == 3
+
+ collection = default_types.OrderedCollection
+ collection_a = collection(Subtypable(1), Subtypable(2), Subtypable(3))
+ collection_b = collection(Subtypable(2), Subtypable(1), Subtypable(2))
+ collection_c = collection(Subtypable(1), Subtypable(3), Subtypable(3))
+
+ self.assertTrue(collection_b.is_subtype_of(collection_c))
+ self.assertFalse(collection_a.is_subtype_of(collection_b))
+ self.assertFalse(collection_c.is_subtype_of(collection_a))
+
+ def testOrderedCollectionTypeSupertype(self):
+
+ class Supertypable(default_types.Generic):
+
+ def most_specific_common_supertype(self, others):
+ if self._object == 2 and isinstance(others[0]._object, int):
+ return Supertypable(3)
+ else:
+ return None
+
+ collection = default_types.OrderedCollection
+ collection_a = collection(Supertypable(1), Supertypable(2), Supertypable(3))
+ collection_b = collection(Supertypable(2), Supertypable(2), Supertypable(2))
+
+ self.assertIsNone(
+ collection_a.most_specific_common_supertype([collection_b]))
+ self.assertEqual(
+ collection_b.most_specific_common_supertype([collection_a]),
+ collection(Supertypable(3), Supertypable(3), Supertypable(3)))
+
+ def testDictTypeSubtype(self):
+
+ class MockSubtypeOf2(default_types.Generic):
+
+ def is_subtype_of(self, other):
+ return other._object == 2
+
+ dict_type = default_types.Dict
+ dict_a = dict_type({
+ 'a': MockSubtypeOf2(1),
+ 'b': MockSubtypeOf2(1),
+ 'c': MockSubtypeOf2(1)
+ })
+ dict_b = dict_type({
+ 'a': MockSubtypeOf2(2),
+ 'b': MockSubtypeOf2(2),
+ 'c': MockSubtypeOf2(2)
+ })
+ dict_c = dict_type({'a': MockSubtypeOf2(1), 'b': MockSubtypeOf2(1)})
+
+ self.assertTrue(dict_a.is_subtype_of(dict_b))
+ self.assertFalse(dict_c.is_subtype_of(dict_b))
+ self.assertFalse(dict_c.is_subtype_of(dict_a))
+
+ def testDictTypeSupertype(self):
+
+ class MockSupertypes2With3(default_types.Generic):
+
+ def most_specific_common_supertype(self, others):
+ if not others:
+ return self
+
+ if self._object == 2 and isinstance(others[0]._object, int):
+ return MockSupertypes2With3(3)
+ else:
+ return None
+
+ dict_type = default_types.Dict
+ dict_a = dict_type({
+ 'a': MockSupertypes2With3(1),
+ 'b': MockSupertypes2With3(2),
+ 'c': MockSupertypes2With3(3)
+ })
+ dict_b = dict_type({
+ 'a': MockSupertypes2With3(2),
+ 'b': MockSupertypes2With3(2),
+ 'c': MockSupertypes2With3(2)
+ })
+
+ self.assertIsNone(dict_a.most_specific_common_supertype([dict_b]))
+ self.assertEqual(
+ dict_b.most_specific_common_supertype([dict_a]),
+ dict_type({
+ 'a': MockSupertypes2With3(3),
+ 'b': MockSupertypes2With3(3),
+ 'c': MockSupertypes2With3(3)
+ }))
+
+ def testListTupleInequality(self):
+ generic = default_types.Generic
+
+ list_a = default_types.List(generic(1), generic(2), generic(3))
+ list_b = default_types.List(generic(1), generic(2), generic(3))
+
+ tuple_a = default_types.Tuple(generic(1), generic(2), generic(3))
+ tuple_b = default_types.Tuple(generic(1), generic(2), generic(3))
+
+ self.assertEqual(list_a, list_b)
+ self.assertEqual(tuple_a, tuple_b)
+ self.assertNotEqual(list_a, tuple_a)
+ self.assertNotEqual(tuple_a, list_a)
+
+ def testDictTypeEquality(self):
+ dict_type = default_types.Dict
+ generic = default_types.Generic
+
+ dict_a = dict_type({generic(1): generic(2), generic(3): generic(4)})
+ dict_b = dict_type({generic(1): generic(2)})
+ dict_c = dict_type({generic(3): generic(4), generic(1): generic(2)})
+
+ self.assertEqual(dict_a, dict_c)
+ self.assertNotEqual(dict_a, dict_b)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/core/function/trace_type/signature_builder.py b/tensorflow/core/function/trace_type/signature_builder.py
new file mode 100644
index 0000000..4fa42c1
--- /dev/null
+++ b/tensorflow/core/function/trace_type/signature_builder.py
@@ -0,0 +1,146 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utitiles for Cache Key generation based on Function Trace Type."""
+
+import collections.abc
+from typing import Any, Callable
+import weakref
+
+from tensorflow.core.function.trace_type import default_types
+from tensorflow.python.types import trace
+
+
+class WeakrefDeletionObserver:
+ """An observer for the event of deleting a weakref.
+
+ This allows users of FunctionTraceType to be notified when an instance which
+ depends on a weakref becomes invalid by the deletion of the weakref. In
+ particular, tf.function caches can use this mechanism to clear the cache of
+ keys that are no longer valid.
+
+ We use the observer pattern and not just basic callbacks because the keys
+ are typically created before they are used by the cache.
+ """
+
+ def __init__(self):
+ self._triggered = False
+ self._callables = []
+
+ def add_listener(self, on_delete: Callable[[], None]):
+ if self._triggered:
+ on_delete()
+ else:
+ self._callables.append(on_delete)
+
+ def weakref_deleted(self):
+ self._triggered = True
+ for c in self._callables:
+ c()
+
+ def __call__(self, _):
+ """Call handler for convenience of use with weakref."""
+ self.weakref_deleted()
+
+
+class SignatureContext(trace.TracingContext):
+ """Container for variables and flags shared across signature tracing."""
+
+ def __init__(self, include_tensor_ranks_only=False):
+ self._deletion_observer = WeakrefDeletionObserver()
+ self._include_tensor_ranks_only = include_tensor_ranks_only
+ self._global_to_local_id = {}
+
+ # TODO(b/202772221): Consider dropping after alias pattern matching is
+ # supported.
+ def get_local_id(self, local_id):
+
+ if local_id not in self._global_to_local_id:
+ self._global_to_local_id[local_id] = len(self._global_to_local_id)
+
+ return self._global_to_local_id[local_id]
+
+ # TODO(b/202430155): Remove this flag after TraceType shape relaxation.
+ @property
+ def include_tensor_ranks_only(self):
+ return self._include_tensor_ranks_only
+
+ @property
+ def deletion_observer(self):
+ """Returns a functor which invalidates the current key when called."""
+ return self._deletion_observer
+
+
+def create_trace_type(obj: Any,
+ context: SignatureContext) -> trace.TraceType:
+ """Returns a TraceType corresponding to the object based on the context.
+
+ Args:
+ obj: The object to generate a TraceType for.
+ context: The TracingContext to be shared during protocol calls.
+
+ Returns:
+ A TraceType object representing the given object.
+ """
+
+ if isinstance(obj, trace.SupportsTracingProtocol):
+ return obj.__tf_tracing_type__(context)
+
+ if isinstance(obj, list):
+ return default_types.List(*(create_trace_type(c, context) for c in obj))
+
+ if isinstance(obj, tuple):
+ return default_types.Tuple(*(create_trace_type(c, context) for c in obj))
+
+ if isinstance(obj, collections.abc.Mapping):
+ return default_types.Dict(
+ {k: create_trace_type(obj[k], context) for k in obj})
+
+ if hasattr(type(obj), "__attrs_attrs__"):
+ return default_types.Attrs(
+ type(obj), (create_trace_type(getattr(obj, a.name), context)
+ for a in obj.__attrs_attrs__))
+
+ if hasattr(obj, "__wrapped__"):
+ return create_trace_type(obj.__wrapped__, context)
+
+ try:
+ ref = weakref.ref(obj, context.deletion_observer)
+ if ref is None:
+ raise TypeError(
+ f"Deleted objects are not valid tf.function arguments, Got {obj!r}")
+ else:
+ return default_types.Weakref(ref)
+ except TypeError:
+ try:
+ return default_types.Generic(obj)
+ except:
+ raise TypeError(
+ f"Python object could not be represented through the generic tracing "
+ f"type. Consider implementing the Tracing Protocol for it: {obj!r}")
+
+
+def make_function_signature(
+ function_args,
+ signature_context: SignatureContext) -> trace.TraceType:
+ """Returns the trace type specification of a function's arguments.
+
+ Args:
+ function_args: Tuple/List/Dict structure containing the function arguments
+ signature_context: The SignatureContext to be shared during protocol calls.
+
+ Returns:
+ A TraceType object representing all the given inputs.
+ """
+ return create_trace_type(function_args, signature_context)
diff --git a/tensorflow/core/function/trace_type/trace_type_test.py b/tensorflow/core/function/trace_type/trace_type_test.py
new file mode 100644
index 0000000..fc29d0f
--- /dev/null
+++ b/tensorflow/core/function/trace_type/trace_type_test.py
@@ -0,0 +1,417 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests and benchmarks for the trace_type module."""
+
+import timeit
+
+from absl.testing import parameterized
+
+from tensorflow.core.function import trace_type
+from tensorflow.core.function.trace_type import default_types
+from tensorflow.python.compat import v2_compat
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.eager import function
+from tensorflow.python.framework import combinations
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_spec
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.ops.ragged import ragged_tensor
+from tensorflow.python.platform import test
+from tensorflow.python.types import trace
+
+
+class TestAttr:
+ """Helps test attrs collections."""
+
+ def __init__(self, name):
+ self.name = name
+
+
+class TestAttrsClass:
+ """Helps test attrs collections."""
+
+ __attrs_attrs__ = (TestAttr('a'), TestAttr('b'))
+
+ def __init__(self, a, b):
+ self.a = a
+ self.b = b
+
+
+class DummyGenericClass:
+ """Helps test memory leaks for GenericType."""
+ pass
+
+
+def make_function_signature_with_context(inputs):
+ return trace_type.make_function_signature(
+ inputs, trace_type.SignatureContext())
+
+
+class CacheKeyGenerationTest(test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(combinations.combine(mode=['eager']))
+ def testIteratorAliasing(self):
+ it1 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3]))
+ it2 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3]))
+
+ self.assertEqual(
+ make_function_signature_with_context((it1, it1)),
+ make_function_signature_with_context((it2, it2)))
+ self.assertEqual(
+ make_function_signature_with_context((it1, it2)),
+ make_function_signature_with_context((it2, it1)))
+ self.assertNotEqual(
+ make_function_signature_with_context((it1, it1)),
+ make_function_signature_with_context((it1, it2)))
+
+ @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+ def testIteratorTypesImplementTracing(self):
+ self.assertTrue(
+ issubclass(iterator_ops.OwnedIterator, trace.SupportsTracingProtocol))
+ self.assertTrue(
+ issubclass(iterator_ops.IteratorSpec, trace.SupportsTracingProtocol))
+
+ @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+ def testCompositeAndSpec(self):
+ composite_tensor = ragged_tensor.RaggedTensor.from_row_splits(
+ values=[1, 2, 3], row_splits=[0, 2, 3])
+ spec = ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32)
+
+ self.assertEqual(
+ make_function_signature_with_context(composite_tensor),
+ make_function_signature_with_context(spec))
+
+ @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+ def testVariableAliasing(self):
+ v1 = resource_variable_ops.ResourceVariable([1])
+ v2 = resource_variable_ops.ResourceVariable([1])
+ v3 = resource_variable_ops.ResourceVariable([1])
+ all_unique = make_function_signature_with_context((v1, v2, v3))
+ all_same = make_function_signature_with_context((v1, v1, v1))
+ self.assertNotEqual(all_unique, all_same)
+
+ v3 = resource_variable_ops.ResourceVariable([2])
+ v4 = resource_variable_ops.ResourceVariable([2])
+ v5 = resource_variable_ops.ResourceVariable([2])
+ all_unique_again = make_function_signature_with_context((v3, v4, v5))
+ all_same_again = make_function_signature_with_context((v4, v4, v4))
+ self.assertEqual(all_unique, all_unique_again)
+ self.assertEqual(all_same, all_same_again)
+
+ @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+ def testTensorEquality(self):
+ context = trace_type.SignatureContext()
+ tensor_a = array_ops.zeros([11, 3, 5],
+ dtype=dtypes.int32).__tf_tracing_type__(context)
+ tensor_b = array_ops.zeros([11, 4, 5],
+ dtype=dtypes.int32).__tf_tracing_type__(context)
+ tensor_c = array_ops.zeros(
+ [11, 3, 5], dtype=dtypes.float32).__tf_tracing_type__(context)
+ tensor_d = array_ops.ones([11, 3, 5],
+ dtype=dtypes.int32).__tf_tracing_type__(context)
+
+ self.assertNotEqual(tensor_a, tensor_b)
+ self.assertNotEqual(tensor_a, tensor_c)
+ self.assertNotEqual(tensor_b, tensor_c)
+ self.assertEqual(tensor_a, tensor_d)
+
+ @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+ def testTensorAndSpecEquality(self):
+ context = trace_type.SignatureContext()
+ tensor = array_ops.zeros([11, 3, 5],
+ dtype=dtypes.int32).__tf_tracing_type__(context)
+ spec = tensor_spec.TensorSpec(
+ [11, 3, 5], dtype=dtypes.int32).__tf_tracing_type__(context)
+ spec_with_name = tensor_spec.TensorSpec(
+ [11, 3, 5], dtype=dtypes.int32,
+ name='name').__tf_tracing_type__(context)
+
+ self.assertEqual(tensor, spec)
+ self.assertNotEqual(tensor, spec_with_name)
+
+ @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+ def testTensorShapeUnknown(self):
+ context = trace_type.SignatureContext()
+ spec_1 = tensor_spec.TensorSpec(
+ None, dtype=dtypes.int32).__tf_tracing_type__(context)
+ spec_2 = tensor_spec.TensorSpec(
+ None, dtype=dtypes.int32).__tf_tracing_type__(context)
+ self.assertEqual(spec_1, spec_2)
+
+ @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+ def testAttrsCacheKeyGeneration(self):
+ trace_a = make_function_signature_with_context(TestAttrsClass(1, 2))
+ expected = default_types.Attrs(
+ TestAttrsClass,
+ (default_types.Generic(1), default_types.Generic(2)))
+ self.assertEqual(trace_a, expected)
+ self.assertTrue(trace_a.is_subtype_of(trace_a))
+
+ @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+ def testTupleEquality(self):
+ trace_a = make_function_signature_with_context((1, 2, 3, 4))
+ trace_b = make_function_signature_with_context((1, 2, 2, 4))
+ trace_c = make_function_signature_with_context((1, 2, 3))
+ trace_d = make_function_signature_with_context((1, 2, 3, 4))
+ self.assertNotEqual(trace_a, trace_b)
+ self.assertNotEqual(trace_a, trace_c)
+ self.assertNotEqual(trace_b, trace_c)
+ self.assertEqual(trace_a, trace_d)
+
+ @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+ def testListEquality(self):
+ trace_a = make_function_signature_with_context([1, 2, 3, 4])
+ trace_b = make_function_signature_with_context([1, 2, 2, 4])
+ trace_c = make_function_signature_with_context([1, 2, 3])
+ trace_d = make_function_signature_with_context([1, 2, 3, 4])
+ self.assertNotEqual(trace_a, trace_b)
+ self.assertNotEqual(trace_a, trace_c)
+ self.assertNotEqual(trace_b, trace_c)
+ self.assertEqual(trace_a, trace_d)
+
+ @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+ def testDictEquality(self):
+ trace_a = make_function_signature_with_context({1: 2, 3: 4})
+ trace_b = make_function_signature_with_context({1: 2, 3: 2})
+ trace_c = make_function_signature_with_context({1: 2, 3: 0})
+ trace_d = make_function_signature_with_context({3: 4, 1: 2})
+ self.assertNotEqual(trace_a, trace_b)
+ self.assertNotEqual(trace_a, trace_c)
+ self.assertNotEqual(trace_b, trace_c)
+ self.assertEqual(trace_a, trace_d)
+
+ @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+ def testComplexStruct(self):
+ struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})}
+ trace_a = make_function_signature_with_context(struct)
+ trace_b = make_function_signature_with_context(struct)
+ self.assertEqual(trace_a, trace_b)
+ self.assertTrue(trace_a.is_subtype_of(trace_b))
+ self.assertTrue(trace_b.is_subtype_of(trace_a))
+
+ @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+ def testCustomUnequableTypeSucceeds(self):
+
+ class CustomUnequable:
+
+ def __eq__(self, o):
+ raise ValueError
+
+ def __hash__(self):
+ return 0
+
+ object_a = CustomUnequable()
+ object_b = CustomUnequable()
+ trace_a_1 = make_function_signature_with_context(object_a)
+ trace_a_2 = make_function_signature_with_context(object_a)
+ trace_b = make_function_signature_with_context(object_b)
+ self.assertEqual(trace_a_1, trace_a_2)
+
+ with self.assertRaises(ValueError):
+ trace_a_1.__eq__(trace_b)
+
+ del object_a
+ self.assertNotEqual(trace_a_1, trace_a_2)
+ self.assertNotEqual(trace_a_2, trace_a_1)
+
+ del object_b
+ self.assertNotEqual(trace_a_1, trace_a_2)
+ self.assertNotEqual(trace_a_2, trace_a_1)
+
+ @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+ def testCustomUnhashableTypeFailsGracefully(self):
+
+ class CustomUnhashable:
+
+ def __eq__(self, o):
+ return True
+
+ obj = CustomUnhashable()
+ with self.assertRaisesRegex(
+ TypeError,
+ r'could not be represented through the generic tracing type'):
+ make_function_signature_with_context(obj)
+
+
+class CacheKeyMemoryTest(test.TestCase):
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testGeneric(self):
+ make_function_signature_with_context(1)
+ make_function_signature_with_context(DummyGenericClass())
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testTensor(self):
+ tensor = array_ops.zeros([10])
+ make_function_signature_with_context(tensor)
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testTuple(self):
+ make_function_signature_with_context((1, 2, 3))
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testDict(self):
+ make_function_signature_with_context({1: 1, 2: 2, 3: 3})
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testList(self):
+ make_function_signature_with_context([1, 2, 3])
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testAttrs(self):
+ make_function_signature_with_context(TestAttrsClass(1, 2))
+
+
+class CacheKeyGenerationBenchmark(test.Benchmark):
+
+ def benchmarkTensor(self):
+ shapes = [[1], [2, 19], [5, 11, 24], [4, 5, 9, 23]]
+ tensors = []
+ for s in shapes:
+ tensors.append(array_ops.zeros(s))
+
+ def encode_tensors(tensors):
+ make_function_signature_with_context(tensors)
+
+ iterations = 100000
+ t = timeit.timeit(lambda: encode_tensors(tensors), number=iterations)
+ self.report_benchmark(
+ name='tensor_cache_key_generation',
+ iters=iterations,
+ wall_time=t,
+ metrics=[{
+ 'name': 'tensor_cache_key_generation_avg_ms',
+ 'value': t / iterations * 1000
+ }])
+
+ def benchmarkTensorSpec(self):
+ shapes = [[1], [2, 19], [5, 11, 24], [4, 5, 9, 23]]
+ tensor_specs = []
+ for s in shapes:
+ tensor_specs.append(tensor_spec.TensorSpec(s, dtypes.int32))
+
+ def encode_tensor_specs(tensor_specs):
+ make_function_signature_with_context(tensor_specs)
+
+ iterations = 100000
+ t = timeit.timeit(
+ lambda: encode_tensor_specs(tensor_specs), number=iterations)
+ self.report_benchmark(
+ name='tensor_spec_cache_key_generation',
+ iters=iterations,
+ wall_time=t,
+ metrics=[{
+ 'name': 'tensor_spec_cache_key_generation_avg_ms',
+ 'value': t / iterations * 1000
+ }])
+
+ def benchmarkVariable(self):
+ var_list = [
+ variables.Variable(1.0),
+ variables.Variable(1),
+ variables.Variable([1])
+ ]
+
+ def encode_variables(var_list):
+ make_function_signature_with_context(var_list)
+
+ iterations = 10000
+ t = timeit.timeit(lambda: encode_variables(var_list), number=iterations)
+ self.report_benchmark(
+ name='variable_cache_key_generation',
+ iters=iterations,
+ wall_time=t,
+ metrics=[{
+ 'name': 'variable_cache_key_generation_avg_ms',
+ 'value': t / iterations * 1000
+ }])
+
+ def benchmarkCacheKeyLookup(self):
+
+ @function.defun
+ def defined(t):
+ return t
+
+ call_arg_list = [
+ 1,
+ array_ops.zeros([5, 13]),
+ array_ops.zeros([9, 22, 24]),
+ array_ops.zeros([5, 13, 2])
+ ]
+
+ for c in call_arg_list:
+ defined(c)
+
+ lookup_call_arg = array_ops.zeros([5, 13])
+
+ iterations = 10000
+ t = timeit.timeit(stmt=lambda: defined(lookup_call_arg), number=iterations)
+
+ self.report_benchmark(
+ name='cache_key_lookup',
+ iters=iterations,
+ wall_time=t,
+ metrics=[{
+ 'name': 'cache_key_lookup_avg_ms',
+ 'value': t / iterations * 1000
+ }])
+
+ def benchmarkNestedStruct(self):
+ struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})}
+
+ def encode_struct(struct):
+ make_function_signature_with_context(struct)
+
+ iterations = 100000
+ t = timeit.timeit(lambda: encode_struct(struct), number=iterations)
+ self.report_benchmark(
+ name='nested_struct_cache_key_generation',
+ iters=iterations,
+ wall_time=t,
+ metrics=[{
+ 'name': 'nested_struct_cache_key_generation_avg_ms',
+ 'value': t / iterations * 1000
+ }])
+
+ def benchmarkFunctionInvocation(self):
+ struct = (variables.Variable(1.0), array_ops.zeros([5, 13]), {
+ 'tensor': array_ops.zeros([5, 20]),
+ 'variable': variables.Variable(1.0)
+ })
+
+ @function.defun
+ def defined(t):
+ return t
+
+ defined(struct) # Get it traced and cached.
+
+ iterations = 10000
+ t = timeit.timeit(lambda: defined(struct), number=iterations)
+ self.report_benchmark(
+ name='function_invocation',
+ iters=iterations,
+ wall_time=t,
+ metrics=[{
+ 'name': 'function_invocation_time_avg_ms',
+ 'value': t / iterations * 1000
+ }])
+
+if __name__ == '__main__':
+ v2_compat.enable_v2_behavior()
+ test.main()