Introduce TraceType for collections

PiperOrigin-RevId: 401577102
Change-Id: Ie9849d327f224b2612b56af1554b5884f4ea8fd8
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index eb282dd..ba42517 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -629,6 +629,7 @@
         "//tensorflow/python/framework:composite_tensor",
         "//tensorflow/python/framework:ops",
         "//tensorflow/python/framework:tensor_spec",
+        "//tensorflow/python/types",
         "//third_party/py/numpy",
     ],
 )
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 7c3e170..a487d53 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -95,6 +95,8 @@
 # are not detected by Global TAP.
 # TODO(jiaweix): remove this flag and related args (b/198782192)
 ENCODE_VARIABLES_BY_RESOURCE_ID = True
+# TODO(b/201533914): Remove this flag and related args
+USE_FULL_TRACE_TYPE = False
 
 _graph_building_time_counter = monitoring.Counter(
     "/tensorflow/core/tf_function/graph_building_time_usecs",
@@ -3213,7 +3215,8 @@
       # kwargs is empty.
       inputs = (args, kwargs)
       hashable_input_signature = function_trace_type.get_arg_spec(
-          inputs, include_tensor_ranks_only, ENCODE_VARIABLES_BY_RESOURCE_ID)
+          inputs, include_tensor_ranks_only, ENCODE_VARIABLES_BY_RESOURCE_ID,
+          USE_FULL_TRACE_TYPE)
     else:
       del args, kwargs
       assert not include_tensor_ranks_only
diff --git a/tensorflow/python/eager/function_trace_type.py b/tensorflow/python/eager/function_trace_type.py
index 5db82be..30d662c 100644
--- a/tensorflow/python/eager/function_trace_type.py
+++ b/tensorflow/python/eager/function_trace_type.py
@@ -14,6 +14,7 @@
 # ==============================================================================
 """Utitiles for Cache Key generation based on Function Trace Type."""
 
+from typing import Optional, Sequence, Dict
 import weakref
 
 import numpy as np
@@ -23,10 +24,179 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.types import trace
+
+
+class GenericType(trace.TraceType):
+  """Represents an arbitrary Python object."""
+
+  def __init__(self, obj):
+    self._object = obj
+    self._object_hash = self._make_hash(obj)
+
+  def is_subtype_of(self, other: trace.TraceType) -> bool:
+    return self == other
+
+  def most_specific_common_supertype(
+      self, others: Sequence[trace.TraceType]) -> Optional[trace.TraceType]:
+    return None
+
+  def __eq__(self, other) -> bool:
+    return isinstance(other, GenericType) and self._object == other._object
+
+  def __hash__(self) -> int:
+    return self._object_hash
+
+  # TODO(b/195985838): Cleanup once Tensor protocol is implemented.
+  def _make_hash(self, elem):
+    """Deals with special cases while hashing arbitrary Python objects."""
+    try:
+      return hash(elem)
+    except TypeError:
+      # TODO(slebedev): consider using nest.
+      if isinstance(elem, tuple):
+        return hash(tuple(map(self._make_hash, elem)))
+
+      # TFE_Py_EncodeArg weakrefs arguments it does not recognize, and we expect
+      # all recognized types to be hashable.
+      assert isinstance(elem, weakref.ReferenceType)
+      v = elem()
+
+      if resource_variable_ops.is_resource_variable(v):
+        # We special case variables here to use unique_id as the cache key. This
+        # ensures we have to retrace whenever a different variable is passed in.
+        # This is needed to support cases where the user may use the id of a
+        # variable in the function perhaps as a lookup in a dictionary.
+        #
+        # This choice leads to more retracing when we could have possibly used
+        # the shape and dtype instead. However, we expect the number of
+        # variables in a program to be bounded, and correspondingly the number
+        # of retraces.
+        #
+        # Note we also include the class name to avoid collisions with strings.
+        return hash((v.__class__, v._unique_id))  # pylint: disable=protected-access
+
+      if self._is_ndarray(v):
+        # Numpy arrays are not hashable, but when calling functions we treat
+        # them in the same way as tf.Tensors.
+        if not hasattr(v, "shape") or not hasattr(v, "dtype"):
+          # TODO(tomhennigan) De-dup with _as_ndarray in _convert_numpy_inputs.
+          v = self._as_ndarray(v)
+        return hash(tensor_spec.TensorSpec(v.shape, v.dtype))
+
+      raise ValueError(
+          "Arguments to a tf.function must be a nested structure of "
+          "Tensors, Variables, NumPy arrays, or hashable Python "
+          f"objects, got {type(v)}.")
+
+  def _as_ndarray(self, value):
+    """Converts value to an ndarray, assumes _is_ndarray(value)."""
+    # TODO(tomhennigan) Support __array_interface__ too (including for
+    # _convert_numpy_inputs).
+    return value.__array__()
+
+  def _is_ndarray(self, value):
+    """Tests whether the given value is an ndarray (and not a TF tensor/var)."""
+    # TODO(tomhennigan) Support __array_interface__ too.
+    return hasattr(value, "__array__") and not (
+        isinstance(value, ops.Tensor) or
+        isinstance(value, resource_variable_ops.BaseResourceVariable) or
+        hasattr(value, "_should_act_as_resource_variable")
+
+        # For legacy reasons we do not automatically promote Numpy strings.
+        or isinstance(value, np.str_)
+        # NumPy dtypes have __array__ as unbound methods.
+        or isinstance(value, type)
+        # CompositeTensors should be flattened instead.
+        or isinstance(value, composite_tensor.CompositeTensor))
+
+
+class CollectionType(trace.TraceType):
+  """Represents a collection of TraceType objects.
+
+  Attributes:
+    components: The group of TraceTypes objects that this class represents.
+  """
+
+  def __init__(self, *components: trace.TraceType):
+    self.components = components
+
+  def is_subtype_of(self, other: trace.TraceType) -> bool:
+    if not isinstance(other, type(self)):
+      return False
+
+    if len(self.components) != len(other.components):
+      return False
+
+    if not all([
+        component.is_subtype_of(other.components[i])
+        for i, component in enumerate(self.components)
+    ]):
+      return False
+
+    return True
+
+  def most_specific_common_supertype(self, others: Sequence[trace.TraceType]):
+    if not all([
+        isinstance(other, type(self)) and
+        len(self.components) == len(other.components) for other in others
+    ]):
+      return None
+
+    new_components = []
+    for i, component in enumerate(self.components):
+      common = component.most_specific_common_supertype(
+          *[other.components[i] for other in others])
+      if common is None:
+        return None
+      else:
+        new_components.append(common)
+
+    return new_components
+
+  def __eq__(self, other) -> bool:
+    if not isinstance(other, type(self)):
+      return False
+
+    if len(self.components) != len(other.components):
+      return False
+
+    if not all([
+        component == other.components[i]
+        for i, component in enumerate(self.components)
+    ]):
+      return False
+
+    return True
+
+  def __hash__(self) -> int:
+    return hash((type(self), self.components))
+
+
+class TupleType(CollectionType):
+  """Represents a tuple of TraceType objects."""
+  pass
+
+
+class ListType(CollectionType):
+  """Represents a list of TraceType objects."""
+  pass
+
+
+class DictType(CollectionType):
+  """Represents a dictionary of TraceType objects."""
+
+  def __init__(self, mapping: Dict[trace.TraceType, trace.TraceType]):
+    sorted_keys = sorted(mapping.keys(), key=hash)
+    components = []
+    for k in sorted_keys:
+      components.append(TupleType(k, mapping[k]))
+
+    super().__init__(*components)
 
 
 def get_arg_spec(inputs, include_tensor_ranks_only,
-                 encode_variables_by_resource_id):
+                 encode_variables_by_resource_id, enable_full_trace_type):
   """Returns the trace type specification of a function's arguments.
 
   Args:
@@ -34,85 +204,32 @@
     include_tensor_ranks_only: If Tensors should be considered by rank
     encode_variables_by_resource_id: If Variables should be considered by
       resource id
+    enable_full_trace_type: If full usage of trace type protocol should be
+      enabled. Otherwise, only a GenericType wrapper is added over the final
+      results.
 
   Returns:
     A hashable object representing the function arguments.
   """
-  return _make_input_signature_hashable(pywrap_tfe.TFE_Py_EncodeArg(
-      inputs, include_tensor_ranks_only, encode_variables_by_resource_id))
+  if enable_full_trace_type:
 
+    def parametrized_get_arg_spec(arg):
+      return get_arg_spec(arg, include_tensor_ranks_only,
+                          encode_variables_by_resource_id, True)
 
-# TODO(b/195985838): Cleanup this function once Tensor protocol is implemented.
-def _make_input_signature_hashable(elem):
-  """Rewrite input signature to be hashable.
+    if isinstance(inputs, tuple):
+      return TupleType(*map(parametrized_get_arg_spec, inputs))
 
-  We replace nested variables in the input signature with TensorSpec in order to
-  be hashable.
+    if isinstance(inputs, list):
+      return ListType(*map(parametrized_get_arg_spec, inputs))
 
-  Args:
-    elem: Input signature element
+    if isinstance(inputs, dict):
+      traced = {
+          parametrized_get_arg_spec(k): parametrized_get_arg_spec(v)
+          for k, v in inputs.items()
+      }
+      return DictType(traced)
 
-  Returns:
-    A hashable object for the requested input signature
-  """
-  try:
-    hash(elem)
-  except TypeError:
-    # TODO(slebedev): consider using nest.
-    if isinstance(elem, tuple):
-      return tuple(map(_make_input_signature_hashable, elem))
-
-    # TFE_Py_EncodeArg weakrefs arguments it does not recognize, and we expect
-    # all recognized types to be hashable.
-    assert isinstance(elem, weakref.ReferenceType)
-    v = elem()
-
-    if resource_variable_ops.is_resource_variable(v):
-      # We special case variables here to use unique_id as the cache key. This
-      # ensures we have to retrace whenever a different variable is passed in.
-      # This is needed to support cases where the user may use the id of a
-      # variable in the function perhaps as a lookup in a dictionary.
-      #
-      # This choice leads to more retracing when we could have possibly used the
-      # shape and dtype instead. However, we expect the number of variables in a
-      # program to be bounded, and correspondingly the number of retraces.
-      #
-      # Note we also include the class name to avoid collisions with strings.
-      return v.__class__, v._unique_id  # pylint: disable=protected-access
-
-    if _is_ndarray(v):
-      # Numpy arrays are not hashable, but when calling functions we treat them
-      # in the same way as tf.Tensors.
-      if not hasattr(v, "shape") or not hasattr(v, "dtype"):
-        # TODO(tomhennigan) De-dup with _as_ndarray in _convert_numpy_inputs.
-        v = _as_ndarray(v)
-      return tensor_spec.TensorSpec(v.shape, v.dtype)
-
-    raise ValueError("Arguments to a tf.function must be a nested structure of "
-                     "Tensors, Variables, NumPy arrays, or hashable Python "
-                     f"objects, got {type(v)}.")
-
-  return elem
-
-
-def _as_ndarray(value):
-  """Converts value to an ndarray, assumes _is_ndarray(value)."""
-  # TODO(tomhennigan) Support __array_interface__ too (including for
-  # _convert_numpy_inputs).
-  return value.__array__()
-
-
-def _is_ndarray(value):
-  """Tests whether the given value is an ndarray (and not a TF tensor/var)."""
-  # TODO(tomhennigan) Support __array_interface__ too.
-  return hasattr(value, "__array__") and not (
-      isinstance(value, ops.Tensor)
-      or isinstance(value, resource_variable_ops.BaseResourceVariable)
-      or hasattr(value, "_should_act_as_resource_variable")
-
-      # For legacy reasons we do not automatically promote Numpy strings.
-      or isinstance(value, np.str_)
-      # NumPy dtypes have __array__ as unbound methods.
-      or isinstance(value, type)
-      # CompositeTensors should be flattened instead.
-      or isinstance(value, composite_tensor.CompositeTensor))
+  return GenericType(
+      pywrap_tfe.TFE_Py_EncodeArg(inputs, include_tensor_ranks_only,
+                                  encode_variables_by_resource_id))
diff --git a/tensorflow/python/eager/function_trace_type_test.py b/tensorflow/python/eager/function_trace_type_test.py
index a8fa530..a2b9537 100644
--- a/tensorflow/python/eager/function_trace_type_test.py
+++ b/tensorflow/python/eager/function_trace_type_test.py
@@ -27,6 +27,50 @@
 from tensorflow.python.platform import test
 
 
+class CacheKeyGenerationTest(test.TestCase):
+
+  def testTupleEquality(self):
+    trace_a = function_trace_type.get_arg_spec((1, 2, 3, 4), False, False, True)
+    trace_b = function_trace_type.get_arg_spec((1, 2, 2, 4), False, False, True)
+    trace_c = function_trace_type.get_arg_spec((1, 2, 3), False, False, True)
+    trace_d = function_trace_type.get_arg_spec((1, 2, 3, 4), False, False, True)
+
+    self.assertNotEqual(trace_a, trace_b)
+    self.assertNotEqual(trace_a, trace_c)
+    self.assertNotEqual(trace_b, trace_c)
+    self.assertEqual(trace_a, trace_d)
+
+  def testListEquality(self):
+    trace_a = function_trace_type.get_arg_spec([1, 2, 3, 4], False, False, True)
+    trace_b = function_trace_type.get_arg_spec([1, 2, 2, 4], False, False, True)
+    trace_c = function_trace_type.get_arg_spec([1, 2, 3], False, False, True)
+    trace_d = function_trace_type.get_arg_spec([1, 2, 3, 4], False, False, True)
+
+    self.assertNotEqual(trace_a, trace_b)
+    self.assertNotEqual(trace_a, trace_c)
+    self.assertNotEqual(trace_b, trace_c)
+    self.assertEqual(trace_a, trace_d)
+
+  def testDictEquality(self):
+    trace_a = function_trace_type.get_arg_spec({1: 2, 3: 4}, False, False, True)
+    trace_b = function_trace_type.get_arg_spec({1: 2, 3: 2}, False, False, True)
+    trace_c = function_trace_type.get_arg_spec({1: 2, 3: 0}, False, False, True)
+    trace_d = function_trace_type.get_arg_spec({3: 4, 1: 2}, False, False, True)
+
+    self.assertNotEqual(trace_a, trace_b)
+    self.assertNotEqual(trace_a, trace_c)
+    self.assertNotEqual(trace_b, trace_c)
+    self.assertEqual(trace_a, trace_d)
+
+  def testComplexStruct(self):
+    struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})}
+    trace_a = function_trace_type.get_arg_spec(struct, False, False, True)
+    trace_b = function_trace_type.get_arg_spec(struct, False, False, True)
+    self.assertEqual(trace_a, trace_b)
+    self.assertTrue(trace_a.is_subtype_of(trace_b))
+    self.assertTrue(trace_b.is_subtype_of(trace_a))
+
+
 class CacheKeyGenerationBenchmark(test.Benchmark):
 
   def benchmarkTensor(self):
@@ -36,7 +80,8 @@
       tensors.append(array_ops.zeros(s))
 
     def encode_tensors(tensors):
-      function_trace_type.get_arg_spec(tensors, False, False)
+      function_trace_type.get_arg_spec(tensors, False, False,
+                                       function.USE_FULL_TRACE_TYPE)
 
     iterations = 100000
     t = timeit.timeit(lambda: encode_tensors(tensors), number=iterations)
@@ -56,7 +101,8 @@
       tensor_specs.append(tensor_spec.TensorSpec(s, dtypes.int32))
 
     def encode_tensor_specs(tensor_specs):
-      function_trace_type.get_arg_spec(tensor_specs, False, False)
+      function_trace_type.get_arg_spec(tensor_specs, False, False,
+                                       function.USE_FULL_TRACE_TYPE)
 
     iterations = 100000
     t = timeit.timeit(
@@ -78,7 +124,8 @@
     ]
 
     def encode_variables(var_list):
-      function_trace_type.get_arg_spec(var_list, False, False)
+      function_trace_type.get_arg_spec(var_list, False, False,
+                                       function.USE_FULL_TRACE_TYPE)
 
     iterations = 1000000
     t = timeit.timeit(lambda: encode_variables(var_list), number=iterations)
@@ -98,7 +145,8 @@
     model = keras.Model(inputs=inputs, outputs=outputs)
 
     def encode_model(model):
-      function_trace_type.get_arg_spec(model, False, False)
+      function_trace_type.get_arg_spec(model, False, False,
+                                       function.USE_FULL_TRACE_TYPE)
 
     iterations = 100000
     t = timeit.timeit(lambda: encode_model(model), number=iterations)
@@ -146,7 +194,8 @@
     struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})}
 
     def encode_struct(struct):
-      function_trace_type.get_arg_spec(struct, False, False)
+      function_trace_type.get_arg_spec(struct, False, False,
+                                       function.USE_FULL_TRACE_TYPE)
 
     iterations = 100000
     t = timeit.timeit(lambda: encode_struct(struct), number=iterations)
diff --git a/tensorflow/python/types/BUILD b/tensorflow/python/types/BUILD
index a2c3dc6..22d9b7b 100644
--- a/tensorflow/python/types/BUILD
+++ b/tensorflow/python/types/BUILD
@@ -25,6 +25,7 @@
         "core.py",
         "distribute.py",
         "internal.py",
+        "trace.py",
     ],
     srcs_version = "PY3",
     visibility = [
diff --git a/tensorflow/python/types/trace.py b/tensorflow/python/types/trace.py
new file mode 100644
index 0000000..895d68a
--- /dev/null
+++ b/tensorflow/python/types/trace.py
@@ -0,0 +1,45 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Function Tracing Type."""
+
+import abc
+from typing import Optional, Sequence
+
+
+class TraceType(abc.ABC):
+  """Represents the type of object(s) for Function Tracing purposes.
+
+  `TraceType` is an abstract class that other classes might inherit from to
+  provide information regarding associated class(es) for the purposes of
+  Function Tracing. The typing logic provided through this mechanism will be
+  used to make decisions regarding usage of cached functions and retracing.
+  """
+
+  @abc.abstractmethod
+  def is_subtype_of(self, other: "TraceType") -> bool:
+    pass
+
+  @abc.abstractmethod
+  def most_specific_common_supertype(
+      self, others: Sequence["TraceType"]) -> Optional["TraceType"]:
+    pass
+
+  @abc.abstractmethod
+  def __hash__(self) -> int:
+    pass
+
+  @abc.abstractmethod
+  def __eq__(self, other) -> bool:
+    pass