blob: fc29d0f87e70961a16eb95b1dbea0040ed69d97a [file] [log] [blame]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests and benchmarks for the trace_type module."""
import timeit
from absl.testing import parameterized
from tensorflow.core.function import trace_type
from tensorflow.core.function.trace_type import default_types
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
from tensorflow.python.types import trace
class TestAttr:
"""Helps test attrs collections."""
def __init__(self, name):
self.name = name
class TestAttrsClass:
"""Helps test attrs collections."""
__attrs_attrs__ = (TestAttr('a'), TestAttr('b'))
def __init__(self, a, b):
self.a = a
self.b = b
class DummyGenericClass:
"""Helps test memory leaks for GenericType."""
pass
def make_function_signature_with_context(inputs):
return trace_type.make_function_signature(
inputs, trace_type.SignatureContext())
class CacheKeyGenerationTest(test.TestCase, parameterized.TestCase):
@combinations.generate(combinations.combine(mode=['eager']))
def testIteratorAliasing(self):
it1 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3]))
it2 = iter(dataset_ops.DatasetV2.from_tensor_slices([1, 2, 3]))
self.assertEqual(
make_function_signature_with_context((it1, it1)),
make_function_signature_with_context((it2, it2)))
self.assertEqual(
make_function_signature_with_context((it1, it2)),
make_function_signature_with_context((it2, it1)))
self.assertNotEqual(
make_function_signature_with_context((it1, it1)),
make_function_signature_with_context((it1, it2)))
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testIteratorTypesImplementTracing(self):
self.assertTrue(
issubclass(iterator_ops.OwnedIterator, trace.SupportsTracingProtocol))
self.assertTrue(
issubclass(iterator_ops.IteratorSpec, trace.SupportsTracingProtocol))
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testCompositeAndSpec(self):
composite_tensor = ragged_tensor.RaggedTensor.from_row_splits(
values=[1, 2, 3], row_splits=[0, 2, 3])
spec = ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32)
self.assertEqual(
make_function_signature_with_context(composite_tensor),
make_function_signature_with_context(spec))
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testVariableAliasing(self):
v1 = resource_variable_ops.ResourceVariable([1])
v2 = resource_variable_ops.ResourceVariable([1])
v3 = resource_variable_ops.ResourceVariable([1])
all_unique = make_function_signature_with_context((v1, v2, v3))
all_same = make_function_signature_with_context((v1, v1, v1))
self.assertNotEqual(all_unique, all_same)
v3 = resource_variable_ops.ResourceVariable([2])
v4 = resource_variable_ops.ResourceVariable([2])
v5 = resource_variable_ops.ResourceVariable([2])
all_unique_again = make_function_signature_with_context((v3, v4, v5))
all_same_again = make_function_signature_with_context((v4, v4, v4))
self.assertEqual(all_unique, all_unique_again)
self.assertEqual(all_same, all_same_again)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testTensorEquality(self):
context = trace_type.SignatureContext()
tensor_a = array_ops.zeros([11, 3, 5],
dtype=dtypes.int32).__tf_tracing_type__(context)
tensor_b = array_ops.zeros([11, 4, 5],
dtype=dtypes.int32).__tf_tracing_type__(context)
tensor_c = array_ops.zeros(
[11, 3, 5], dtype=dtypes.float32).__tf_tracing_type__(context)
tensor_d = array_ops.ones([11, 3, 5],
dtype=dtypes.int32).__tf_tracing_type__(context)
self.assertNotEqual(tensor_a, tensor_b)
self.assertNotEqual(tensor_a, tensor_c)
self.assertNotEqual(tensor_b, tensor_c)
self.assertEqual(tensor_a, tensor_d)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testTensorAndSpecEquality(self):
context = trace_type.SignatureContext()
tensor = array_ops.zeros([11, 3, 5],
dtype=dtypes.int32).__tf_tracing_type__(context)
spec = tensor_spec.TensorSpec(
[11, 3, 5], dtype=dtypes.int32).__tf_tracing_type__(context)
spec_with_name = tensor_spec.TensorSpec(
[11, 3, 5], dtype=dtypes.int32,
name='name').__tf_tracing_type__(context)
self.assertEqual(tensor, spec)
self.assertNotEqual(tensor, spec_with_name)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testTensorShapeUnknown(self):
context = trace_type.SignatureContext()
spec_1 = tensor_spec.TensorSpec(
None, dtype=dtypes.int32).__tf_tracing_type__(context)
spec_2 = tensor_spec.TensorSpec(
None, dtype=dtypes.int32).__tf_tracing_type__(context)
self.assertEqual(spec_1, spec_2)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testAttrsCacheKeyGeneration(self):
trace_a = make_function_signature_with_context(TestAttrsClass(1, 2))
expected = default_types.Attrs(
TestAttrsClass,
(default_types.Generic(1), default_types.Generic(2)))
self.assertEqual(trace_a, expected)
self.assertTrue(trace_a.is_subtype_of(trace_a))
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testTupleEquality(self):
trace_a = make_function_signature_with_context((1, 2, 3, 4))
trace_b = make_function_signature_with_context((1, 2, 2, 4))
trace_c = make_function_signature_with_context((1, 2, 3))
trace_d = make_function_signature_with_context((1, 2, 3, 4))
self.assertNotEqual(trace_a, trace_b)
self.assertNotEqual(trace_a, trace_c)
self.assertNotEqual(trace_b, trace_c)
self.assertEqual(trace_a, trace_d)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testListEquality(self):
trace_a = make_function_signature_with_context([1, 2, 3, 4])
trace_b = make_function_signature_with_context([1, 2, 2, 4])
trace_c = make_function_signature_with_context([1, 2, 3])
trace_d = make_function_signature_with_context([1, 2, 3, 4])
self.assertNotEqual(trace_a, trace_b)
self.assertNotEqual(trace_a, trace_c)
self.assertNotEqual(trace_b, trace_c)
self.assertEqual(trace_a, trace_d)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testDictEquality(self):
trace_a = make_function_signature_with_context({1: 2, 3: 4})
trace_b = make_function_signature_with_context({1: 2, 3: 2})
trace_c = make_function_signature_with_context({1: 2, 3: 0})
trace_d = make_function_signature_with_context({3: 4, 1: 2})
self.assertNotEqual(trace_a, trace_b)
self.assertNotEqual(trace_a, trace_c)
self.assertNotEqual(trace_b, trace_c)
self.assertEqual(trace_a, trace_d)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testComplexStruct(self):
struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})}
trace_a = make_function_signature_with_context(struct)
trace_b = make_function_signature_with_context(struct)
self.assertEqual(trace_a, trace_b)
self.assertTrue(trace_a.is_subtype_of(trace_b))
self.assertTrue(trace_b.is_subtype_of(trace_a))
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testCustomUnequableTypeSucceeds(self):
class CustomUnequable:
def __eq__(self, o):
raise ValueError
def __hash__(self):
return 0
object_a = CustomUnequable()
object_b = CustomUnequable()
trace_a_1 = make_function_signature_with_context(object_a)
trace_a_2 = make_function_signature_with_context(object_a)
trace_b = make_function_signature_with_context(object_b)
self.assertEqual(trace_a_1, trace_a_2)
with self.assertRaises(ValueError):
trace_a_1.__eq__(trace_b)
del object_a
self.assertNotEqual(trace_a_1, trace_a_2)
self.assertNotEqual(trace_a_2, trace_a_1)
del object_b
self.assertNotEqual(trace_a_1, trace_a_2)
self.assertNotEqual(trace_a_2, trace_a_1)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testCustomUnhashableTypeFailsGracefully(self):
class CustomUnhashable:
def __eq__(self, o):
return True
obj = CustomUnhashable()
with self.assertRaisesRegex(
TypeError,
r'could not be represented through the generic tracing type'):
make_function_signature_with_context(obj)
class CacheKeyMemoryTest(test.TestCase):
@test_util.assert_no_new_pyobjects_executing_eagerly
def testGeneric(self):
make_function_signature_with_context(1)
make_function_signature_with_context(DummyGenericClass())
@test_util.assert_no_new_pyobjects_executing_eagerly
def testTensor(self):
tensor = array_ops.zeros([10])
make_function_signature_with_context(tensor)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testTuple(self):
make_function_signature_with_context((1, 2, 3))
@test_util.assert_no_new_pyobjects_executing_eagerly
def testDict(self):
make_function_signature_with_context({1: 1, 2: 2, 3: 3})
@test_util.assert_no_new_pyobjects_executing_eagerly
def testList(self):
make_function_signature_with_context([1, 2, 3])
@test_util.assert_no_new_pyobjects_executing_eagerly
def testAttrs(self):
make_function_signature_with_context(TestAttrsClass(1, 2))
class CacheKeyGenerationBenchmark(test.Benchmark):
def benchmarkTensor(self):
shapes = [[1], [2, 19], [5, 11, 24], [4, 5, 9, 23]]
tensors = []
for s in shapes:
tensors.append(array_ops.zeros(s))
def encode_tensors(tensors):
make_function_signature_with_context(tensors)
iterations = 100000
t = timeit.timeit(lambda: encode_tensors(tensors), number=iterations)
self.report_benchmark(
name='tensor_cache_key_generation',
iters=iterations,
wall_time=t,
metrics=[{
'name': 'tensor_cache_key_generation_avg_ms',
'value': t / iterations * 1000
}])
def benchmarkTensorSpec(self):
shapes = [[1], [2, 19], [5, 11, 24], [4, 5, 9, 23]]
tensor_specs = []
for s in shapes:
tensor_specs.append(tensor_spec.TensorSpec(s, dtypes.int32))
def encode_tensor_specs(tensor_specs):
make_function_signature_with_context(tensor_specs)
iterations = 100000
t = timeit.timeit(
lambda: encode_tensor_specs(tensor_specs), number=iterations)
self.report_benchmark(
name='tensor_spec_cache_key_generation',
iters=iterations,
wall_time=t,
metrics=[{
'name': 'tensor_spec_cache_key_generation_avg_ms',
'value': t / iterations * 1000
}])
def benchmarkVariable(self):
var_list = [
variables.Variable(1.0),
variables.Variable(1),
variables.Variable([1])
]
def encode_variables(var_list):
make_function_signature_with_context(var_list)
iterations = 10000
t = timeit.timeit(lambda: encode_variables(var_list), number=iterations)
self.report_benchmark(
name='variable_cache_key_generation',
iters=iterations,
wall_time=t,
metrics=[{
'name': 'variable_cache_key_generation_avg_ms',
'value': t / iterations * 1000
}])
def benchmarkCacheKeyLookup(self):
@function.defun
def defined(t):
return t
call_arg_list = [
1,
array_ops.zeros([5, 13]),
array_ops.zeros([9, 22, 24]),
array_ops.zeros([5, 13, 2])
]
for c in call_arg_list:
defined(c)
lookup_call_arg = array_ops.zeros([5, 13])
iterations = 10000
t = timeit.timeit(stmt=lambda: defined(lookup_call_arg), number=iterations)
self.report_benchmark(
name='cache_key_lookup',
iters=iterations,
wall_time=t,
metrics=[{
'name': 'cache_key_lookup_avg_ms',
'value': t / iterations * 1000
}])
def benchmarkNestedStruct(self):
struct = {(1, 2, 3): {(1, 2): {12: 2}}, (3, 2, 3): (2, {2: 3})}
def encode_struct(struct):
make_function_signature_with_context(struct)
iterations = 100000
t = timeit.timeit(lambda: encode_struct(struct), number=iterations)
self.report_benchmark(
name='nested_struct_cache_key_generation',
iters=iterations,
wall_time=t,
metrics=[{
'name': 'nested_struct_cache_key_generation_avg_ms',
'value': t / iterations * 1000
}])
def benchmarkFunctionInvocation(self):
struct = (variables.Variable(1.0), array_ops.zeros([5, 13]), {
'tensor': array_ops.zeros([5, 20]),
'variable': variables.Variable(1.0)
})
@function.defun
def defined(t):
return t
defined(struct) # Get it traced and cached.
iterations = 10000
t = timeit.timeit(lambda: defined(struct), number=iterations)
self.report_benchmark(
name='function_invocation',
iters=iterations,
wall_time=t,
metrics=[{
'name': 'function_invocation_time_avg_ms',
'value': t / iterations * 1000
}])
if __name__ == '__main__':
v2_compat.enable_v2_behavior()
test.main()