Set up TraceType as an independent core/function package

PiperOrigin-RevId: 418082697
Change-Id: I74ce638f178c8712d20f00a9db5d7ef11484de71
diff --git a/tensorflow/core/function/trace_type/BUILD b/tensorflow/core/function/trace_type/BUILD
new file mode 100644
index 0000000..5a0a056
--- /dev/null
+++ b/tensorflow/core/function/trace_type/BUILD
@@ -0,0 +1,65 @@
+load("//tensorflow:tensorflow.bzl", "pytype_strict_library")
+load("//tensorflow:tensorflow.bzl", "py_strict_test")
+
+package(
+    licenses = ["notice"],
+)
+
+pytype_strict_library(
+    name = "trace_type",
+    srcs = [
+        "__init__.py",
+        "signature_builder.py",
+    ],
+    srcs_version = "PY3",
+    visibility = ["//tensorflow:internal"],
+    deps = [
+        ":default_types",
+        "//tensorflow/python/types",
+    ],
+)
+
+py_strict_test(
+    name = "trace_type_test",
+    srcs = ["trace_type_test.py"],
+    python_version = "PY3",
+    deps = [
+        ":default_types",
+        ":trace_type",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:resource_variable_ops",
+        "//tensorflow/python:variables",
+        "//tensorflow/python/compat:v2_compat",
+        "//tensorflow/python/data/ops:dataset_ops",
+        "//tensorflow/python/data/ops:iterator_ops",
+        "//tensorflow/python/eager:function",
+        "//tensorflow/python/framework:combinations",
+        "//tensorflow/python/framework:dtypes",
+        "//tensorflow/python/framework:tensor_spec",
+        "//tensorflow/python/framework:test_lib",
+        "//tensorflow/python/ops/ragged:ragged_tensor",
+        "//tensorflow/python/platform:client_testlib",
+        "//tensorflow/python/types",
+        "@absl_py//absl/testing:parameterized",
+    ],
+)
+
+pytype_strict_library(
+    name = "default_types",
+    srcs = [
+        "default_types.py",
+    ],
+    srcs_version = "PY3",
+    visibility = ["//tensorflow:internal"],
+    deps = ["//tensorflow/python/types"],
+)
+
+py_strict_test(
+    name = "default_types_test",
+    srcs = ["default_types_test.py"],
+    python_version = "PY3",
+    deps = [
+        ":default_types",
+        "//tensorflow/python/platform:client_testlib",
+    ],
+)
diff --git a/tensorflow/core/function/trace_type/__init__.py b/tensorflow/core/function/trace_type/__init__.py
new file mode 100644
index 0000000..1e35416
--- /dev/null
+++ b/tensorflow/core/function/trace_type/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tracing Protocol for tf.function.
+
+TODO(b/202447704): Briefly describe the tracing, retracing, and how trace types
+control it.
+"""
+
+
+from tensorflow.core.function.trace_type.signature_builder import make_function_signature
+from tensorflow.core.function.trace_type.signature_builder import SignatureContext
+from tensorflow.core.function.trace_type.signature_builder import WeakrefDeletionObserver
+
diff --git a/tensorflow/core/function/trace_type/default_types.py b/tensorflow/core/function/trace_type/default_types.py
new file mode 100644
index 0000000..a307633
--- /dev/null
+++ b/tensorflow/core/function/trace_type/default_types.py
@@ -0,0 +1,232 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""TraceType implementations for common Python types."""
+
+from typing import Dict as PythonDict
+from typing import Hashable, Optional, Sequence, Type
+from typing import Tuple as PythonTuple
+
+from tensorflow.python.types import trace
+
+
+class Generic(trace.TraceType):
+  """Represents an arbitrary Python object."""
+
+  def __init__(self, obj):
+    self._object = obj
+    self._object_hash = hash(obj)
+
+  def is_subtype_of(self, other: trace.TraceType) -> bool:
+    return self == other
+
+  def most_specific_common_supertype(
+      self, types: Sequence[trace.TraceType]) -> Optional[trace.TraceType]:
+    if not types:
+      raise ValueError(f"`types` must be a non-empty sequence, got{types}")
+
+    return None
+
+  def __eq__(self, other) -> bool:
+    if not isinstance(other, trace.TraceType):
+      return NotImplemented
+
+    return isinstance(other, Generic) and self._object == other._object
+
+  def __hash__(self) -> int:
+    return self._object_hash
+
+  def __repr__(self):
+    return f"{self.__class__.__name__}(obj={self._object!r})"
+
+
+class Weakref(Generic):
+  """Represents weakref of an arbitrary Python object.
+
+  When a function argument is a custom class, instead of making a copy of it
+  just for the sake of function cache, a weakref is instead kept to save memory.
+  """
+
+  def __eq__(self, other):
+    if not isinstance(other, trace.TraceType):
+      return NotImplemented
+
+    if not isinstance(other, Weakref):
+      return False
+
+    if self._object() is None or other._object() is None:
+      return False
+
+    if self._object() is other._object():
+      return True
+
+    return self._object == other._object
+
+  def __hash__(self):
+    return self._object_hash
+
+
+class OrderedCollection(trace.TraceType):
+  """Represents an ordered collection of TraceType objects.
+
+  Attributes:
+    components: A corresponding sequence of TraceTypes to the values in the
+      collection.
+  """
+
+  def __init__(self, *components: trace.TraceType):
+    self.components = components
+
+  def _has_same_structure(self, other):
+    if not isinstance(other, type(self)):
+      return False
+
+    if len(self.components) != len(other.components):
+      return False
+
+    return True
+
+  def is_subtype_of(self, other: trace.TraceType) -> bool:
+    """See base class."""
+    if not self._has_same_structure(other):
+      return False
+
+    if not all([
+        component.is_subtype_of(other.components[i])
+        for i, component in enumerate(self.components)
+    ]):
+      return False
+
+    return True
+
+  def most_specific_common_supertype(self, types: Sequence[trace.TraceType]):
+    """See base class."""
+    if not types:
+      raise ValueError(f"`types` must be a non-empty sequence, got{types}")
+
+    if not all(self._has_same_structure(other) for other in types):
+      return None
+
+    new_components = []
+    for i, component in enumerate(self.components):
+      common = component.most_specific_common_supertype(
+          [other.components[i] for other in types])
+      if common is None:
+        return None
+      else:
+        new_components.append(common)
+
+    return type(self)(*new_components)
+
+  def __eq__(self, other) -> bool:
+    if not isinstance(other, trace.TraceType):
+      return NotImplemented
+
+    if not self._has_same_structure(other):
+      return False
+
+    return self.components == other.components
+
+  def __hash__(self) -> int:
+    return hash(self.components)
+
+  def __repr__(self):
+    return f"{self.__class__.__name__}(components={self.components!r})"
+
+
+class List(OrderedCollection):
+  pass
+
+
+class Tuple(OrderedCollection):
+  pass
+
+
+class Attrs(OrderedCollection):
+  """Represents a class annotated by attr.s.
+
+  Each attr.s class has a fixed, ordered set of attributes. Therefore, we only
+  need to consider the class type and the underlying attributes. Extra
+  metadata including attribute names can be ignored.
+  """
+
+  def __init__(self, classtype: Type[object],
+               attributes: PythonTuple[trace.TraceType]):
+    super().__init__(Generic(classtype), *attributes)
+
+
+class Dict(trace.TraceType):
+  """Represents a dictionary of TraceType objects.
+
+  Attributes:
+    mapping: A mapping from keys to corresponding TraceTypes of the dict values.
+  """
+
+  def __init__(self, mapping: PythonDict[Hashable, trace.TraceType]):
+    self.mapping = mapping
+
+  def _has_same_structure(self, other):
+    if not isinstance(other, Dict):
+      return False
+
+    return self.mapping.keys() == other.mapping.keys()
+
+  def is_subtype_of(self, other: trace.TraceType) -> bool:
+    """See base class."""
+    if not self._has_same_structure(other):
+      return False
+
+    # We need all keys to be present because there can be logic relying on
+    # their existence or lack thereof and hence can not guarantee subtype based
+    # on a subset or superset of keys.
+    # Only the tracing code can explicitly check for key dependencies and inform
+    # that decision.
+    return all(self.mapping[key].is_subtype_of(other.mapping[key])
+               for key in self.mapping)
+
+  def most_specific_common_supertype(self, types: Sequence[trace.TraceType]):
+    """See base class."""
+
+    if not types:
+      raise ValueError(f"`types` must be a non-empty sequence, got{types}")
+
+    if not all(self._has_same_structure(other) for other in types):
+      return None
+
+    new_mapping = {}
+    for key in self.mapping.keys():
+      common = self.mapping[key].most_specific_common_supertype(
+          [other.mapping[key] for other in types])
+      if common is None:
+        return None
+      else:
+        new_mapping[key] = common
+
+    return Dict(new_mapping)
+
+  def __eq__(self, other) -> bool:
+    if not isinstance(other, trace.TraceType):
+      return NotImplemented
+
+    if not isinstance(other, Dict):
+      return False
+
+    return self.mapping == other.mapping
+
+  def __hash__(self) -> int:
+    return hash(frozenset(self.mapping.keys()))
+
+  def __repr__(self):
+    return f"{self.__class__.__name__}(mapping={self.mapping!r})"
diff --git a/tensorflow/core/function/trace_type/default_types_test.py b/tensorflow/core/function/trace_type/default_types_test.py
new file mode 100644
index 0000000..f28128f
--- /dev/null
+++ b/tensorflow/core/function/trace_type/default_types_test.py
@@ -0,0 +1,155 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for default_types."""
+
+from tensorflow.core.function.trace_type import default_types
+from tensorflow.python.platform import test
+
+
+class DefaultTypesTest(test.TestCase):
+
+  def testOrderedCollectionTypeEquality(self):
+    collection = default_types.OrderedCollection
+    generic = default_types.Generic
+    collection_a = collection(generic(1), generic(2), generic(3))
+    collection_b = collection(generic(1), generic(2), generic(1))
+    collection_c = collection(generic(1), generic(2), generic(3))
+
+    self.assertNotEqual(collection_a, collection_b)
+    self.assertEqual(collection_a, collection_c)
+    self.assertEqual(hash(collection_a), hash(collection_c))
+
+  def testOrderedCollectionTypeSubtype(self):
+
+    class Subtypable(default_types.Generic):
+
+      def is_subtype_of(self, other):
+        return self._object == 2 or other._object == 3
+
+    collection = default_types.OrderedCollection
+    collection_a = collection(Subtypable(1), Subtypable(2), Subtypable(3))
+    collection_b = collection(Subtypable(2), Subtypable(1), Subtypable(2))
+    collection_c = collection(Subtypable(1), Subtypable(3), Subtypable(3))
+
+    self.assertTrue(collection_b.is_subtype_of(collection_c))
+    self.assertFalse(collection_a.is_subtype_of(collection_b))
+    self.assertFalse(collection_c.is_subtype_of(collection_a))
+
+  def testOrderedCollectionTypeSupertype(self):
+
+    class Supertypable(default_types.Generic):
+
+      def most_specific_common_supertype(self, others):
+        if self._object == 2 and isinstance(others[0]._object, int):
+          return Supertypable(3)
+        else:
+          return None
+
+    collection = default_types.OrderedCollection
+    collection_a = collection(Supertypable(1), Supertypable(2), Supertypable(3))
+    collection_b = collection(Supertypable(2), Supertypable(2), Supertypable(2))
+
+    self.assertIsNone(
+        collection_a.most_specific_common_supertype([collection_b]))
+    self.assertEqual(
+        collection_b.most_specific_common_supertype([collection_a]),
+        collection(Supertypable(3), Supertypable(3), Supertypable(3)))
+
+  def testDictTypeSubtype(self):
+
+    class MockSubtypeOf2(default_types.Generic):
+
+      def is_subtype_of(self, other):
+        return other._object == 2
+
+    dict_type = default_types.Dict
+    dict_a = dict_type({
+        'a': MockSubtypeOf2(1),
+        'b': MockSubtypeOf2(1),
+        'c': MockSubtypeOf2(1)
+    })
+    dict_b = dict_type({
+        'a': MockSubtypeOf2(2),
+        'b': MockSubtypeOf2(2),
+        'c': MockSubtypeOf2(2)
+    })
+    dict_c = dict_type({'a': MockSubtypeOf2(1), 'b': MockSubtypeOf2(1)})
+
+    self.assertTrue(dict_a.is_subtype_of(dict_b))
+    self.assertFalse(dict_c.is_subtype_of(dict_b))
+    self.assertFalse(dict_c.is_subtype_of(dict_a))
+
+  def testDictTypeSupertype(self):
+
+    class MockSupertypes2With3(default_types.Generic):
+
+      def most_specific_common_supertype(self, others):
+        if not others:
+          return self
+
+        if self._object == 2 and isinstance(others[0]._object, int):
+          return MockSupertypes2With3(3)
+        else:
+          return None
+
+    dict_type = default_types.Dict
+    dict_a = dict_type({
+        'a': MockSupertypes2With3(1),
+        'b': MockSupertypes2With3(2),
+        'c': MockSupertypes2With3(3)
+    })
+    dict_b = dict_type({
+        'a': MockSupertypes2With3(2),
+        'b': MockSupertypes2With3(2),
+        'c': MockSupertypes2With3(2)
+    })
+
+    self.assertIsNone(dict_a.most_specific_common_supertype([dict_b]))
+    self.assertEqual(
+        dict_b.most_specific_common_supertype([dict_a]),
+        dict_type({
+            'a': MockSupertypes2With3(3),
+            'b': MockSupertypes2With3(3),
+            'c': MockSupertypes2With3(3)
+        }))
+
+  def testListTupleInequality(self):
+    generic = default_types.Generic
+
+    list_a = default_types.List(generic(1), generic(2), generic(3))
+    list_b = default_types.List(generic(1), generic(2), generic(3))
+
+    tuple_a = default_types.Tuple(generic(1), generic(2), generic(3))
+    tuple_b = default_types.Tuple(generic(1), generic(2), generic(3))
+
+    self.assertEqual(list_a, list_b)
+    self.assertEqual(tuple_a, tuple_b)
+    self.assertNotEqual(list_a, tuple_a)
+    self.assertNotEqual(tuple_a, list_a)
+
+  def testDictTypeEquality(self):
+    dict_type = default_types.Dict
+    generic = default_types.Generic
+
+    dict_a = dict_type({generic(1): generic(2), generic(3): generic(4)})
+    dict_b = dict_type({generic(1): generic(2)})
+    dict_c = dict_type({generic(3): generic(4), generic(1): generic(2)})
+
+    self.assertEqual(dict_a, dict_c)
+    self.assertNotEqual(dict_a, dict_b)
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/core/function/trace_type/signature_builder.py b/tensorflow/core/function/trace_type/signature_builder.py
new file mode 100644
index 0000000..4fa42c1
--- /dev/null
+++ b/tensorflow/core/function/trace_type/signature_builder.py
@@ -0,0 +1,146 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utitiles for Cache Key generation based on Function Trace Type."""
+
+import collections.abc
+from typing import Any, Callable
+import weakref
+
+from tensorflow.core.function.trace_type import default_types
+from tensorflow.python.types import trace
+
+
+class WeakrefDeletionObserver:
+  """An observer for the event of deleting a weakref.
+
+  This allows users of FunctionTraceType to be notified when an instance which
+  depends on a weakref becomes invalid by the deletion of the weakref. In
+  particular, tf.function caches can use this mechanism to clear the cache of
+  keys that are no longer valid.
+
+  We use the observer pattern and not just basic callbacks because the keys
+  are typically created before they are used by the cache.
+  """
+
+  def __init__(self):
+    self._triggered = False
+    self._callables = []
+
+  def add_listener(self, on_delete: Callable[[], None]):
+    if self._triggered:
+      on_delete()
+    else:
+      self._callables.append(on_delete)
+
+  def weakref_deleted(self):
+    self._triggered = True
+    for c in self._callables:
+      c()
+
+  def __call__(self, _):
+    """Call handler for convenience of use with weakref."""
+    self.weakref_deleted()
+
+
+class SignatureContext(trace.TracingContext):
+  """Container for variables and flags shared across signature tracing."""
+
+  def __init__(self, include_tensor_ranks_only=False):
+    self._deletion_observer = WeakrefDeletionObserver()
+    self._include_tensor_ranks_only = include_tensor_ranks_only
+    self._global_to_local_id = {}
+
+  # TODO(b/202772221): Consider dropping after alias pattern matching is
+  # supported.
+  def get_local_id(self, local_id):
+
+    if local_id not in self._global_to_local_id:
+      self._global_to_local_id[local_id] = len(self._global_to_local_id)
+
+    return self._global_to_local_id[local_id]
+
+  # TODO(b/202430155): Remove this flag after TraceType shape relaxation.
+  @property
+  def include_tensor_ranks_only(self):
+    return self._include_tensor_ranks_only
+
+  @property
+  def deletion_observer(self):
+    """Returns a functor which invalidates the current key when called."""
+    return self._deletion_observer
+
+
+def create_trace_type(obj: Any,
+                      context: SignatureContext) -> trace.TraceType:
+  """Returns a TraceType corresponding to the object based on the context.
+
+  Args:
+    obj: The object to generate a TraceType for.
+    context: The TracingContext to be shared during protocol calls.
+
+  Returns:
+    A TraceType object representing the given object.
+  """
+
+  if isinstance(obj, trace.SupportsTracingProtocol):
+    return obj.__tf_tracing_type__(context)
+
+  if isinstance(obj, list):
+    return default_types.List(*(create_trace_type(c, context) for c in obj))
+
+  if isinstance(obj, tuple):
+    return default_types.Tuple(*(create_trace_type(c, context) for c in obj))
+
+  if isinstance(obj, collections.abc.Mapping):
+    return default_types.Dict(
+        {k: create_trace_type(obj[k], context) for k in obj})
+
+  if hasattr(type(obj), "__attrs_attrs__"):
+    return default_types.Attrs(
+        type(obj), (create_trace_type(getattr(obj, a.name), context)
+                    for a in obj.__attrs_attrs__))
+
+  if hasattr(obj, "__wrapped__"):
+    return create_trace_type(obj.__wrapped__, context)
+
+  try:
+    ref = weakref.ref(obj, context.deletion_observer)
+    if ref is None:
+      raise TypeError(
+          f"Deleted objects are not valid tf.function arguments, Got {obj!r}")
+    else:
+      return default_types.Weakref(ref)
+  except TypeError:
+    try:
+      return default_types.Generic(obj)
+    except:
+      raise TypeError(
+          f"Python object could not be represented through the generic tracing "
+          f"type. Consider implementing the Tracing Protocol for it: {obj!r}")
+
+
+def make_function_signature(
+    function_args,
+    signature_context: SignatureContext) -> trace.TraceType:
+  """Returns the trace type specification of a function's arguments.
+
+  Args:
+    function_args: Tuple/List/Dict structure containing the function arguments
+    signature_context: The SignatureContext to be shared during protocol calls.
+
+  Returns:
+    A TraceType object representing all the given inputs.
+  """
+  return create_trace_type(function_args, signature_context)
diff --git a/tensorflow/core/function/trace_type/trace_type_test.py b/tensorflow/core/function/trace_type/trace_type_test.py
new file mode 100644
index 0000000..fc29d0f
--- /dev/null
+++ b/tensorflow/core/function/trace_type/trace_type_test.py
@@ -0,0 +1,417 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests and benchmarks for the trace_type module."""
+
+import timeit
+
+from absl.testing import parameterized
+
+from tensorflow.core.function import trace_type
+from tensorflow.core.function.trace_type import default_types
+from tensorflow.python.compat import v2_compat
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.eager import function
+from tensorflow.python.framework import combinations
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_spec
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.ops.ragged import ragged_tensor
+from tensorflow.python.platform import test
+from tensorflow.python.types import trace
+
+
+class TestAttr:
+  """Helps test attrs collections."""
+
+  def __init__(self, name):
+    self.name = name
+
+
+class TestAttrsClass:
+  """Helps test attrs collections."""
+
+  __attrs_attrs__ = (TestAttr('a'), TestAttr('b'))
+
+  def __init__(self, a, b):
+    self.a = a
+    self.b = b
+
+
+class DummyGenericClass:
+  """Helps test memory leaks for GenericType."""
+  pass
+
+
+def make_function_signature_with_context(inputs):
+  return trace_type.make_function_signature(
+      inputs, trace_type.SignatureContext())
+
+
+class CacheKeyGenerationTest(test.TestCase, parameterized.TestCase):
+
+  @combinations.generate(combinations.combine(mode=['eager']))
+  def testIteratorAliasing(self):
+    it1 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3]))
+    it2 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3]))
+
+    self.assertEqual(
+        make_function_signature_with_context((it1, it1)),
+        make_function_signature_with_context((it2, it2)))
+    self.assertEqual(
+        make_function_signature_with_context((it1, it2)),
+        make_function_signature_with_context((it2, it1)))
+    self.assertNotEqual(
+        make_function_signature_with_context((it1, it1)),
+        make_function_signature_with_context((it1, it2)))
+
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def testIteratorTypesImplementTracing(self):
+    self.assertTrue(
+        issubclass(iterator_ops.OwnedIterator, trace.SupportsTracingProtocol))
+    self.assertTrue(
+        issubclass(iterator_ops.IteratorSpec, trace.SupportsTracingProtocol))
+
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def testCompositeAndSpec(self):
+    composite_tensor = ragged_tensor.RaggedTensor.from_row_splits(
+        values=[1, 2, 3], row_splits=[0, 2, 3])
+    spec = ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32)
+
+    self.assertEqual(
+        make_function_signature_with_context(composite_tensor),
+        make_function_signature_with_context(spec))
+
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def testVariableAliasing(self):
+    v1 = resource_variable_ops.ResourceVariable([1])
+    v2 = resource_variable_ops.ResourceVariable([1])
+    v3 = resource_variable_ops.ResourceVariable([1])
+    all_unique = make_function_signature_with_context((v1, v2, v3))
+    all_same = make_function_signature_with_context((v1, v1, v1))
+    self.assertNotEqual(all_unique, all_same)
+
+    v3 = resource_variable_ops.ResourceVariable([2])
+    v4 = resource_variable_ops.ResourceVariable([2])
+    v5 = resource_variable_ops.ResourceVariable([2])
+    all_unique_again = make_function_signature_with_context((v3, v4, v5))
+    all_same_again = make_function_signature_with_context((v4, v4, v4))
+    self.assertEqual(all_unique, all_unique_again)
+    self.assertEqual(all_same, all_same_again)
+
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def testTensorEquality(self):
+    context = trace_type.SignatureContext()
+    tensor_a = array_ops.zeros([11, 3, 5],
+                               dtype=dtypes.int32).__tf_tracing_type__(context)
+    tensor_b = array_ops.zeros([11, 4, 5],
+                               dtype=dtypes.int32).__tf_tracing_type__(context)
+    tensor_c = array_ops.zeros(
+        [11, 3, 5], dtype=dtypes.float32).__tf_tracing_type__(context)
+    tensor_d = array_ops.ones([11, 3, 5],
+                              dtype=dtypes.int32).__tf_tracing_type__(context)
+
+    self.assertNotEqual(tensor_a, tensor_b)
+    self.assertNotEqual(tensor_a, tensor_c)
+    self.assertNotEqual(tensor_b, tensor_c)
+    self.assertEqual(tensor_a, tensor_d)
+
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def testTensorAndSpecEquality(self):
+    context = trace_type.SignatureContext()
+    tensor = array_ops.zeros([11, 3, 5],
+                             dtype=dtypes.int32).__tf_tracing_type__(context)
+    spec = tensor_spec.TensorSpec(
+        [11, 3, 5], dtype=dtypes.int32).__tf_tracing_type__(context)
+    spec_with_name = tensor_spec.TensorSpec(
+        [11, 3, 5], dtype=dtypes.int32,
+        name='name').__tf_tracing_type__(context)
+
+    self.assertEqual(tensor, spec)
+    self.assertNotEqual(tensor, spec_with_name)
+
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def testTensorShapeUnknown(self):
+    context = trace_type.SignatureContext()
+    spec_1 = tensor_spec.TensorSpec(
+        None, dtype=dtypes.int32).__tf_tracing_type__(context)
+    spec_2 = tensor_spec.TensorSpec(
+        None, dtype=dtypes.int32).__tf_tracing_type__(context)
+    self.assertEqual(spec_1, spec_2)
+
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def testAttrsCacheKeyGeneration(self):
+    trace_a = make_function_signature_with_context(TestAttrsClass(1, 2))
+    expected = default_types.Attrs(
+        TestAttrsClass,
+        (default_types.Generic(1), default_types.Generic(2)))
+    self.assertEqual(trace_a, expected)
+    self.assertTrue(trace_a.is_subtype_of(trace_a))
+
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def testTupleEquality(self):
+    trace_a = make_function_signature_with_context((1, 2, 3, 4))
+    trace_b = make_function_signature_with_context((1, 2, 2, 4))
+    trace_c = make_function_signature_with_context((1, 2, 3))
+    trace_d = make_function_signature_with_context((1, 2, 3, 4))
+    self.assertNotEqual(trace_a, trace_b)
+    self.assertNotEqual(trace_a, trace_c)
+    self.assertNotEqual(trace_b, trace_c)
+    self.assertEqual(trace_a, trace_d)
+
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def testListEquality(self):
+    trace_a = make_function_signature_with_context([1, 2, 3, 4])
+    trace_b = make_function_signature_with_context([1, 2, 2, 4])
+    trace_c = make_function_signature_with_context([1, 2, 3])
+    trace_d = make_function_signature_with_context([1, 2, 3, 4])
+    self.assertNotEqual(trace_a, trace_b)
+    self.assertNotEqual(trace_a, trace_c)
+    self.assertNotEqual(trace_b, trace_c)
+    self.assertEqual(trace_a, trace_d)
+
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def testDictEquality(self):
+    trace_a = make_function_signature_with_context({1: 2, 3: 4})
+    trace_b = make_function_signature_with_context({1: 2, 3: 2})
+    trace_c = make_function_signature_with_context({1: 2, 3: 0})
+    trace_d = make_function_signature_with_context({3: 4, 1: 2})
+    self.assertNotEqual(trace_a, trace_b)
+    self.assertNotEqual(trace_a, trace_c)
+    self.assertNotEqual(trace_b, trace_c)
+    self.assertEqual(trace_a, trace_d)
+
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def testComplexStruct(self):
+    struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})}
+    trace_a = make_function_signature_with_context(struct)
+    trace_b = make_function_signature_with_context(struct)
+    self.assertEqual(trace_a, trace_b)
+    self.assertTrue(trace_a.is_subtype_of(trace_b))
+    self.assertTrue(trace_b.is_subtype_of(trace_a))
+
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def testCustomUnequableTypeSucceeds(self):
+
+    class CustomUnequable:
+
+      def __eq__(self, o):
+        raise ValueError
+
+      def __hash__(self):
+        return 0
+
+    object_a = CustomUnequable()
+    object_b = CustomUnequable()
+    trace_a_1 = make_function_signature_with_context(object_a)
+    trace_a_2 = make_function_signature_with_context(object_a)
+    trace_b = make_function_signature_with_context(object_b)
+    self.assertEqual(trace_a_1, trace_a_2)
+
+    with self.assertRaises(ValueError):
+      trace_a_1.__eq__(trace_b)
+
+    del object_a
+    self.assertNotEqual(trace_a_1, trace_a_2)
+    self.assertNotEqual(trace_a_2, trace_a_1)
+
+    del object_b
+    self.assertNotEqual(trace_a_1, trace_a_2)
+    self.assertNotEqual(trace_a_2, trace_a_1)
+
+  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
+  def testCustomUnhashableTypeFailsGracefully(self):
+
+    class CustomUnhashable:
+
+      def __eq__(self, o):
+        return True
+
+    obj = CustomUnhashable()
+    with self.assertRaisesRegex(
+        TypeError,
+        r'could not be represented through the generic tracing type'):
+      make_function_signature_with_context(obj)
+
+
+class CacheKeyMemoryTest(test.TestCase):
+
+  @test_util.assert_no_new_pyobjects_executing_eagerly
+  def testGeneric(self):
+    make_function_signature_with_context(1)
+    make_function_signature_with_context(DummyGenericClass())
+
+  @test_util.assert_no_new_pyobjects_executing_eagerly
+  def testTensor(self):
+    tensor = array_ops.zeros([10])
+    make_function_signature_with_context(tensor)
+
+  @test_util.assert_no_new_pyobjects_executing_eagerly
+  def testTuple(self):
+    make_function_signature_with_context((1, 2, 3))
+
+  @test_util.assert_no_new_pyobjects_executing_eagerly
+  def testDict(self):
+    make_function_signature_with_context({1: 1, 2: 2, 3: 3})
+
+  @test_util.assert_no_new_pyobjects_executing_eagerly
+  def testList(self):
+    make_function_signature_with_context([1, 2, 3])
+
+  @test_util.assert_no_new_pyobjects_executing_eagerly
+  def testAttrs(self):
+    make_function_signature_with_context(TestAttrsClass(1, 2))
+
+
+class CacheKeyGenerationBenchmark(test.Benchmark):
+
+  def benchmarkTensor(self):
+    shapes = [[1], [2, 19], [5, 11, 24], [4, 5, 9, 23]]
+    tensors = []
+    for s in shapes:
+      tensors.append(array_ops.zeros(s))
+
+    def encode_tensors(tensors):
+      make_function_signature_with_context(tensors)
+
+    iterations = 100000
+    t = timeit.timeit(lambda: encode_tensors(tensors), number=iterations)
+    self.report_benchmark(
+        name='tensor_cache_key_generation',
+        iters=iterations,
+        wall_time=t,
+        metrics=[{
+            'name': 'tensor_cache_key_generation_avg_ms',
+            'value': t / iterations * 1000
+        }])
+
+  def benchmarkTensorSpec(self):
+    shapes = [[1], [2, 19], [5, 11, 24], [4, 5, 9, 23]]
+    tensor_specs = []
+    for s in shapes:
+      tensor_specs.append(tensor_spec.TensorSpec(s, dtypes.int32))
+
+    def encode_tensor_specs(tensor_specs):
+      make_function_signature_with_context(tensor_specs)
+
+    iterations = 100000
+    t = timeit.timeit(
+        lambda: encode_tensor_specs(tensor_specs), number=iterations)
+    self.report_benchmark(
+        name='tensor_spec_cache_key_generation',
+        iters=iterations,
+        wall_time=t,
+        metrics=[{
+            'name': 'tensor_spec_cache_key_generation_avg_ms',
+            'value': t / iterations * 1000
+        }])
+
+  def benchmarkVariable(self):
+    var_list = [
+        variables.Variable(1.0),
+        variables.Variable(1),
+        variables.Variable([1])
+    ]
+
+    def encode_variables(var_list):
+      make_function_signature_with_context(var_list)
+
+    iterations = 10000
+    t = timeit.timeit(lambda: encode_variables(var_list), number=iterations)
+    self.report_benchmark(
+        name='variable_cache_key_generation',
+        iters=iterations,
+        wall_time=t,
+        metrics=[{
+            'name': 'variable_cache_key_generation_avg_ms',
+            'value': t / iterations * 1000
+        }])
+
+  def benchmarkCacheKeyLookup(self):
+
+    @function.defun
+    def defined(t):
+      return t
+
+    call_arg_list = [
+        1,
+        array_ops.zeros([5, 13]),
+        array_ops.zeros([9, 22, 24]),
+        array_ops.zeros([5, 13, 2])
+    ]
+
+    for c in call_arg_list:
+      defined(c)
+
+    lookup_call_arg = array_ops.zeros([5, 13])
+
+    iterations = 10000
+    t = timeit.timeit(stmt=lambda: defined(lookup_call_arg), number=iterations)
+
+    self.report_benchmark(
+        name='cache_key_lookup',
+        iters=iterations,
+        wall_time=t,
+        metrics=[{
+            'name': 'cache_key_lookup_avg_ms',
+            'value': t / iterations * 1000
+        }])
+
+  def benchmarkNestedStruct(self):
+    struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})}
+
+    def encode_struct(struct):
+      make_function_signature_with_context(struct)
+
+    iterations = 100000
+    t = timeit.timeit(lambda: encode_struct(struct), number=iterations)
+    self.report_benchmark(
+        name='nested_struct_cache_key_generation',
+        iters=iterations,
+        wall_time=t,
+        metrics=[{
+            'name': 'nested_struct_cache_key_generation_avg_ms',
+            'value': t / iterations * 1000
+        }])
+
+  def benchmarkFunctionInvocation(self):
+    struct = (variables.Variable(1.0), array_ops.zeros([5, 13]), {
+        'tensor': array_ops.zeros([5, 20]),
+        'variable': variables.Variable(1.0)
+    })
+
+    @function.defun
+    def defined(t):
+      return t
+
+    defined(struct)  # Get it traced and cached.
+
+    iterations = 10000
+    t = timeit.timeit(lambda: defined(struct), number=iterations)
+    self.report_benchmark(
+        name='function_invocation',
+        iters=iterations,
+        wall_time=t,
+        metrics=[{
+            'name': 'function_invocation_time_avg_ms',
+            'value': t / iterations * 1000
+        }])
+
+if __name__ == '__main__':
+  v2_compat.enable_v2_behavior()
+  test.main()