Introduce TraceType for collections
PiperOrigin-RevId: 401577102
Change-Id: Ie9849d327f224b2612b56af1554b5884f4ea8fd8
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index eb282dd..ba42517 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -629,6 +629,7 @@
"//tensorflow/python/framework:composite_tensor",
"//tensorflow/python/framework:ops",
"//tensorflow/python/framework:tensor_spec",
+ "//tensorflow/python/types",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 7c3e170..a487d53 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -95,6 +95,8 @@
# are not detected by Global TAP.
# TODO(jiaweix): remove this flag and related args (b/198782192)
ENCODE_VARIABLES_BY_RESOURCE_ID = True
+# TODO(b/201533914): Remove this flag and related args
+USE_FULL_TRACE_TYPE = False
_graph_building_time_counter = monitoring.Counter(
"/tensorflow/core/tf_function/graph_building_time_usecs",
@@ -3213,7 +3215,8 @@
# kwargs is empty.
inputs = (args, kwargs)
hashable_input_signature = function_trace_type.get_arg_spec(
- inputs, include_tensor_ranks_only, ENCODE_VARIABLES_BY_RESOURCE_ID)
+ inputs, include_tensor_ranks_only, ENCODE_VARIABLES_BY_RESOURCE_ID,
+ USE_FULL_TRACE_TYPE)
else:
del args, kwargs
assert not include_tensor_ranks_only
diff --git a/tensorflow/python/eager/function_trace_type.py b/tensorflow/python/eager/function_trace_type.py
index 5db82be..30d662c 100644
--- a/tensorflow/python/eager/function_trace_type.py
+++ b/tensorflow/python/eager/function_trace_type.py
@@ -14,6 +14,7 @@
# ==============================================================================
"""Utitiles for Cache Key generation based on Function Trace Type."""
+from typing import Optional, Sequence, Dict
import weakref
import numpy as np
@@ -23,10 +24,179 @@
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.types import trace
+
+
+class GenericType(trace.TraceType):
+ """Represents an arbitrary Python object."""
+
+ def __init__(self, obj):
+ self._object = obj
+ self._object_hash = self._make_hash(obj)
+
+ def is_subtype_of(self, other: trace.TraceType) -> bool:
+ return self == other
+
+ def most_specific_common_supertype(
+ self, others: Sequence[trace.TraceType]) -> Optional[trace.TraceType]:
+ return None
+
+ def __eq__(self, other) -> bool:
+ return isinstance(other, GenericType) and self._object == other._object
+
+ def __hash__(self) -> int:
+ return self._object_hash
+
+ # TODO(b/195985838): Cleanup once Tensor protocol is implemented.
+ def _make_hash(self, elem):
+ """Deals with special cases while hashing arbitrary Python objects."""
+ try:
+ return hash(elem)
+ except TypeError:
+ # TODO(slebedev): consider using nest.
+ if isinstance(elem, tuple):
+ return hash(tuple(map(self._make_hash, elem)))
+
+ # TFE_Py_EncodeArg weakrefs arguments it does not recognize, and we expect
+ # all recognized types to be hashable.
+ assert isinstance(elem, weakref.ReferenceType)
+ v = elem()
+
+ if resource_variable_ops.is_resource_variable(v):
+ # We special case variables here to use unique_id as the cache key. This
+ # ensures we have to retrace whenever a different variable is passed in.
+ # This is needed to support cases where the user may use the id of a
+ # variable in the function perhaps as a lookup in a dictionary.
+ #
+ # This choice leads to more retracing when we could have possibly used
+ # the shape and dtype instead. However, we expect the number of
+ # variables in a program to be bounded, and correspondingly the number
+ # of retraces.
+ #
+ # Note we also include the class name to avoid collisions with strings.
+ return hash((v.__class__, v._unique_id)) # pylint: disable=protected-access
+
+ if self._is_ndarray(v):
+ # Numpy arrays are not hashable, but when calling functions we treat
+ # them in the same way as tf.Tensors.
+ if not hasattr(v, "shape") or not hasattr(v, "dtype"):
+ # TODO(tomhennigan) De-dup with _as_ndarray in _convert_numpy_inputs.
+ v = self._as_ndarray(v)
+ return hash(tensor_spec.TensorSpec(v.shape, v.dtype))
+
+ raise ValueError(
+ "Arguments to a tf.function must be a nested structure of "
+ "Tensors, Variables, NumPy arrays, or hashable Python "
+ f"objects, got {type(v)}.")
+
+ def _as_ndarray(self, value):
+ """Converts value to an ndarray, assumes _is_ndarray(value)."""
+ # TODO(tomhennigan) Support __array_interface__ too (including for
+ # _convert_numpy_inputs).
+ return value.__array__()
+
+ def _is_ndarray(self, value):
+ """Tests whether the given value is an ndarray (and not a TF tensor/var)."""
+ # TODO(tomhennigan) Support __array_interface__ too.
+ return hasattr(value, "__array__") and not (
+ isinstance(value, ops.Tensor) or
+ isinstance(value, resource_variable_ops.BaseResourceVariable) or
+ hasattr(value, "_should_act_as_resource_variable")
+
+ # For legacy reasons we do not automatically promote Numpy strings.
+ or isinstance(value, np.str_)
+ # NumPy dtypes have __array__ as unbound methods.
+ or isinstance(value, type)
+ # CompositeTensors should be flattened instead.
+ or isinstance(value, composite_tensor.CompositeTensor))
+
+
+class CollectionType(trace.TraceType):
+ """Represents a collection of TraceType objects.
+
+ Attributes:
+ components: The group of TraceTypes objects that this class represents.
+ """
+
+ def __init__(self, *components: trace.TraceType):
+ self.components = components
+
+ def is_subtype_of(self, other: trace.TraceType) -> bool:
+ if not isinstance(other, type(self)):
+ return False
+
+ if len(self.components) != len(other.components):
+ return False
+
+ if not all([
+ component.is_subtype_of(other.components[i])
+ for i, component in enumerate(self.components)
+ ]):
+ return False
+
+ return True
+
+ def most_specific_common_supertype(self, others: Sequence[trace.TraceType]):
+ if not all([
+ isinstance(other, type(self)) and
+ len(self.components) == len(other.components) for other in others
+ ]):
+ return None
+
+ new_components = []
+ for i, component in enumerate(self.components):
+ common = component.most_specific_common_supertype(
+ *[other.components[i] for other in others])
+ if common is None:
+ return None
+ else:
+ new_components.append(common)
+
+ return new_components
+
+ def __eq__(self, other) -> bool:
+ if not isinstance(other, type(self)):
+ return False
+
+ if len(self.components) != len(other.components):
+ return False
+
+ if not all([
+ component == other.components[i]
+ for i, component in enumerate(self.components)
+ ]):
+ return False
+
+ return True
+
+ def __hash__(self) -> int:
+ return hash((type(self), self.components))
+
+
+class TupleType(CollectionType):
+ """Represents a tuple of TraceType objects."""
+ pass
+
+
+class ListType(CollectionType):
+ """Represents a list of TraceType objects."""
+ pass
+
+
+class DictType(CollectionType):
+ """Represents a dictionary of TraceType objects."""
+
+ def __init__(self, mapping: Dict[trace.TraceType, trace.TraceType]):
+ sorted_keys = sorted(mapping.keys(), key=hash)
+ components = []
+ for k in sorted_keys:
+ components.append(TupleType(k, mapping[k]))
+
+ super().__init__(*components)
def get_arg_spec(inputs, include_tensor_ranks_only,
- encode_variables_by_resource_id):
+ encode_variables_by_resource_id, enable_full_trace_type):
"""Returns the trace type specification of a function's arguments.
Args:
@@ -34,85 +204,32 @@
include_tensor_ranks_only: If Tensors should be considered by rank
encode_variables_by_resource_id: If Variables should be considered by
resource id
+ enable_full_trace_type: If full usage of trace type protocol should be
+ enabled. Otherwise, only a GenericType wrapper is added over the final
+ results.
Returns:
A hashable object representing the function arguments.
"""
- return _make_input_signature_hashable(pywrap_tfe.TFE_Py_EncodeArg(
- inputs, include_tensor_ranks_only, encode_variables_by_resource_id))
+ if enable_full_trace_type:
+ def parametrized_get_arg_spec(arg):
+ return get_arg_spec(arg, include_tensor_ranks_only,
+ encode_variables_by_resource_id, True)
-# TODO(b/195985838): Cleanup this function once Tensor protocol is implemented.
-def _make_input_signature_hashable(elem):
- """Rewrite input signature to be hashable.
+ if isinstance(inputs, tuple):
+ return TupleType(*map(parametrized_get_arg_spec, inputs))
- We replace nested variables in the input signature with TensorSpec in order to
- be hashable.
+ if isinstance(inputs, list):
+ return ListType(*map(parametrized_get_arg_spec, inputs))
- Args:
- elem: Input signature element
+ if isinstance(inputs, dict):
+ traced = {
+ parametrized_get_arg_spec(k): parametrized_get_arg_spec(v)
+ for k, v in inputs.items()
+ }
+ return DictType(traced)
- Returns:
- A hashable object for the requested input signature
- """
- try:
- hash(elem)
- except TypeError:
- # TODO(slebedev): consider using nest.
- if isinstance(elem, tuple):
- return tuple(map(_make_input_signature_hashable, elem))
-
- # TFE_Py_EncodeArg weakrefs arguments it does not recognize, and we expect
- # all recognized types to be hashable.
- assert isinstance(elem, weakref.ReferenceType)
- v = elem()
-
- if resource_variable_ops.is_resource_variable(v):
- # We special case variables here to use unique_id as the cache key. This
- # ensures we have to retrace whenever a different variable is passed in.
- # This is needed to support cases where the user may use the id of a
- # variable in the function perhaps as a lookup in a dictionary.
- #
- # This choice leads to more retracing when we could have possibly used the
- # shape and dtype instead. However, we expect the number of variables in a
- # program to be bounded, and correspondingly the number of retraces.
- #
- # Note we also include the class name to avoid collisions with strings.
- return v.__class__, v._unique_id # pylint: disable=protected-access
-
- if _is_ndarray(v):
- # Numpy arrays are not hashable, but when calling functions we treat them
- # in the same way as tf.Tensors.
- if not hasattr(v, "shape") or not hasattr(v, "dtype"):
- # TODO(tomhennigan) De-dup with _as_ndarray in _convert_numpy_inputs.
- v = _as_ndarray(v)
- return tensor_spec.TensorSpec(v.shape, v.dtype)
-
- raise ValueError("Arguments to a tf.function must be a nested structure of "
- "Tensors, Variables, NumPy arrays, or hashable Python "
- f"objects, got {type(v)}.")
-
- return elem
-
-
-def _as_ndarray(value):
- """Converts value to an ndarray, assumes _is_ndarray(value)."""
- # TODO(tomhennigan) Support __array_interface__ too (including for
- # _convert_numpy_inputs).
- return value.__array__()
-
-
-def _is_ndarray(value):
- """Tests whether the given value is an ndarray (and not a TF tensor/var)."""
- # TODO(tomhennigan) Support __array_interface__ too.
- return hasattr(value, "__array__") and not (
- isinstance(value, ops.Tensor)
- or isinstance(value, resource_variable_ops.BaseResourceVariable)
- or hasattr(value, "_should_act_as_resource_variable")
-
- # For legacy reasons we do not automatically promote Numpy strings.
- or isinstance(value, np.str_)
- # NumPy dtypes have __array__ as unbound methods.
- or isinstance(value, type)
- # CompositeTensors should be flattened instead.
- or isinstance(value, composite_tensor.CompositeTensor))
+ return GenericType(
+ pywrap_tfe.TFE_Py_EncodeArg(inputs, include_tensor_ranks_only,
+ encode_variables_by_resource_id))
diff --git a/tensorflow/python/eager/function_trace_type_test.py b/tensorflow/python/eager/function_trace_type_test.py
index a8fa530..a2b9537 100644
--- a/tensorflow/python/eager/function_trace_type_test.py
+++ b/tensorflow/python/eager/function_trace_type_test.py
@@ -27,6 +27,50 @@
from tensorflow.python.platform import test
+class CacheKeyGenerationTest(test.TestCase):
+
+ def testTupleEquality(self):
+ trace_a = function_trace_type.get_arg_spec((1, 2, 3, 4), False, False, True)
+ trace_b = function_trace_type.get_arg_spec((1, 2, 2, 4), False, False, True)
+ trace_c = function_trace_type.get_arg_spec((1, 2, 3), False, False, True)
+ trace_d = function_trace_type.get_arg_spec((1, 2, 3, 4), False, False, True)
+
+ self.assertNotEqual(trace_a, trace_b)
+ self.assertNotEqual(trace_a, trace_c)
+ self.assertNotEqual(trace_b, trace_c)
+ self.assertEqual(trace_a, trace_d)
+
+ def testListEquality(self):
+ trace_a = function_trace_type.get_arg_spec([1, 2, 3, 4], False, False, True)
+ trace_b = function_trace_type.get_arg_spec([1, 2, 2, 4], False, False, True)
+ trace_c = function_trace_type.get_arg_spec([1, 2, 3], False, False, True)
+ trace_d = function_trace_type.get_arg_spec([1, 2, 3, 4], False, False, True)
+
+ self.assertNotEqual(trace_a, trace_b)
+ self.assertNotEqual(trace_a, trace_c)
+ self.assertNotEqual(trace_b, trace_c)
+ self.assertEqual(trace_a, trace_d)
+
+ def testDictEquality(self):
+ trace_a = function_trace_type.get_arg_spec({1: 2, 3: 4}, False, False, True)
+ trace_b = function_trace_type.get_arg_spec({1: 2, 3: 2}, False, False, True)
+ trace_c = function_trace_type.get_arg_spec({1: 2, 3: 0}, False, False, True)
+ trace_d = function_trace_type.get_arg_spec({3: 4, 1: 2}, False, False, True)
+
+ self.assertNotEqual(trace_a, trace_b)
+ self.assertNotEqual(trace_a, trace_c)
+ self.assertNotEqual(trace_b, trace_c)
+ self.assertEqual(trace_a, trace_d)
+
+ def testComplexStruct(self):
+ struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})}
+ trace_a = function_trace_type.get_arg_spec(struct, False, False, True)
+ trace_b = function_trace_type.get_arg_spec(struct, False, False, True)
+ self.assertEqual(trace_a, trace_b)
+ self.assertTrue(trace_a.is_subtype_of(trace_b))
+ self.assertTrue(trace_b.is_subtype_of(trace_a))
+
+
class CacheKeyGenerationBenchmark(test.Benchmark):
def benchmarkTensor(self):
@@ -36,7 +80,8 @@
tensors.append(array_ops.zeros(s))
def encode_tensors(tensors):
- function_trace_type.get_arg_spec(tensors, False, False)
+ function_trace_type.get_arg_spec(tensors, False, False,
+ function.USE_FULL_TRACE_TYPE)
iterations = 100000
t = timeit.timeit(lambda: encode_tensors(tensors), number=iterations)
@@ -56,7 +101,8 @@
tensor_specs.append(tensor_spec.TensorSpec(s, dtypes.int32))
def encode_tensor_specs(tensor_specs):
- function_trace_type.get_arg_spec(tensor_specs, False, False)
+ function_trace_type.get_arg_spec(tensor_specs, False, False,
+ function.USE_FULL_TRACE_TYPE)
iterations = 100000
t = timeit.timeit(
@@ -78,7 +124,8 @@
]
def encode_variables(var_list):
- function_trace_type.get_arg_spec(var_list, False, False)
+ function_trace_type.get_arg_spec(var_list, False, False,
+ function.USE_FULL_TRACE_TYPE)
iterations = 1000000
t = timeit.timeit(lambda: encode_variables(var_list), number=iterations)
@@ -98,7 +145,8 @@
model = keras.Model(inputs=inputs, outputs=outputs)
def encode_model(model):
- function_trace_type.get_arg_spec(model, False, False)
+ function_trace_type.get_arg_spec(model, False, False,
+ function.USE_FULL_TRACE_TYPE)
iterations = 100000
t = timeit.timeit(lambda: encode_model(model), number=iterations)
@@ -146,7 +194,8 @@
struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})}
def encode_struct(struct):
- function_trace_type.get_arg_spec(struct, False, False)
+ function_trace_type.get_arg_spec(struct, False, False,
+ function.USE_FULL_TRACE_TYPE)
iterations = 100000
t = timeit.timeit(lambda: encode_struct(struct), number=iterations)
diff --git a/tensorflow/python/types/BUILD b/tensorflow/python/types/BUILD
index a2c3dc6..22d9b7b 100644
--- a/tensorflow/python/types/BUILD
+++ b/tensorflow/python/types/BUILD
@@ -25,6 +25,7 @@
"core.py",
"distribute.py",
"internal.py",
+ "trace.py",
],
srcs_version = "PY3",
visibility = [
diff --git a/tensorflow/python/types/trace.py b/tensorflow/python/types/trace.py
new file mode 100644
index 0000000..895d68a
--- /dev/null
+++ b/tensorflow/python/types/trace.py
@@ -0,0 +1,45 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Function Tracing Type."""
+
+import abc
+from typing import Optional, Sequence
+
+
+class TraceType(abc.ABC):
+ """Represents the type of object(s) for Function Tracing purposes.
+
+ `TraceType` is an abstract class that other classes might inherit from to
+ provide information regarding associated class(es) for the purposes of
+ Function Tracing. The typing logic provided through this mechanism will be
+ used to make decisions regarding usage of cached functions and retracing.
+ """
+
+ @abc.abstractmethod
+ def is_subtype_of(self, other: "TraceType") -> bool:
+ pass
+
+ @abc.abstractmethod
+ def most_specific_common_supertype(
+ self, others: Sequence["TraceType"]) -> Optional["TraceType"]:
+ pass
+
+ @abc.abstractmethod
+ def __hash__(self) -> int:
+ pass
+
+ @abc.abstractmethod
+ def __eq__(self, other) -> bool:
+ pass