| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import collections |
| import copy |
| import functools |
| import itertools |
| import multiprocessing.pool |
| import sys |
| import time |
| import weakref |
| |
| from absl.testing import parameterized |
| import numpy |
| |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.core.protobuf import rewriter_config_pb2 |
| from tensorflow.python.autograph.core import ag_ctx |
| from tensorflow.python.data.ops import dataset_ops |
| from tensorflow.python.data.ops import iterator_ops |
| from tensorflow.python.eager import backprop |
| from tensorflow.python.eager import cancellation |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.eager import function |
| from tensorflow.python.framework import composite_tensor |
| from tensorflow.python.framework import config |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import func_graph |
| from tensorflow.python.framework import function as tf_function |
| from tensorflow.python.framework import indexed_slices |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import random_seed |
| from tensorflow.python.framework import sparse_tensor |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.framework import tensor_spec |
| from tensorflow.python.framework import test_ops |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.framework import type_spec |
| from tensorflow.python.layers import convolutional |
| from tensorflow.python.module import module |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import check_ops |
| from tensorflow.python.ops import clip_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import data_flow_ops |
| from tensorflow.python.ops import functional_ops |
| from tensorflow.python.ops import gen_functional_ops |
| from tensorflow.python.ops import gen_random_ops |
| from tensorflow.python.ops import gen_resource_variable_ops |
| from tensorflow.python.ops import gen_sendrecv_ops |
| from tensorflow.python.ops import gradients_impl |
| from tensorflow.python.ops import init_ops |
| from tensorflow.python.ops import list_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import random_ops |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.ops import string_ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops import variables |
| from tensorflow.python.ops.ragged import ragged_factory_ops |
| from tensorflow.python.ops.ragged import ragged_tensor |
| from tensorflow.python.ops.structured import structured_tensor |
| from tensorflow.python.platform import test |
| from tensorflow.python.training import training_ops |
| from tensorflow.python.util import compat |
| from tensorflow.python.util import nest |
| from tensorflow.python.util import tf_inspect |
| |
| try: |
| import attr # pylint:disable=g-import-not-at-top |
| except ImportError: |
| attr = None |
| |
| |
| def total_function_cache(defined): |
| # pylint: disable=protected-access |
| return (set(defined._function_cache.primary) |
| | set(defined._function_cache.arg_relaxed)) |
| # pylint: enable=protected-access |
| |
| |
| def _example_indexed_slices_with_dense_shape(): |
| return indexed_slices.IndexedSlices( |
| constant_op.constant([1, 2]), constant_op.constant([0, 1]), |
| constant_op.constant([2])) |
| |
| |
| def _example_indexed_slices_without_dense_shape(): |
| return indexed_slices.IndexedSlices( |
| constant_op.constant([1, 2]), constant_op.constant([0, 1])) |
| |
| |
| def _spec_for_value(value): |
| """Returns the (nested) TypeSpec for a value.""" |
| if nest.is_sequence(value): |
| return nest.map_structure(_spec_for_value, value) |
| elif isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)): |
| return type_spec.type_spec_from_value(value) |
| else: |
| return value |
| |
| |
| class FunctionTest(test.TestCase, parameterized.TestCase): |
| |
| def setUp(self): |
| super(FunctionTest, self).setUp() |
| cpus = config.list_physical_devices('CPU') |
| # Set 4 virtual CPUs |
| config.set_logical_device_configuration(cpus[0], [ |
| context.LogicalDeviceConfiguration(), |
| context.LogicalDeviceConfiguration(), |
| context.LogicalDeviceConfiguration(), |
| context.LogicalDeviceConfiguration() |
| ]) |
| |
| def testBasic(self): |
| matmul = def_function.function(math_ops.matmul) |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| sq = matmul(t, t, transpose_a=True) |
| sq2 = matmul(sq, t, transpose_a=True) |
| self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) |
| self.assertAllEqual(sq2.numpy().reshape(-1), [52, 76, 74, 108]) |
| |
| def testOnExitCallback(self): |
| values = [] |
| def append_1(): |
| values.append(1) |
| |
| def append_2(): |
| values.append(2) |
| |
| def g(x): |
| old_values = list(values) |
| ops.add_exit_callback_to_default_func_graph(append_1) |
| self.assertEqual(old_values, values) |
| return x + 1 |
| |
| tf_g = def_function.function(g) |
| |
| def f(x): |
| old_values = list(values) |
| ops.add_exit_callback_to_default_func_graph(append_2) |
| self.assertEqual(old_values, values) |
| return tf_g(x) |
| |
| tf_f = def_function.function(f) |
| self.assertEmpty(values) |
| tf_f(constant_op.constant(1.0)) |
| self.assertEqual(values, [1, 2]) # Once for g, once for f. |
| tf_f(constant_op.constant([1.0])) # force a retrace |
| self.assertEqual(values, [1, 2, 1, 2]) # And again. |
| |
| def testCannotAddExitCallbackWhenNotInFunctionScope(self): |
| with self.assertRaisesRegex(RuntimeError, 'when not building a function.'): |
| ops.add_exit_callback_to_default_func_graph(lambda: None) |
| |
| def testVariable(self): |
| v1 = variables.Variable(1.0) |
| add = def_function.function(lambda x, v: x + v1 + v) |
| v2 = variables.Variable(1.0) |
| x = constant_op.constant(1.0) |
| r = add(x, v2) |
| self.assertEqual(3.0, self.evaluate(r)) |
| |
| def testVariableOnly(self): |
| v = variables.Variable(1.0) |
| add = def_function.function(lambda x: x.assign_add(1.0)) |
| r1 = add(v) |
| self.assertEqual(2.0, self.evaluate(r1)) |
| c = constant_op.constant(1.0) |
| with self.assertRaisesRegex(AttributeError, 'no attribute'): |
| add(c) |
| |
| @test_util.disable_tfrt('Packed tensor is not supported in tfrt yet.') |
| def testPackedVariable(self): |
| with ops.device('/cpu:0'): |
| v0_0 = resource_variable_ops.ResourceVariable(1.0) |
| with ops.device('/cpu:1'): |
| v0_1 = resource_variable_ops.ResourceVariable(2.0) |
| v1_0 = resource_variable_ops.ResourceVariable(3.0) |
| with ops.device('/cpu:2'): |
| v1_1 = resource_variable_ops.ResourceVariable(4.0) |
| |
| packed_var_0 = ops.pack_eager_tensors([v0_0.handle, v0_1.handle]) |
| packed_var_1 = ops.pack_eager_tensors([v1_0.handle, v1_1.handle]) |
| |
| # TODO(b/145922293): use ResourceVariable.assign_add and |
| # ResourceVariable.read_value directly once we support packing multiple |
| # ResourceVariable into one ResourceVariable. |
| @def_function.function |
| def read_var(): |
| resource_variable_ops.assign_add_variable_op( |
| packed_var_0, constant_op.constant(5.0)) |
| resource_variable_ops.assign_add_variable_op( |
| packed_var_1, constant_op.constant(6.0)) |
| with ops.device('/cpu:0'): |
| read0 = resource_variable_ops.read_variable_op( |
| packed_var_0, dtype=dtypes.float32) |
| with ops.device('/cpu:1'): |
| read1 = resource_variable_ops.read_variable_op( |
| packed_var_0, dtype=dtypes.float32) |
| read2 = resource_variable_ops.read_variable_op( |
| packed_var_1, dtype=dtypes.float32) |
| with ops.device('/cpu:2'): |
| read3 = resource_variable_ops.read_variable_op( |
| packed_var_1, dtype=dtypes.float32) |
| |
| return read0, read1, read2, read3 |
| |
| arg_attrs = read_var.get_concrete_function().function_def.arg_attr |
| self.assertLen(arg_attrs, 2) |
| self.assertEqual(arg_attrs[0].attr['_composite_device'].s, |
| compat.as_bytes(packed_var_0.device)) |
| self.assertEqual(arg_attrs[1].attr['_composite_device'].s, |
| compat.as_bytes(packed_var_1.device)) |
| |
| self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6)) |
| |
| def testImplementsAttributeBasic(self): |
| v = def_function.function( |
| experimental_implements='func')(lambda x, y: x + y) |
| with context.graph_mode(), self.cached_session(): |
| a = array_ops.placeholder(dtypes.float32, ()) |
| b = array_ops.placeholder(dtypes.float32, ()) |
| v(a, b) |
| gradients_impl.gradients(v(a, b), [a, b]) |
| fdefs = ops.get_default_graph().as_graph_def().library.function |
| self.assertLen(fdefs, 3) |
| not_present = 0 |
| present = 0 |
| for f in fdefs: |
| name = f.signature.name |
| if 'forward' in name or 'backward' in name: |
| not_present += 1 |
| self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f) |
| else: |
| present += 1 |
| self.assertEqual(f.attr[function.IMPLEMENTS_ATTRIBUTE_NAME].s, |
| 'func'.encode('ascii'), f) |
| self.assertEqual(not_present, 2, fdefs) |
| self.assertEqual(present, 1, fdefs) |
| |
| def testImplementsAttributeAssertsOnSideInput(self): |
| with context.graph_mode(), self.cached_session(): |
| z = array_ops.zeros(0) |
| v = def_function.function( |
| experimental_implements='func')(lambda x, y: x + y + z) |
| a = array_ops.ones((1.0,)) |
| b = array_ops.ones((1.0,)) |
| with self.assertRaisesRegex(AssertionError, |
| 'variables are always captured'): |
| v(a, b) |
| functions = ops.get_default_graph().as_graph_def().library.function |
| self.assertEmpty(functions) |
| |
| def testImplementsAttributeWorksOnVariables(self): |
| with context.graph_mode(), self.cached_session(): |
| v = def_function.function( |
| experimental_implements='func')(lambda x, y: x + y) |
| a = variables.Variable((1.0,)) |
| b = variables.Variable((1.0,)) |
| r1 = v(a, b) |
| _ = v(a, a) |
| functions = ops.get_default_graph().as_graph_def().library.function |
| # Verify that we created only one function |
| self.assertLen(functions, 1) |
| # Verify that eval() reads the current values. |
| a.initializer.run() |
| b.initializer.run() |
| self.assertEqual(r1.eval(), 2) |
| |
| a.assign_add([1]).eval() |
| self.assertEqual(r1.eval(), 3) |
| |
| def testImplementsAttributeWorksOnConstants(self): |
| with context.graph_mode(), self.cached_session(): |
| v = def_function.function( |
| experimental_implements='func')(lambda x, y: x + y) |
| a = variables.Variable(1.0) |
| r1 = v(a, 2.) |
| r2 = v(2., a) |
| functions = ops.get_default_graph().as_graph_def().library.function |
| self.assertLen(functions, 1) |
| self.assertLen(functions[0].signature.input_arg, 2) |
| # Verify that eval() reads the current values. |
| a.initializer.run() |
| self.assertEqual(r1.eval(), 3) |
| self.assertEqual(r2.eval(), 3) |
| |
| def testImplementsAttributeSpecializes(self): |
| with context.graph_mode(), self.cached_session(): |
| v = def_function.function( |
| experimental_implements='func')(lambda x, y: x + y) |
| a = variables.Variable(1.0) |
| r1 = v(a, [2.]) |
| r2 = v([2., 2], a) |
| functions = ops.get_default_graph().as_graph_def().library.function |
| self.assertLen(functions, 2) |
| # Ensure that all parameters are still there and haven't been inlined! |
| |
| self.assertLen(functions[0].signature.input_arg, 2) |
| self.assertLen(functions[1].signature.input_arg, 2) |
| # Verify that eval() reads the current values. |
| a.initializer.run() |
| numpy.testing.assert_equal(r1.eval(), [3.]) |
| numpy.testing.assert_equal(r2.eval(), [3., 3.]) |
| |
| def testImplementsAttributeAsNameAttrList(self): |
| implements_attr = ( |
| 'name: "embedding_matmul" attr { key: "key1" value { i: 2 } ' |
| '} attr { key: "key2" value { b: false } }') |
| v = def_function.function( |
| experimental_implements=implements_attr)(lambda x, y: x + y) |
| with context.graph_mode(), self.cached_session(): |
| a = array_ops.placeholder(dtypes.float32, ()) |
| b = array_ops.placeholder(dtypes.float32, ()) |
| v(a, b) |
| gradients_impl.gradients(v(a, b), [a, b]) |
| fdefs = ops.get_default_graph().as_graph_def().library.function |
| self.assertLen(fdefs, 3) |
| not_present = 0 |
| present = 0 |
| for f in fdefs: |
| name = f.signature.name |
| if 'forward' in name or 'backward' in name: |
| not_present += 1 |
| self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f) |
| else: |
| present += 1 |
| attr_value = f.attr[function.IMPLEMENTS_ATTRIBUTE_NAME] |
| self.assertIsNotNone(attr_value.func, f) |
| self.assertEqual(attr_value.func.name, 'embedding_matmul') |
| name_attrs = attr_value.func.attr |
| self.assertLen(name_attrs, 2) |
| self.assertEqual(not_present, 2, fdefs) |
| self.assertEqual(present, 1, fdefs) |
| |
| def testExternalControlDependency(self): |
| with ops.Graph().as_default(), self.test_session(): |
| v = variables.Variable(1.0) |
| v.initializer.run() |
| |
| op = v.assign_add(1.0) |
| |
| @function.defun |
| def f(): |
| with ops.control_dependencies([op]): |
| return 1.0 |
| |
| self.evaluate(f()) |
| self.assertAllEqual(self.evaluate(v), 2.0) |
| |
| def testInputShapeFunctionRelaxation(self): |
| unknown_dim = [False] |
| |
| @function.defun(experimental_relax_shapes=True) |
| def func(a): |
| if a._shape_tuple()[0] is None: |
| unknown_dim[0] = True |
| return a + 1 |
| |
| func(constant_op.constant([])) |
| self.assertFalse(unknown_dim[0]) |
| self.assertLen(total_function_cache(func), 1) |
| |
| func(constant_op.constant([1.0])) |
| self.assertFalse(unknown_dim[0]) |
| self.assertLen(total_function_cache(func), 2) |
| |
| func(constant_op.constant([1.0, 2.0])) |
| self.assertTrue(unknown_dim[0]) |
| self.assertLen(total_function_cache(func), 2) |
| |
| def testInputShapeRelaxationOnInstanceMethod(self): |
| # Test that experimental_relax_shapes is passed during |
| # instance method bounding. |
| unknown_dim = [False] |
| |
| class Foo(object): |
| |
| @def_function.function(experimental_relax_shapes=True) |
| def func(self, a): |
| if a._shape_tuple()[0] is None: |
| unknown_dim[0] = True |
| return a + 1 |
| |
| foo = Foo() |
| foo.func(constant_op.constant([])) |
| self.assertFalse(unknown_dim[0]) |
| |
| foo.func(constant_op.constant([1.0])) |
| self.assertFalse(unknown_dim[0]) |
| |
| foo.func(constant_op.constant([1.0, 2.0])) |
| self.assertTrue(unknown_dim[0]) |
| |
| def testInputShapeFunctionRelaxationWithRaggedTensors(self): |
| traced_type_spec = [None] |
| |
| @def_function.function(experimental_relax_shapes=True) |
| def func(x): |
| traced_type_spec[0] = x._type_spec |
| return x |
| |
| def check_trace(x, expected_trace): |
| traced_type_spec[0] = None |
| func(x) |
| self.assertEqual(traced_type_spec[0], expected_trace) |
| |
| check_trace( # Initial call gets traced. |
| ragged_factory_ops.constant([[1], [2, 3, 4]]), |
| ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32)) |
| check_trace( # Input TypeSpec is the same -> no retrace. |
| ragged_factory_ops.constant([[1, 2], [3, 4]]), None) |
| check_trace( # Even if component tensor shapes change -> no retrace. |
| ragged_factory_ops.constant([[1, 2], [3, 4, 5, 6]]), None) |
| check_trace( # Different TypeSpec shape (nrows): retrace |
| ragged_factory_ops.constant([[1], [2], [3]]), |
| ragged_tensor.RaggedTensorSpec([3, None], dtypes.int32)) |
| check_trace( # Different nrows again: relax & retrace |
| ragged_factory_ops.constant([[1], [2], [3], [4]]), |
| ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)) |
| check_trace( # Different nrows yet again: not retrace |
| ragged_factory_ops.constant([[1]]), None) |
| check_trace( # Different ragged_rank: retrace |
| ragged_factory_ops.constant([[[1]]]), |
| ragged_tensor.RaggedTensorSpec([1, None, None], dtypes.int32)) |
| check_trace( # Different ragged_rank again: retrace & relax |
| ragged_factory_ops.constant([[[1]], [[2]]]), |
| ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32)) |
| |
| def testInputShapeFunctionRelaxationWithStructuredTensors(self): |
| traced_type_spec = [None] |
| |
| @def_function.function(experimental_relax_shapes=True) |
| def func(x): |
| traced_type_spec[0] = x._type_spec |
| return x |
| |
| def check_trace(x, expected_trace): |
| traced_type_spec[0] = None |
| func(x) |
| self.assertEqual(traced_type_spec[0], expected_trace) |
| |
| # If we have TypeSpecs that differ in ways other than just their shape, |
| # then retrace each time. |
| check_trace( |
| structured_tensor.StructuredTensor.from_pyval({'a': [1]}), |
| structured_tensor.StructuredTensorSpec( |
| [], {'a': tensor_spec.TensorSpec((1,), dtypes.int32)})) |
| check_trace( |
| structured_tensor.StructuredTensor.from_pyval({'b': [1]}), |
| structured_tensor.StructuredTensorSpec( |
| [], {'b': tensor_spec.TensorSpec((1,), dtypes.int32)})) |
| check_trace( |
| structured_tensor.StructuredTensor.from_pyval({'c': [1]}), |
| structured_tensor.StructuredTensorSpec( |
| [], {'c': tensor_spec.TensorSpec((1,), dtypes.int32)})) |
| |
| # But if we call again with only shape different, then do relax: |
| check_trace( # retrace |
| structured_tensor.StructuredTensor.from_pyval({'a': [1, 2]}), |
| structured_tensor.StructuredTensorSpec( |
| [], {'a': tensor_spec.TensorSpec((2,), dtypes.int32)})) |
| check_trace( # relax & retrace |
| structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3]}), |
| structured_tensor.StructuredTensorSpec( |
| [], {'a': tensor_spec.TensorSpec((None,), dtypes.int32)})) |
| check_trace( # use relaxed graph |
| structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3, 4]}), |
| None) |
| |
| def testInputShapeFunctionRelaxationWithDatasetIterators(self): |
| # For dataset iterators, the TypeSpec includes type information that's |
| # not derivable from the component tensors. Make sure that the TypeSpec |
| # shapes get relaxed as appropriate. |
| |
| traced_type_spec = [None] |
| |
| @def_function.function(experimental_relax_shapes=True) |
| def func(x): |
| traced_type_spec[0] = x._type_spec |
| return x |
| |
| def check_trace(x, expected_trace): |
| traced_type_spec[0] = None |
| func(x) |
| self.assertEqual(traced_type_spec[0], expected_trace) |
| |
| ds_1_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([1, 2])) |
| ds_2_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 2])) |
| ds_3_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([3, 2])) |
| ds_4_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([4, 2])) |
| ds_2_1 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 1])) |
| check_trace( # shape=[1, 2]: retrace |
| dataset_ops.make_one_shot_iterator(ds_1_2), |
| iterator_ops.IteratorSpec( |
| tensor_spec.TensorSpec([1, 2], dtypes.float32))) |
| check_trace( # shape=[1, 2]: no retrace (use the [1, 2] graph) |
| dataset_ops.make_one_shot_iterator(ds_1_2), None) |
| check_trace( # shape=[2, 2]: retrace |
| dataset_ops.make_one_shot_iterator(ds_2_2), |
| iterator_ops.IteratorSpec( |
| tensor_spec.TensorSpec([2, 2], dtypes.float32))) |
| check_trace( # shape=[3, 2]: relax to [None, 2] and retrace |
| dataset_ops.make_one_shot_iterator(ds_3_2), |
| iterator_ops.IteratorSpec( |
| tensor_spec.TensorSpec([None, 2], dtypes.float32))) |
| check_trace( # shape=[4, 2]: no retrace (use the [None, 2] graph) |
| dataset_ops.make_one_shot_iterator(ds_4_2), None) |
| check_trace( # shape=[2, 1]: relax to [None, None] and retrace |
| dataset_ops.make_one_shot_iterator(ds_2_1), |
| iterator_ops.IteratorSpec( |
| tensor_spec.TensorSpec([None, None], dtypes.float32))) |
| |
| def testCapturesVariables(self): |
| a = variables.Variable(1.0, trainable=False) |
| b = variables.Variable(1.0) |
| cc = [None] |
| |
| @def_function.function |
| def f(): |
| c = cc[0] |
| if c is None: |
| c = cc[0] = variables.Variable(1.) |
| return a + b + c + 1 |
| |
| cf = f.get_concrete_function() |
| c = cc[0] |
| |
| captured_variables = {v.ref() for v in (a, b, c)} |
| trainable_variables = {v.ref() for v in (b, c)} |
| self.assertEqual({v.ref() for v in cf.variables}, captured_variables) |
| self.assertEqual({v.ref() for v in cf.trainable_variables}, |
| trainable_variables) |
| self.assertEqual(cf.variables, cf.graph.variables) |
| self.assertEqual(cf.trainable_variables, cf.graph.trainable_variables) |
| |
| def testNestedInputShapeFunctionRelaxation(self): |
| unknown_dim = [False] |
| |
| @function.defun(experimental_relax_shapes=True) |
| def func(a_, b_=None): |
| del a_ # Only used to check which cache is used. |
| self.assertEqual(b_[0]._shape_tuple(), ()) |
| if b_[1]._shape_tuple()[0] is None: |
| unknown_dim[0] = True |
| return b_[0] + 1 |
| |
| a = 'hi' |
| b0 = constant_op.constant(1.0) |
| func(a, b_=[b0, constant_op.constant([])]) |
| self.assertFalse(unknown_dim[0]) |
| self.assertLen(total_function_cache(func), 1) |
| |
| func(a, b_=[b0, constant_op.constant([1.0])]) |
| self.assertFalse(unknown_dim[0]) |
| self.assertLen(total_function_cache(func), 2) |
| |
| func(a, b_=[b0, constant_op.constant([1.0, 1.0])]) |
| self.assertTrue(unknown_dim[0]) |
| self.assertLen(total_function_cache(func), 2) |
| |
| unknown_dim[0] = False |
| |
| # Now do the same except with a new a which is not a tensor; this should |
| # change the cache key. |
| a = 'bye' |
| func(a, b_=[b0, constant_op.constant([])]) |
| self.assertFalse(unknown_dim[0]) |
| self.assertLen(total_function_cache(func), 3) |
| |
| # Since we already marked a cache miss for a function with the same |
| # non-input signatures, here we will immediately start relaxing shapes. |
| func(a, b_=[b0, constant_op.constant([1.0])]) |
| self.assertTrue(unknown_dim[0]) |
| self.assertLen(total_function_cache(func), 3) |
| |
| def testNestedShapeFunctionRelaxation(self): |
| |
| got_shape = [None] |
| |
| # The inner function will go through shape relaxation because the shapes it |
| # receives will be [1], [2], [3], ... |
| @def_function.function(experimental_relax_shapes=True) |
| def bar(x_shape): |
| got_shape[0] = x_shape._shape_tuple() |
| return x_shape |
| |
| # The outer function will not go through shape relaxation because the shapes |
| # it receives will be [1], [[1]], [[[1]]], ... |
| @def_function.function(experimental_relax_shapes=True) |
| def foo(ones): |
| return bar(array_ops.shape(ones)) |
| |
| for rank in range(1, 6): |
| x_shape = self.evaluate(foo(array_ops.ones([1] * rank))) |
| self.assertAllEqual(x_shape, [1] * rank) |
| if rank < 3: |
| self.assertEqual(got_shape[0], (rank,)) |
| else: |
| self.assertEqual(got_shape[0], (None,)) |
| |
| def testNoHash(self): |
| |
| @def_function.function() |
| def f(_): |
| return 1.0 |
| |
| with self.assertRaisesRegex(ValueError, r'Got type: set'): |
| f(set([])) |
| |
| def testFuncName(self): |
| |
| @function.defun_with_attributes(attributes={'func_name': 'multiply'}) |
| def add(x, y): |
| _ = x * y |
| return x + y |
| |
| @function.defun |
| def add_2(x, y): |
| _ = x * y |
| return x + y |
| |
| self.assertEqual(add._name, 'multiply') |
| self.assertEqual(add_2._name, 'add_2') |
| |
| def testBasicGraphMode(self): |
| matmul = def_function.function(math_ops.matmul) |
| |
| @def_function.function |
| def sq(a): |
| return matmul(a, a) |
| |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| out = sq(t) |
| self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) |
| |
| def testNestedInputsGraphMode(self): |
| matmul = def_function.function(math_ops.matmul) |
| |
| pair = collections.namedtuple('pair', ['a', 'b']) |
| |
| @def_function.function |
| def a_times_b(inputs): |
| return matmul(inputs.a['a'], inputs.b['b']) |
| |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| |
| out = a_times_b(pair({'a': t}, {'b': t})) |
| self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) |
| |
| def testNestedOutputsGraphMode(self): |
| matmul = def_function.function(math_ops.matmul) |
| |
| pair = collections.namedtuple('pair', ['a', 'b']) |
| |
| @def_function.function() |
| def pairs_mul(pair_a, pair_b): |
| return pair(matmul(pair_a.a, pair_b.a), matmul(pair_a.b, pair_b.b)) |
| |
| a = constant_op.constant([[1.0, 2.0], [1.0, 2.0]]) |
| b = constant_op.constant([[3.0, 4.0], [3.0, 4.0]]) |
| |
| out = pairs_mul(pair(a, b), pair(b, a)) |
| expected = pair(math_ops.matmul(a, b).numpy(), |
| math_ops.matmul(b, a).numpy()) |
| self.assertAllClose(out, expected) |
| |
| @parameterized.named_parameters( |
| dict(testcase_name='Defun', |
| function_decorator=function.defun), |
| dict(testcase_name='DefFunction', |
| function_decorator=def_function.function)) |
| def testNestedFunctionGraphNotOutOfDate(self, function_decorator): |
| @function_decorator |
| def f(): |
| return constant_op.constant(1.) |
| |
| class _Model(object): |
| |
| @function_decorator |
| def g(self): |
| self.f = f.get_concrete_function() |
| |
| model = _Model() |
| model.g() |
| concrete = model.f |
| weak_g_graph = weakref.ref(model.g.get_concrete_function().graph) |
| self.assertIs(weak_g_graph(), concrete.graph.outer_graph) |
| weak_g = weakref.ref(model.g) |
| del model |
| self.assertIsNone(weak_g()) |
| self.assertIsNone(weak_g_graph()) |
| self.assertIsNotNone(concrete.graph.outer_graph) |
| self.assertIs(ops.get_default_graph(), concrete.graph.outer_graph) |
| |
| def testGraphEagerIsolation(self): |
| |
| @function.defun |
| def f(): |
| self.v = variables.Variable(1.0) |
| return self.v.read_value() |
| |
| self.assertAllEqual(f(), 1.0) |
| |
| with ops.Graph().as_default(): |
| self.assertEqual(f().shape, ()) |
| |
| def testBasicGraphFunction(self): |
| matmul = def_function.function(math_ops.matmul) |
| |
| @def_function.function |
| def sq(a): |
| return matmul(a, a) |
| |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| |
| sq_op = sq.get_concrete_function(t) |
| self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) |
| out = sq_op(t) |
| self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) |
| |
| def testGetConcreteFunctionThreadSafety(self): |
| |
| @def_function.function |
| def sq(): |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| return math_ops.matmul(t, t) |
| |
| concrete_functions = [] |
| |
| def thread_func(_): |
| cf = sq.get_concrete_function() |
| concrete_functions.append(cf) |
| |
| num_threads = 100 |
| pool = multiprocessing.pool.ThreadPool(num_threads) |
| _ = pool.map(thread_func, list(range(num_threads))) |
| |
| self.assertLen(set(concrete_functions), 1) |
| |
| def testGetConcreteFunctionThreadSafetyWithArgs(self): |
| @def_function.function |
| def add_100(*args): |
| return math_ops.add_n(args) |
| |
| p = multiprocessing.pool.ThreadPool(2) |
| args = (constant_op.constant(1.),) * 100 |
| f1, f2 = p.map(add_100.get_concrete_function, [args] * 2) |
| # I see about len(args) + max(0, len(args) - 3) arguments expected. |
| f1(*args) |
| del f2 |
| |
| def testInputSpecGraphFunction(self): |
| matmul = def_function.function(math_ops.matmul) |
| |
| @def_function.function |
| def sq(a): |
| return matmul(a, a) |
| |
| sq_op = sq.get_concrete_function( |
| tensor_spec.TensorSpec((None, None), dtypes.float32)) |
| self.assertEqual([None, None], sq_op.output_shapes.as_list()) |
| |
| t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| out1 = sq_op(t1) |
| self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy()) |
| |
| t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| out2 = sq_op(t2) |
| self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy()) |
| |
| def testNestedInputSpecGraphFunction(self): |
| matmul = def_function.function(math_ops.matmul) |
| |
| @def_function.function |
| def sq(mats): |
| ((a, b),) = mats |
| return matmul(a, b) |
| |
| sq_op_autonamed = sq.get_concrete_function( |
| [(tensor_spec.TensorSpec((None, None), dtypes.float32), |
| tensor_spec.TensorSpec((None, None), dtypes.float32))]) |
| self.assertEqual([None, None], sq_op_autonamed.output_shapes.as_list()) |
| |
| sq_op = sq.get_concrete_function( |
| [(tensor_spec.TensorSpec((None, None), dtypes.float32, |
| name='first_mat'), |
| tensor_spec.TensorSpec((None, None), dtypes.float32, |
| name='second_mat'))]) |
| self.assertEqual([None, None], sq_op.output_shapes.as_list()) |
| |
| t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]]) |
| out = sq_op(first_mat=t1, second_mat=t2) |
| self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy()) |
| self.assertAllEqual(sq_op_autonamed(t1, t2), |
| math_ops.matmul(t1, t2).numpy()) |
| |
| def testExecutingStatelessDefunConcurrently(self): |
| |
| @def_function.function |
| def stateless(x): |
| return math_ops.multiply(2.0, x) |
| |
| pool = multiprocessing.pool.ThreadPool() |
| inputs = [constant_op.constant(1.0 * x) for x in range(100)] |
| outputs = [float(out) for out in pool.map(stateless, inputs)] |
| expected = [float(2.0 * x) for x in inputs] |
| self.assertSequenceEqual(outputs, expected) |
| |
| def testExecutingManyStatelessDefunsConcurrently(self): |
| |
| @def_function.function |
| def stateless(x): |
| del x |
| return math_ops.multiply(2.0, 2.0) |
| |
| pool = multiprocessing.pool.ThreadPool() |
| # `pool.map` below instantiates 100 functions, one for each object. |
| objects = [object() for _ in range(100)] |
| outputs = [float(out) for out in pool.map(stateless, objects)] |
| expected = [4.0] * 100 |
| self.assertSequenceEqual(outputs, expected) |
| |
| @test_util.disable_tfrt('b/169431085: This test is flaky on tfrt') |
| def testExecutingStatefulDefunConcurrently(self): |
| |
| v = resource_variable_ops.ResourceVariable(1.0) |
| |
| @def_function.function |
| def stateful(x): |
| v.assign(x) |
| |
| pool = multiprocessing.pool.ThreadPool() |
| inputs = [constant_op.constant(0.0)] * 100 |
| pool.map(stateful, inputs) |
| self.assertEqual(float(v.read_value()), 0.0) |
| |
| def testExecutingManyStatefulDefunsConcurrently(self): |
| |
| v = resource_variable_ops.ResourceVariable(1.0) |
| |
| @def_function.function |
| def stateful(x): |
| del x |
| return v.assign(0.0) |
| |
| pool = multiprocessing.pool.ThreadPool() |
| # `pool.map` below instantiates 100 functions, one for each object. |
| pool.map(stateful, [object() for _ in range(100)]) |
| self.assertEqual(float(v.read_value()), 0.0) |
| |
| def testShareRendezvous(self): |
| |
| # Disable grappler from inlining the functions. Note we run the send & recv |
| # in graph mode since with eager mode the function should automatically be |
| # inlined. |
| context.context().set_optimizer_experimental_options( |
| {'disable_meta_optimizer': True}) |
| |
| cpu = '/device:CPU:0' |
| |
| signature = [tensor_spec.TensorSpec([], dtypes.int32)] |
| |
| @def_function.function |
| def send(): |
| x = constant_op.constant(1) |
| gen_sendrecv_ops.send(x, 'x', cpu, 0, cpu) |
| return x |
| |
| send._shared_rendezvous = True # pylint: disable=protected-access |
| |
| @def_function.function(input_signature=signature) |
| def send_body(n): |
| send() |
| return n - 1 |
| |
| @def_function.function |
| def recv(): |
| return gen_sendrecv_ops.recv(dtypes.int32, 'x', cpu, 0, cpu) |
| |
| recv._shared_rendezvous = True # pylint: disable=protected-access |
| |
| @def_function.function(input_signature=signature) |
| def recv_body(n): |
| recv() |
| return n - 1 |
| |
| @def_function.function(input_signature=signature) |
| def cond(n): |
| return n > 0 |
| |
| # Instead of calling the send & recv functions directly we want to call them |
| # through a functional while to ensure the rendezvous is shared across the |
| # while boundary. |
| @def_function.function |
| def fn(n): |
| functional_ops.While([n], cond.get_concrete_function(), |
| send_body.get_concrete_function()) |
| return functional_ops.While([n], cond.get_concrete_function(), |
| recv_body.get_concrete_function()) |
| |
| # Use a graph context since functions will not be automatically inlined |
| with context.graph_mode(), self.cached_session(): |
| self.evaluate(fn(2)) |
| |
| def disabled_testRandomSeed(self): |
| |
| @def_function.function |
| def f(): |
| return random_ops.random_normal(()) |
| |
| random_seed.set_random_seed(1) |
| x = f() |
| self.assertNotEqual(x, f()) |
| random_seed.set_random_seed(1) |
| self.assertAllEqual(f(), x) |
| |
| def testNestedInputsGraphFunction(self): |
| matmul = def_function.function(math_ops.matmul) |
| |
| pair = collections.namedtuple('pair', ['a', 'b']) |
| |
| @def_function.function |
| def a_times_b(inputs): |
| return matmul(inputs.a['a'], inputs.b['b']) |
| |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| sq_op = a_times_b.get_concrete_function( |
| pair(dict(a=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'a')), |
| dict(b=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'b')))) |
| self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) |
| out = sq_op(a=t, b=t) |
| self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) |
| |
| def testNestedOutputGraphFunction(self): |
| matmul = def_function.function(math_ops.matmul) |
| |
| @def_function.function |
| def sq(a): |
| return (matmul(a, a), {'b': constant_op.constant(1.0)}) |
| |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| |
| sq_op = sq.get_concrete_function(t) |
| self.assertEqual(sq_op.output_shapes, |
| (tensor_shape.TensorShape([2, 2]), |
| {'b': tensor_shape.TensorShape([])})) |
| self.assertEqual(sq_op.output_dtypes, |
| (dtypes.float32, {'b': dtypes.float32})) |
| (a, b) = sq_op(t) |
| self.assertAllEqual(a, math_ops.matmul(t, t).numpy()) |
| self.assertAllEqual(b['b'].numpy(), 1.0) |
| |
| def testGraphFunctionNoneOutput(self): |
| @def_function.function |
| def fn(unused_a, unused_b): |
| return None |
| |
| x = constant_op.constant(1) |
| fn_op = fn.get_concrete_function(x, x) |
| self.assertEqual(fn_op.output_dtypes, None) |
| self.assertEqual(fn_op.output_shapes, None) |
| self.assertAllEqual(fn_op(x, x), None) |
| |
| def testDefunNumpyArraysConvertedToTensors(self): |
| |
| def f(x): |
| self.assertIsInstance(x, ops.Tensor) |
| return x |
| |
| x = random_ops.random_uniform([2, 2]).numpy() |
| defined = function.defun(f) |
| defined(x) |
| self.assertLen(total_function_cache(defined), 1) |
| |
| x = random_ops.random_uniform([2, 2]).numpy() |
| defined(x) |
| # A NumPy array with different values but the same shape and dtype |
| # shouldn't trigger another function definition. |
| self.assertLen(total_function_cache(defined), 1) |
| |
| np_ones = numpy.ones([], numpy.float32) |
| np_zeros = numpy.zeros([], numpy.float32) |
| tf_ones = array_ops.ones([]) |
| tf_zeros = array_ops.zeros([]) |
| |
| # Test that the numpy array is properly an argument to the graph function. |
| self.assertEqual(1., defined(np_ones).numpy()) |
| self.assertLen(total_function_cache(defined), 2) |
| self.assertEqual(0., defined(np_zeros).numpy()) |
| self.assertEqual(1., defined(tf_ones).numpy()) |
| self.assertEqual(0., defined(tf_zeros).numpy()) |
| self.assertLen(total_function_cache(defined), 2) |
| |
| # Test that mutable inputs are supported. |
| mutable = numpy.ones([], numpy.float32) |
| self.assertEqual(1., defined(mutable).numpy()) |
| mutable.fill(0) |
| self.assertEqual(0., defined(mutable).numpy()) |
| |
| class MyNdarray(numpy.ndarray): |
| pass |
| |
| # Test that the subclasses of ndarray are converted too. |
| self.assertEqual(1., defined(np_ones.view(MyNdarray)).numpy()) |
| self.assertEqual(0., defined(np_zeros.view(MyNdarray)).numpy()) |
| |
| # We should not have triggered any re-tracing of the python function. |
| self.assertLen(total_function_cache(defined), 2) |
| |
| def testNumpyDtypeInputSupported(self): |
| @function.defun |
| def f(x, dtype): |
| return constant_op.constant(dtype(x)) |
| |
| self.assertEqual(f(1, numpy.float32).numpy(), numpy.float32(1)) |
| self.assertEqual(f(2, numpy.float32).numpy(), numpy.float32(2)) |
| self.assertEqual(f(1, numpy.int32).numpy(), numpy.int32(1)) |
| self.assertEqual(f(2, numpy.int32).numpy(), numpy.int32(2)) |
| |
| def testDefunNumpyArraysConvertedToTensorsInKwargs(self): |
| |
| def f(**kwargs): |
| x = kwargs.pop('x') |
| self.assertIsInstance(x, ops.Tensor) |
| return x |
| |
| x = random_ops.random_uniform([2, 2]).numpy() |
| defined = function.defun(f) |
| defined(x=x) |
| self.assertLen(total_function_cache(defined), 1) |
| |
| x = random_ops.random_uniform([2, 2]).numpy() |
| defined(x=x) |
| # A NumPy array with different values but the same shape and dtype |
| # shouldn't trigger another function definition. |
| self.assertLen(total_function_cache(defined), 1) |
| |
| # Test that the numpy array is properly an argument to the graph function. |
| self.assertEqual(1., defined(x=numpy.ones([])).numpy()) |
| self.assertEqual(0., defined(x=numpy.zeros([])).numpy()) |
| self.assertEqual(1., defined(x=array_ops.ones([])).numpy()) |
| self.assertEqual(0., defined(x=array_ops.zeros([])).numpy()) |
| |
| def testDefunCapturedInt32(self): |
| x = constant_op.constant(1, dtype=dtypes.int32) |
| |
| @def_function.function |
| def add_int32s(): |
| return x + x |
| |
| self.assertEqual(2, int(add_int32s())) |
| |
| def testDefunReadVariable(self): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| |
| @def_function.function |
| def f(): |
| return v.read_value() |
| |
| self.assertEqual(1.0, float(f())) |
| |
| def testDefunAssignAddVariable(self): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| x = constant_op.constant(2.0) |
| |
| @def_function.function |
| def test_assign_add(): |
| v.assign_add(x) |
| return v.read_value() |
| |
| self.assertEqual(3.0, float(test_assign_add())) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testTensorInitializationInFunctionRaisesError(self): |
| error_msg = ('Tensor-typed variable initializers must either be ' |
| 'wrapped in an init_scope or callable.*') |
| |
| @def_function.function |
| def tensor_init(): |
| with self.assertRaisesRegex(ValueError, error_msg): |
| resource_variable_ops.ResourceVariable(constant_op.constant(2.0)) |
| |
| tensor_init() |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testCallableTensorInitializationInFunction(self): |
| |
| @def_function.function |
| def tensor_init(): |
| self.v = resource_variable_ops.ResourceVariable( |
| lambda: constant_op.constant(2.0)) |
| return self.v.read_value() |
| |
| value = tensor_init() |
| if not context.executing_eagerly(): |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertEqual(self.evaluate(value), 2.0) |
| |
| @test_util.also_run_as_tf_function |
| def testInitScopeTensorInitializationInFunction(self): |
| |
| @def_function.function |
| def tensor_init(): |
| with ops.init_scope(): |
| const = constant_op.constant(2.0) |
| # Note: this variable bypasses tf.function's variable creation |
| # requirements by bypassing variable_creator_scope by using |
| # ResourceVariable instead of Variable. |
| self.v = resource_variable_ops.ResourceVariable(const) |
| return self.v.read_value() |
| |
| value = tensor_init() |
| self.assertAllEqual(value, 2.0) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testGetConcreteFunctionCreatesVariables(self): |
| |
| v_holder = [] |
| |
| @def_function.function |
| def tensor_init(): |
| if not v_holder: |
| v_holder.append(variables.Variable(5.)) |
| return v_holder[0].read_value() |
| |
| concrete = tensor_init.get_concrete_function() |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertAllEqual(5., self.evaluate(concrete())) |
| self.assertAllEqual(5., self.evaluate(tensor_init())) |
| |
| def testFuncGraphCaptureByValue(self): |
| v = variables.Variable(1.0) |
| |
| def trivial_function(): |
| return v.read_value() |
| |
| graph_function = function.Function( |
| trivial_function, 'test', capture_by_value=True) |
| |
| self.assertAllEqual(graph_function(), 1.0) |
| v.assign(2.0) |
| self.assertAllEqual(graph_function(), 1.0) |
| |
| def testFuncGraphCaptureByValueNested(self): |
| v = variables.Variable(1.0) |
| |
| def trivial_function(): |
| return control_flow_ops.cond( |
| array_ops.placeholder_with_default(True, ()), |
| v.read_value, v.read_value) |
| |
| graph_function = function.Function( |
| trivial_function, 'test', capture_by_value=True) |
| |
| self.assertAllEqual(graph_function(), 1.0) |
| v.assign(2.0) |
| self.assertAllEqual(graph_function(), 1.0) |
| |
| def testDefunShapeInferenceWithCapturedResourceVariable(self): |
| v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) |
| |
| def f(): |
| x = constant_op.constant([[1, 2], [3, 4]]) |
| out = math_ops.matmul(v, x) |
| self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) |
| # We do not return v directly since the tensor conversion function of |
| # ResourceVariable returns the read value and not the resource itself. |
| return v._handle |
| |
| compiled = def_function.function(f) |
| var_handle = compiled() |
| self.assertEqual(var_handle.dtype, dtypes.resource) |
| self.assertEqual(var_handle.shape, tensor_shape.TensorShape([])) |
| var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) |
| self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) |
| |
| def testShapeInferenceForMoreSpecificInput(self): |
| |
| def f(a): |
| return array_ops.reshape(a, [-1, 3]) |
| |
| signature = [tensor_spec.TensorSpec(None, dtypes.float32)] |
| compiled = def_function.function(f, input_signature=signature) |
| |
| @def_function.function |
| def use_f(): |
| inputs = array_ops.zeros([10, 10, 3]) |
| self.assertAllEqual(f(inputs).shape, compiled(inputs).shape) |
| |
| use_f() |
| |
| def testFuncListAttr(self): |
| |
| @function.defun |
| def test_function(val): |
| |
| def fn1(): |
| return array_ops.ones([10]) |
| |
| fn2 = lambda: array_ops.ones([10]) * 2 |
| |
| def fn3(x=3): |
| return array_ops.ones([10]) * x |
| fn4 = functools.partial(fn3, x=4) |
| fn5 = functools.partial(fn3, 5) |
| |
| return gen_functional_ops.case(val, [], [dtypes.float32], |
| [function.defun(f).get_concrete_function() |
| for f in (fn1, fn2, fn3, fn4, fn5)]) |
| |
| ones = array_ops.ones([10]) |
| self.assertAllEqual([ones], test_function(0)) |
| self.assertAllEqual([ones * 2], test_function(1)) |
| self.assertAllEqual([ones * 3], test_function(2)) |
| self.assertAllEqual([ones * 4], test_function(3)) |
| self.assertAllEqual([ones * 5], test_function(4)) |
| self.assertAllEqual([ones * 5], test_function(22)) # default branch |
| |
| @test_util.enable_control_flow_v2 |
| def testVariableInLoopInFunction(self): |
| |
| @function.defun |
| def test_function(): |
| |
| def loop_test(_): |
| return False |
| |
| def loop_body(_): |
| return variable_scope.get_variable('a', shape=()) |
| |
| return control_flow_ops.while_loop(loop_test, loop_body, [0.0]) |
| |
| self.assertEqual(test_function().shape, []) |
| |
| def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self): |
| with context.graph_mode(): |
| v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) |
| |
| def f(): |
| x = constant_op.constant([[1, 2], [3, 4]]) |
| out = math_ops.matmul(v, x) |
| self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) |
| # We do not return v directly since the tensor conversion function of |
| # ResourceVariable returns the read value and not the resource itself. |
| return v._handle |
| |
| compiled = def_function.function(f) |
| var_handle = compiled() |
| self.assertEqual(var_handle.dtype, dtypes.resource) |
| self.assertEqual(var_handle.shape, tensor_shape.TensorShape([])) |
| var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) |
| self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) |
| |
| def testDefunShapeInferenceWithCapturedVariableInGraphMode(self): |
| with context.graph_mode(): |
| v = variables.Variable([[1, 2], [3, 4]]) |
| |
| def f(): |
| x = constant_op.constant([[1, 2], [3, 4]]) |
| out = math_ops.matmul(v, x) |
| self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) |
| |
| # Check that shape inference works while creating the defun |
| compiled = def_function.function(f) |
| compiled() |
| |
| def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self): |
| with context.graph_mode(): |
| tensor_list = list_ops.empty_tensor_list( |
| element_dtype=dtypes.float32, |
| element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) |
| tensor_list = list_ops.tensor_list_push_back(tensor_list, |
| constant_op.constant(1.0)) |
| tensor_list = list_ops.tensor_list_push_back(tensor_list, |
| constant_op.constant(2.0)) |
| |
| def f(): |
| tl, value = list_ops.tensor_list_pop_back( |
| tensor_list, element_dtype=dtypes.float32) |
| self.assertEqual(value.shape, tensor_shape.TensorShape([])) |
| return tl |
| |
| compiled = def_function.function(f) |
| output_tensor_list = compiled() |
| _, value = list_ops.tensor_list_pop_back( |
| output_tensor_list, element_dtype=dtypes.float32) |
| self.assertEqual(value.shape, tensor_shape.TensorShape([])) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testDefunForcesResourceVariables(self): |
| |
| def variable_creator(): |
| self.v = variables.Variable(0.0) |
| return self.v.read_value() |
| |
| self.v = None |
| defined = function.defun(variable_creator) |
| defined() # Create the variable. |
| self.assertIsInstance( |
| self.v, resource_variable_ops.ResourceVariable) |
| |
| @test_util.disable_tfrt('b/169294215') |
| def testRunMetadata(self): |
| |
| @def_function.function |
| def f(x): |
| return x * x |
| |
| with ops.device('cpu:0'): |
| context.enable_run_metadata() |
| f(constant_op.constant(1.0)) |
| run_metadata = context.export_run_metadata() |
| context.disable_run_metadata() |
| self.assertLen(run_metadata.partition_graphs, 1) |
| |
| def testGraphModeCaptureVariable(self): |
| with context.graph_mode(), self.cached_session(): |
| |
| class HasAVar(object): |
| |
| def __init__(self): |
| self.v = resource_variable_ops.ResourceVariable(1.0) |
| |
| def call(self): |
| return self.v * 2 |
| |
| o = HasAVar() |
| self.evaluate(variables.global_variables_initializer()) |
| call = def_function.function(o.call) |
| op = call() |
| self.assertAllEqual(self.evaluate(op), 2.0) |
| |
| def testGraphModeManyFunctions(self): |
| with ops.Graph().as_default(), self.cached_session(): |
| |
| @def_function.function |
| def f(x): |
| return x * x |
| |
| @def_function.function |
| def g(x): |
| return f(x) + 1 |
| |
| self.assertAllEqual(g(constant_op.constant(2.0)), 5.0) |
| |
| def testDict(self): |
| |
| @def_function.function |
| def f(x): |
| return {'name': x + 1} |
| |
| self.assertAllEqual(f(constant_op.constant(1.0))['name'], 2.0) |
| |
| def testTensorConversionWithDefun(self): |
| |
| @def_function.function |
| def f(x): |
| return math_ops.add(x, constant_op.constant(3)) |
| |
| self.assertAllEqual(5, f(constant_op.constant(2))) |
| |
| def testTensorConversionCall(self): |
| |
| @def_function.function |
| def f(x): |
| return math_ops.add(x, constant_op.constant(3)) |
| |
| @def_function.function |
| def g(x): |
| return f(f(x)) |
| |
| self.assertAllEqual(8, g(constant_op.constant(2))) |
| |
| def testCallShape(self): |
| |
| @def_function.function |
| def f(x): |
| return x + 1 |
| |
| @def_function.function |
| def g(x): |
| x = f(x) |
| self.assertEqual(x.shape.as_list(), []) |
| return None |
| |
| g(constant_op.constant(1.0)) |
| |
| def testNestedDefunWithNoOutputAndTapedInput(self): |
| three = resource_variable_ops.ResourceVariable(3.0, name='v') |
| |
| @def_function.function |
| def f(x): |
| # This function intentionally takes a taped variable as input, |
| # but does not return any values |
| math_ops.add(x, three) |
| |
| @def_function.function |
| def g(x): |
| y = math_ops.add(x, three) |
| f(y) |
| |
| g(three) |
| |
| def testGatherResourceWithDefun(self): |
| with ops.device('cpu:0'): |
| v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) |
| |
| def sum_gather(): |
| return math_ops.reduce_sum(array_ops.gather(v, [1, 2])) |
| |
| defined = def_function.function(sum_gather) |
| self.assertAllEqual(sum_gather(), defined()) |
| |
| @parameterized.named_parameters([ |
| ('IndexedSlicesWithDenseShape', |
| _example_indexed_slices_with_dense_shape,), |
| ('IndexedSlicesWithoutDenseShape', |
| _example_indexed_slices_without_dense_shape,), |
| ('RaggedTensorRaggedRank1', ragged_tensor.RaggedTensor.from_row_lengths, |
| {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}), |
| ('RaggedTensorRaggedRank2', |
| ragged_tensor.RaggedTensor.from_nested_row_lengths, |
| {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}), |
| ('SparseTensor', sparse_tensor.SparseTensor, |
| {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}), |
| ]) # pyformat: disable |
| def testReturnCompositeTensorWithDefun(self, |
| factory_fn, |
| factory_kwargs={}, |
| input_signature=None): |
| input_ct = factory_fn(**factory_kwargs) |
| |
| @def_function.function(input_signature=input_signature) |
| def f(): |
| return input_ct |
| |
| output_ct = f() |
| self.assertIsInstance(output_ct, type(input_ct)) |
| nest.assert_same_structure(input_ct, output_ct, expand_composites=True) |
| |
| input_flat = nest.flatten(input_ct, expand_composites=True) |
| output_flat = nest.flatten(output_ct, expand_composites=True) |
| for (input_component, output_component) in zip(input_flat, output_flat): |
| self.assertAllEqual(input_component, output_component) |
| |
| @parameterized.named_parameters([ |
| ('IndexedSlicesWithDenseShape', |
| _example_indexed_slices_with_dense_shape,), |
| ('IndexedSlicesWithoutDenseShape', |
| _example_indexed_slices_without_dense_shape,), |
| ('RaggedTensorRaggedRank1', |
| ragged_tensor.RaggedTensor.from_row_lengths, |
| {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}), |
| ('RaggedTensorRaggedRank2', |
| ragged_tensor.RaggedTensor.from_nested_row_lengths, |
| {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}), |
| ('SparseTensor', |
| sparse_tensor.SparseTensor, |
| {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}), |
| ('RaggedTensorRaggedRank1WithSignature', |
| ragged_tensor.RaggedTensor.from_row_lengths, |
| {'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}, |
| [ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)]), |
| ('RaggedTensorRaggedRank2WithSignature', |
| ragged_tensor.RaggedTensor.from_nested_row_lengths, |
| {'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}, |
| [ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32)]), |
| ('SparseTensorWithSignature', |
| sparse_tensor.SparseTensor, |
| {'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}, |
| [sparse_tensor.SparseTensorSpec([None], dtypes.int32)]), |
| ]) # pyformat: disable |
| def testCompositeAsArgumentTensorWithDefun(self, |
| factory_fn, |
| factory_kwargs={}, |
| input_signature=None): |
| input_ct = factory_fn(**factory_kwargs) |
| |
| @def_function.function(input_signature=input_signature) |
| def f(x): |
| return x |
| |
| output_ct = f(input_ct) |
| self.assertIsInstance(output_ct, type(input_ct)) |
| nest.assert_same_structure(input_ct, output_ct, expand_composites=True) |
| |
| input_flat = nest.flatten(input_ct, expand_composites=True) |
| output_flat = nest.flatten(output_ct, expand_composites=True) |
| for (input_component, output_component) in zip(input_flat, output_flat): |
| self.assertAllEqual(input_component, output_component) |
| |
| def testTracedCompositeDiscardsShapeInfo(self): |
| # SparseTensorSpec intentionally excludes info about the number of elements |
| # that are in a sparse tensor (which is recorded as st.indices.shape[0] and |
| # st.values.shape[0]). Similarly, RaggedTensorSpec intentionally excludes |
| # info about the total number of values in a RaggedTensor (stored as |
| # rt.values.shape[0]). This test checks that the placeholders created by |
| # tf.function() properly mask this shape info. |
| @def_function.function |
| def f(rt, st): |
| self.assertEqual(st.indices.shape.as_list()[:1], [None]) |
| self.assertEqual(st.values.shape.as_list(), [None]) |
| return (rt, st) |
| |
| rt = ragged_factory_ops.constant([[1, 2], [3]]) |
| st = sparse_tensor.SparseTensor([[0]], [0], [10]) |
| f(rt, st) |
| |
| @test_util.run_gpu_only |
| def testFunctionOnDevice(self): |
| x = constant_op.constant([1.]).gpu() |
| f = def_function.function(math_ops.add) |
| y = f(x, x).cpu() |
| self.assertAllEqual(y, [2.]) |
| |
| @test_util.run_gpu_only |
| @test_util.run_in_graph_and_eager_modes |
| def testFunctionWithResourcesOnDifferentDevices(self): |
| with ops.device('/cpu:0'): |
| v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) |
| |
| with ops.device('/gpu:0'): |
| v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) |
| |
| def sum_gather(): |
| cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2])) |
| gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2])) |
| return cpu_result, gpu_result |
| |
| defined = function.defun(sum_gather) |
| if not context.executing_eagerly(): |
| self.evaluate(variables.global_variables_initializer()) |
| expected = self.evaluate(sum_gather()) |
| self.assertAllEqual(expected, self.evaluate(defined())) |
| |
| @test_util.run_gpu_only |
| @test_util.run_in_graph_and_eager_modes |
| def testOpInFunctionWithConflictingResourceInputs(self): |
| with ops.device('/cpu:0'): |
| v_cpu = resource_variable_ops.ResourceVariable( |
| [0.0, 1.0, 2.0], name='cpu') |
| v_also_cpu = resource_variable_ops.ResourceVariable( |
| [0.0, 1.0, 2.0], name='also_cpu') |
| |
| with ops.device('/gpu:0'): |
| v_gpu = resource_variable_ops.ResourceVariable( |
| [0.0, 1.0, 2.0], name='gpu') |
| |
| @def_function.function |
| def resource_apply_adam(): |
| training_ops.resource_apply_adam( |
| v_cpu.handle, |
| v_gpu.handle, |
| v_also_cpu.handle, |
| 1.0, # beta1_power |
| 1.0, # beta2_power |
| 1.0, # learning_rate |
| 1.0, # beta1 |
| 1.0, # beta2 |
| 1.0, # epsilon, |
| [1.0, 1.0, 1.0], # grad |
| False) # use_locking |
| return None |
| |
| with self.assertRaisesRegex( |
| errors.InvalidArgumentError, |
| 'Cannot place the graph because a reference or resource edge connects ' |
| 'colocation groups with incompatible assigned devices'): |
| if not context.executing_eagerly(): |
| self.evaluate(variables.global_variables_initializer()) |
| self.evaluate(resource_apply_adam()) |
| |
| @test_util.run_gpu_only |
| def testFunctionHandlesInputsOnDifferentDevices(self): |
| # The Reshape op requires the shape tensor to be placed in host memory. |
| reshape = def_function.function(array_ops.reshape) |
| value = constant_op.constant([1., 2.]).gpu() |
| shape = constant_op.constant([2, 1]) |
| reshaped = reshape(value, shape).cpu() |
| self.assertAllEqual(reshaped, [[1], [2]]) |
| |
| @test_util.run_gpu_only |
| def testFunctionHandlesInputsPlacedOnTheWrongDeviceGracefully(self): |
| # The Reshape op requires the shape tensor to be placed in host memory. |
| reshape = def_function.function(array_ops.reshape) |
| value = constant_op.constant([1., 2.]) |
| shape = constant_op.constant([2, 1]).gpu() |
| reshape(value, shape) # No error is raised |
| |
| def testNoneOutput(self): |
| |
| @def_function.function |
| def my_function(_): |
| return None |
| |
| self.assertAllEqual(my_function(1), None) |
| |
| def testNestedFunctions(self): |
| # TensorFlow function (which is what would be used in TensorFlow graph |
| # construction). |
| @tf_function.Defun(dtypes.int32, dtypes.int32) |
| def add(a, b): |
| return math_ops.add(a, b) |
| |
| @def_function.function |
| def add_one(x): |
| return add(x, 1) |
| |
| self.assertAllEqual(3, add_one(constant_op.constant(2))) |
| |
| def testVariableCaptureInNestedFunctions(self): |
| v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32) |
| |
| @def_function.function |
| def inner_read(): |
| return v.read_value() |
| |
| @def_function.function |
| def outer(): |
| return inner_read() |
| |
| self.assertEqual(1, int(outer())) |
| |
| def testReturnCapturedEagerTensor(self): |
| t = constant_op.constant(1) |
| |
| @def_function.function |
| def read(): |
| return t |
| |
| self.assertEqual(1, int(read())) |
| |
| def testReturnCapturedGraphTensor(self): |
| with context.graph_mode(), self.cached_session(): |
| t = constant_op.constant(1) |
| |
| @def_function.function |
| def read(): |
| return t |
| |
| self.assertEqual(1, int(self.evaluate(read()))) |
| |
| def testSequenceInputs(self): |
| clip_by_global_norm = def_function.function(clip_ops.clip_by_global_norm) |
| t_list = [constant_op.constant(1.0), constant_op.constant(2.0)] |
| clipped_list, global_norm = clip_by_global_norm(t_list, |
| constant_op.constant(.2)) |
| for t in clipped_list: |
| self.assertIsInstance(t, ops.Tensor) |
| self.assertIsInstance(global_norm, ops.Tensor) |
| |
| def testNestedSequenceInputs(self): |
| |
| def my_op(inputs): |
| a, b, c = inputs |
| e, f = b |
| g, h = e |
| return [a + a, [tuple([f + f, g + g]), h + h], c + c], a + f + g + h + c |
| |
| my_eager_op = def_function.function(my_op) |
| ret = my_eager_op([ |
| constant_op.constant(1), [(constant_op.constant(2), |
| constant_op.constant(3)), |
| constant_op.constant(4)], |
| constant_op.constant(5) |
| ]) |
| self.assertLen(ret, 2) |
| self.assertAllEqual(ret[0][0], 2) |
| self.assertAllEqual(ret[0][1][0][0], 8) |
| self.assertAllEqual(ret[0][1][0][1], 4) |
| self.assertIsInstance(ret[0][1][0], tuple) |
| self.assertAllEqual(ret[0][1][1], 6) |
| self.assertAllEqual(ret[0][2], 10) |
| self.assertAllEqual(ret[1], 15) |
| |
| def testVariableNamesRespectNameScopesWithDefun(self): |
| @def_function.function |
| def create_variable(): |
| with ops.name_scope('foo', skip_on_eager=False): |
| v = resource_variable_ops.ResourceVariable(0.0, name='bar') |
| self.assertEqual(v.name, 'foo/bar:0') |
| |
| create_variable() |
| |
| def testVariableNamesRespectNameScopesWithDefunInGraph(self): |
| with context.graph_mode(): |
| @def_function.function |
| def create_variable(): |
| with ops.name_scope('foo', skip_on_eager=False): |
| v = resource_variable_ops.ResourceVariable([1.0, 2.0], name='bar') |
| self.assertEqual(v.name, 'foo/bar:0') |
| |
| with ops.get_default_graph().as_default(): |
| create_variable() |
| |
| @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) |
| def testLayerInDefun(self): |
| conv = convolutional.Conv2D( |
| filters=1, |
| kernel_size=2, |
| kernel_initializer=init_ops.ones_initializer(), |
| bias_initializer=init_ops.zeros_initializer()) |
| |
| @function.defun |
| def model(x): |
| return conv(x) |
| |
| x = array_ops.ones([1, 2, 2, 1]) |
| y = model(x) |
| |
| if not context.executing_eagerly(): |
| self.evaluate(variables.global_variables_initializer()) |
| |
| self.assertAllClose([[[[4.0]]]], self.evaluate(y)) |
| |
| # Variable lifting is somewhat different between defun/tf.function, so testing |
| # device placement on both makes sense. |
| @parameterized.named_parameters( |
| dict(testcase_name='Defun', |
| function_decorator=function.defun), |
| dict(testcase_name='DefFunction', |
| function_decorator=def_function.function)) |
| @test_util.run_in_graph_and_eager_modes |
| def testVariablesPlacedOnOutsideDevice(self, function_decorator): |
| |
| class _Obj(object): |
| |
| def __init__(self): |
| self.v = None |
| |
| @function_decorator |
| def f(self): |
| if self.v is None: |
| self.v = variables.Variable(1.) |
| return self.v + 1. |
| |
| has_device = _Obj() |
| with ops.device('cpu:0'): |
| has_device.f() |
| self.assertIn('CPU', has_device.v.device) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testDeviceAnnotationsRespected(self): |
| |
| def multi_device_fn(): |
| with ops.device('/cpu:0'): |
| s0 = test_ops.device_placement_op() |
| with ops.device('/cpu:1'): |
| s1 = test_ops.device_placement_op() |
| with ops.device('/cpu:2'): |
| s2 = test_ops.device_placement_op() |
| s3 = test_ops.device_placement_op() |
| return s0, s1, s2, s3 |
| |
| defined = function.defun(multi_device_fn) |
| outputs = self.evaluate(defined()) |
| self.assertLen(total_function_cache(defined), 1) |
| self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) |
| self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) |
| self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) |
| |
| with ops.device('/cpu:3'): |
| outputs = self.evaluate(defined()) |
| # All function definitions are agnostic to call site devices. |
| self.assertLen(total_function_cache(defined), 1) |
| self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) |
| self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) |
| self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) |
| self.assertIn(compat.as_bytes('CPU:3'), outputs[3]) |
| |
| with ops.device('/cpu:0'): |
| outputs = self.evaluate(defined()) |
| self.assertLen(total_function_cache(defined), 1) |
| self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) |
| self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) |
| self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) |
| self.assertIn(compat.as_bytes('CPU:0'), outputs[3]) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testCallingGraphFunctionOnDifferentDevice(self): |
| |
| def func(): |
| return constant_op.constant(0) |
| |
| defined = def_function.function(func) |
| with ops.device('cpu:0'): |
| cpu_graph_function = defined.get_concrete_function() |
| |
| with ops.device('cpu:0'): |
| self.assertEqual( |
| self.evaluate(cpu_graph_function()), self.evaluate(func())) |
| |
| with ops.device('cpu:1'): |
| self.assertEqual(0., self.evaluate(cpu_graph_function())) |
| |
| with ops.device(None): |
| self.assertEqual(0., self.evaluate(cpu_graph_function())) |
| |
| default_graph_function = defined.get_concrete_function() |
| self.assertEqual( |
| self.evaluate(default_graph_function()), self.evaluate(func())) |
| |
| with ops.device('cpu:1'): |
| self.assertEqual(0., self.evaluate(default_graph_function())) |
| |
| @test_util.run_gpu_only |
| @test_util.run_in_graph_and_eager_modes |
| def testColocateWithRespected(self): |
| # TODO(b/113291792): Use multiple CPUs instead of a GPU. |
| with ops.device('cpu:0'): |
| x = array_ops.identity(1.0) |
| |
| with ops.device('gpu:0'): |
| y = array_ops.identity(1.0) |
| |
| @def_function.function |
| def foo(): |
| return test_ops.device_placement_op() |
| |
| with ops.colocate_with(x): |
| self.assertIn(compat.as_bytes('CPU:0'), self.evaluate(foo())) |
| |
| with ops.colocate_with(y): |
| self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo())) |
| |
| def testVariablesAreTracked(self): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| |
| def foo(x): |
| return v * x |
| |
| defined = def_function.function(foo) |
| |
| x = constant_op.constant([1.0]) |
| self.assertEqual(1., self.evaluate(defined(x))) |
| v.assign(2.) |
| |
| x = constant_op.constant([1.0, 2.0]) |
| self.assertAllEqual([2., 4.], self.evaluate(defined(x))) |
| |
| def testCacheObjectHashCollisions(self): |
| |
| class Foo(object): |
| |
| def __hash__(self): |
| return 42 |
| |
| def func(foo): |
| del foo |
| return |
| |
| defined = function.defun(func) |
| defined(Foo()) |
| self.assertLen(total_function_cache(defined), 1) |
| |
| defined(Foo()) |
| self.assertLen(total_function_cache(defined), 2) |
| |
| def testCacheTensorDtypeCollision(self): |
| |
| def func(t): |
| return t + t |
| |
| defined = function.defun(func) |
| t = constant_op.constant([[1.0]], dtype=dtypes.complex64) |
| defined(t) |
| self.assertLen(total_function_cache(defined), 1) |
| |
| t = constant_op.constant([[1.0]], dtype=dtypes.complex128) |
| defined(t) |
| self.assertLen(total_function_cache(defined), 2) |
| |
| def testCacheTensorShapeCollision(self): |
| |
| def func(t): |
| return t + t |
| |
| defined = function.defun(func) |
| t = constant_op.constant([[1.0]], dtype=dtypes.complex64) |
| defined(t) |
| self.assertLen(total_function_cache(defined), 1) |
| |
| t = constant_op.constant([1.0], dtype=dtypes.complex64) |
| defined(t) |
| self.assertLen(total_function_cache(defined), 2) |
| |
| def testCacheTensorShapeDtypeCollision(self): |
| |
| def func(t): |
| return t + t |
| |
| defined = function.defun(func) |
| t = constant_op.constant([[1.0]], dtype=dtypes.complex64) |
| defined(t) |
| self.assertLen(total_function_cache(defined), 1) |
| |
| t = constant_op.constant([1.0], dtype=dtypes.complex128) |
| defined(t) |
| self.assertLen(total_function_cache(defined), 2) |
| |
| def testCacheTensorUnknownShapesCollisionRelaxedShapes(self): |
| |
| def func(t): |
| return t + t |
| |
| with context.graph_mode(), self.cached_session(): |
| defined = function.defun(func, experimental_relax_shapes=True) |
| |
| p = array_ops.placeholder(dtype=dtypes.float32, shape=[]) |
| defined(p) |
| self.assertLen(total_function_cache(defined), 1) |
| |
| p = array_ops.placeholder(dtype=dtypes.float32, shape=[1]) |
| defined(p) |
| self.assertLen(total_function_cache(defined), 2) |
| |
| p = array_ops.placeholder(dtype=dtypes.float32, shape=[2]) |
| defined(p) |
| # Gradual shape relaxation is performed; and the common shape between |
| # [1] and [2] is one containing unknown dimensions. |
| self.assertLen(total_function_cache(defined), 2) |
| |
| # pylint: disable=protected-access |
| self.assertLen(defined._function_cache.arg_relaxed_specs, 1) |
| relaxed_specs = ( |
| list(defined._function_cache.arg_relaxed_specs.values())[0]) |
| self.assertLen(relaxed_specs, 1) |
| relaxed_shape = relaxed_specs[0].shape |
| # pylint: enable=protected-access |
| self.assertEqual(relaxed_shape.rank, 1) |
| self.assertEqual(tensor_shape.dimension_value(relaxed_shape[0]), None) |
| |
| t = constant_op.constant([1.0, 1.0, 1.0], dtype=dtypes.float32) |
| defined(t) |
| # Shape (3,) matches the relaxed shape TensorShape([None]) |
| self.assertLen(total_function_cache(defined), 2) |
| |
| def testPythonFunctionWithDefaultArgs(self): |
| |
| def func(foo, bar=1, baz=2): |
| del foo |
| del bar |
| del baz |
| return |
| |
| defined = function.defun(func) |
| defined(0, baz=20) |
| |
| def cache_keys(): |
| """Sanitizes cache keys of non-input metadata.""" |
| return tuple(key[0] for key in total_function_cache(defined)) |
| |
| # `True` corresponds to the fact that we're executing eagerly |
| self.assertIn(('URRRu', (0, 1, 20)), cache_keys()) |
| |
| defined(1) # bar=1, baz=2 |
| self.assertIn(('URRRu', (1, 1, 2)), cache_keys()) |
| |
| # This matches the previous call. |
| defined(foo=1) |
| self.assertLen(total_function_cache(defined), 2) |
| |
| defined(1, 2, 3) |
| self.assertLen(total_function_cache(defined), 3) |
| self.assertIn(('URRRu', (1, 2, 3)), cache_keys()) |
| |
| # This matches the previous call. |
| defined(1, bar=2, baz=3) |
| self.assertLen(total_function_cache(defined), 3) |
| |
| # This matches the previous call. |
| defined(1, baz=3, bar=2) |
| self.assertLen(total_function_cache(defined), 3) |
| |
| def testFunctoolsPartialUnwrappedCorrectly(self): |
| |
| def full_function(a, b, c=3): |
| return a, b, c |
| |
| partial = functools.partial(full_function, 1, c=4) |
| a, b, c = partial(2) |
| |
| defined = function.defun(partial) |
| func_a, func_b, func_c = defined(2) |
| self.assertEqual(func_a.numpy(), a) |
| self.assertEqual(func_b.numpy(), b) |
| self.assertEqual(func_c.numpy(), c) |
| |
| def testInputSignatureWithMatchingInputs(self): |
| |
| def foo(a): |
| self.assertEqual(a.shape, (2,)) |
| return a |
| |
| signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)] |
| defined = function.defun(foo, input_signature=signature) |
| a = array_ops.ones([2]) |
| self.assertAllEqual(a, defined(a)) |
| self.assertLen(total_function_cache(defined), 1) |
| self.assertAllEqual(a, defined.get_concrete_function()(a)) |
| self.assertAllEqual(a, defined.get_concrete_function(a)(a)) |
| self.assertAllEqual(a, defined.get_concrete_function( |
| tensor_spec.TensorSpec((2,), dtype=dtypes.float32))(a)) |
| self.assertLen(total_function_cache(defined), 1) |
| |
| def bar(a): |
| self.assertEqual(a._shape_tuple(), (2, None)) |
| return a |
| |
| signature = [tensor_spec.TensorSpec((2, None), dtypes.float32)] |
| defined = function.defun(bar, input_signature=signature) |
| a = array_ops.ones([2, 1]) |
| out = defined(a) |
| self.assertLen(total_function_cache(defined), 1) |
| self.assertAllEqual(out, a) |
| |
| # Changing the second dimension shouldn't create a new function. |
| b = array_ops.ones([2, 3]) |
| out = defined(b) |
| self.assertLen(total_function_cache(defined), 1) |
| self.assertAllEqual(out, b) |
| |
| def testInputSignatureWithCompatibleInputs(self): |
| |
| rank2_spec = tensor_spec.TensorSpec(shape=(None, None), |
| dtype=dtypes.float32) |
| |
| @function.defun(input_signature=[rank2_spec]) |
| def func(a): |
| self.assertEqual([None, None], a.shape.as_list()) |
| return array_ops.shape(a) |
| |
| self.assertAllEqual([3, 1], func([[0], [1.0], [1]])) |
| self.assertAllEqual([2, 2], func(numpy.array([[1, 1], [2, 2]]))) |
| |
| with self.assertRaisesRegex(ValueError, 'incompatible'): |
| func([0.0, 1.0, 2.0]) # Wrong shape. |
| |
| with self.assertRaisesRegex(ValueError, 'incompatible'): |
| func([['wrong dtype']]) |
| |
| def testNoKeywordOnlyArgumentsWithInputSignature(self): |
| if sys.version_info[0] < 3: |
| self.skipTest('keyword_only arguments only exist in Python 3.') |
| |
| func = eval('lambda x, *, y: x') # pylint: disable=eval-used |
| signature = [tensor_spec.TensorSpec(None, dtypes.int32)] |
| with self.assertRaisesRegex( |
| ValueError, 'Cannot define a TensorFlow function from a Python ' |
| 'function with keyword-only arguments when input_signature is ' |
| 'provided.'): |
| def_function.function(func, signature) |
| |
| def testNestedInputSignatures(self): |
| |
| def expected_foo(a, b): |
| return [a, b] |
| |
| @function.defun(input_signature=[ |
| [tensor_spec.TensorSpec((2, None), dtypes.float32)] * 2, |
| tensor_spec.TensorSpec((1,), dtypes.float32), |
| ]) |
| def foo(a, b): |
| self.assertEqual(a[0]._shape_tuple(), (2, None)) |
| self.assertEqual(a[1]._shape_tuple(), (2, None)) |
| self.assertEqual(b._shape_tuple(), (1,)) |
| return [a, b] |
| |
| a = array_ops.ones([2, 1]) |
| b = array_ops.ones([1]) |
| expected = expected_foo([a, a], b) |
| out = foo([a, a], b) |
| self.assertLen(total_function_cache(foo), 1) |
| nest.assert_same_structure(out, expected) |
| self.assertAllEqual(out[0][0], a) |
| self.assertAllEqual(out[0][1], a) |
| self.assertAllEqual(out[1], b) |
| |
| # Changing the unspecified dimensions shouldn't create a new function. |
| a = array_ops.ones([2, 3]) |
| b = array_ops.ones([2, 5]) |
| c = array_ops.ones([1]) |
| expected = expected_foo([a, b], c) |
| out = foo([a, b], c) |
| self.assertLen(total_function_cache(foo), 1) |
| nest.assert_same_structure(out, expected) |
| self.assertAllEqual(out[0][0], a) |
| self.assertAllEqual(out[0][1], b) |
| self.assertAllEqual(out[1], c) |
| |
| # Passing compatible inputs should work. |
| a = a.numpy().tolist() |
| b = b.numpy().tolist() |
| c = c.numpy().tolist() |
| out = foo([a, b], c) |
| self.assertLen(total_function_cache(foo), 1) |
| nest.assert_same_structure(out, expected) |
| self.assertAllEqual(out[0][0], a) |
| self.assertAllEqual(out[0][1], b) |
| self.assertAllEqual(out[1], c) |
| |
| def testNestedInputSignaturesWithDict(self): |
| def expected_bar(a): |
| return a |
| |
| @function.defun(input_signature=[{ |
| 'a': tensor_spec.TensorSpec((2, None), dtypes.float32), |
| 'b': tensor_spec.TensorSpec((2, None), dtypes.float32), |
| 'c': tensor_spec.TensorSpec((1,), dtypes.float32)}]) |
| def bar(a): |
| self.assertEqual(a['a']._shape_tuple(), (2, None)) |
| self.assertEqual(a['b']._shape_tuple(), (2, None)) |
| self.assertEqual(a['c']._shape_tuple(), (1,)) |
| return a |
| |
| a = array_ops.ones([2, 3]) |
| b = array_ops.ones([1]) |
| inputs = {'a': a, 'b': a, 'c': b} |
| expected = expected_bar(inputs) |
| out = bar(inputs) |
| nest.assert_same_structure(out, expected) |
| self.assertAllEqual(out['a'], expected['a']) |
| self.assertAllEqual(out['b'], expected['b']) |
| self.assertAllEqual(out['c'], expected['c']) |
| |
| # Passing compatible inputs should work. |
| a = a.numpy().tolist() |
| b = b.numpy().tolist() |
| inputs = {'a': a, 'b': a, 'c': b} |
| out = bar(inputs) |
| nest.assert_same_structure(out, expected) |
| self.assertAllEqual(out['a'], expected['a']) |
| self.assertAllEqual(out['b'], expected['b']) |
| self.assertAllEqual(out['c'], expected['c']) |
| |
| def testInputSignatureMustBeSequenceOfTensorSpecs(self): |
| |
| def foo(a, b): |
| del a |
| del b |
| |
| # Signatures must consist exclusively of `TensorSpec` objects. |
| signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)] |
| with self.assertRaisesRegex(TypeError, 'Invalid input_signature.*'): |
| def_function.function(foo, input_signature=signature) |
| |
| # Signatures must be either lists or tuples on their outermost levels. |
| signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)} |
| with self.assertRaisesRegex( |
| TypeError, 'input_signature must be either a ' |
| 'tuple or a list.*'): |
| function.defun(foo, input_signature=signature) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testInputsIncompatibleWithSignatureRaisesError(self): |
| |
| def foo(a): |
| return a |
| |
| signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)] |
| defined = def_function.function(foo, input_signature=signature) |
| |
| # Invalid shapes. |
| with self.assertRaisesRegex(ValueError, 'Python inputs incompatible.*'): |
| defined(array_ops.ones([3])) |
| |
| with self.assertRaisesRegex(ValueError, 'Python inputs incompatible.*'): |
| defined(array_ops.ones([2, 1])) |
| |
| # Wrong number of arguments. |
| with self.assertRaisesRegex( |
| TypeError, r'takes 1 positional arguments \(as specified by the ' |
| r'input_signature\) but 2 were given'): |
| defined(array_ops.ones([2]), array_ops.ones([2])) |
| with self.assertRaisesRegex(ValueError, |
| 'Structure of Python function inputs.*'): |
| defined() |
| |
| with self.assertRaisesRegex(ValueError, |
| 'inputs incompatible with input_signature'): |
| defined.get_concrete_function( |
| tensor_spec.TensorSpec(shape=(3,), dtype=dtypes.float32)) |
| |
| def testInputsIncompatibleWithNestedSignatureRaisesError(self): |
| |
| def foo(a, b): |
| return [a, b] |
| |
| signature = [[tensor_spec.TensorSpec((1,), dtypes.float32)] * 2, |
| [tensor_spec.TensorSpec((1,), dtypes.float32)] * 2] |
| defined = function.defun(foo, input_signature=signature) |
| a = array_ops.ones([1]) |
| |
| with self.assertRaisesRegex(ValueError, |
| 'Structure of Python function inputs.*'): |
| defined([a, a, a], [a]) |
| |
| with self.assertRaisesRegex(ValueError, |
| 'Structure of Python function inputs.*'): |
| defined([a], [a, a, a]) |
| defined([a, a], [a, a]) |
| |
| def testUnderspecifiedInputSignature(self): |
| @function.defun(input_signature=[ |
| tensor_spec.TensorSpec([], dtypes.float32), |
| ]) |
| def foo(a, training=True): |
| if training: |
| return a |
| else: |
| return -1.0 * a |
| |
| x = constant_op.constant(1.0) |
| with self.assertRaisesRegex( |
| TypeError, 'got keyword argument `training` ' |
| 'that was not included in input_signature'): |
| foo(x, training=True) |
| |
| with self.assertRaisesRegex( |
| TypeError, 'got keyword argument `training` ' |
| 'that was not included in input_signature'): |
| foo(x, training=False) |
| |
| self.assertAllEqual(x.numpy(), foo(x).numpy()) |
| |
| def testInputSignatureWithPartialFunction(self): |
| def full_function(a, b, c=3.0): |
| return a, b, c |
| |
| partial = functools.partial(full_function, 1, c=4) |
| a, b, c = partial(2.0) |
| signature = [tensor_spec.TensorSpec([], dtypes.float32)] |
| defined = function.defun(partial, input_signature=signature) |
| x = constant_op.constant(2.0) |
| func_a, func_b, func_c = defined(x) |
| self.assertEqual(func_a.numpy(), a) |
| self.assertEqual(func_b.numpy(), b) |
| self.assertEqual(func_c.numpy(), c) |
| |
| def testInputSignatureConversionWithDefaultArg(self): |
| |
| def foo(a, training=True): |
| if training: |
| return a |
| else: |
| return -1.0 * a |
| |
| signature = [ |
| tensor_spec.TensorSpec([], dtypes.float32), |
| tensor_spec.TensorSpec([], dtypes.bool), |
| ] |
| defined = def_function.function(foo, input_signature=signature) |
| a = constant_op.constant(1.0) |
| self.assertAllEqual(a.numpy(), defined(a)) |
| self.assertAllEqual(a.numpy(), defined(a, training=True)) |
| self.assertAllEqual(-a.numpy(), defined(a, training=False)) |
| |
| def testInputSignatureWithKeywordPositionalArgs(self): |
| |
| @function.defun(input_signature=[ |
| tensor_spec.TensorSpec([], dtypes.float32), |
| tensor_spec.TensorSpec([], dtypes.int64) |
| ]) |
| def foo(flt, integer): |
| return flt, integer |
| |
| flt = constant_op.constant(1.0) |
| integer = constant_op.constant(2, dtypes.int64) |
| |
| out1, out2 = foo(flt, integer) |
| self.assertLen(total_function_cache(foo), 1) |
| self.assertEqual(out1.numpy(), 1.0) |
| self.assertEqual(out2.numpy(), 2) |
| |
| out1, out2 = foo(flt=flt, integer=integer) |
| self.assertLen(total_function_cache(foo), 1) |
| self.assertEqual(out1.numpy(), 1.0) |
| self.assertEqual(out2.numpy(), 2) |
| |
| out1, out2 = foo(integer=integer, flt=flt) |
| self.assertLen(total_function_cache(foo), 1) |
| self.assertEqual(out1.numpy(), 1.0) |
| self.assertEqual(out2.numpy(), 2) |
| |
| out1, out2 = foo(flt, integer=integer) |
| self.assertLen(total_function_cache(foo), 1) |
| self.assertEqual(out1.numpy(), 1.0) |
| self.assertEqual(out2.numpy(), 2) |
| |
| def testInputSignatureWithKeywordArgs(self): |
| def foo(a, b, **kwargs): |
| del kwargs |
| return a, b |
| |
| x = function.defun( |
| foo, |
| input_signature=[ |
| tensor_spec.TensorSpec([], dtypes.float32), |
| tensor_spec.TensorSpec([], dtypes.int32) |
| ]).get_concrete_function() |
| result = x(constant_op.constant(5.0), constant_op.constant(5)) |
| self.assertAllEqual(result, [5.0, 5]) |
| |
| def testInputSignatureWithCompositeTensors(self): |
| def f(rt): |
| self.assertEqual(rt.values.shape.as_list(), [None]) |
| self.assertEqual(rt.row_splits.shape.as_list(), [4]) |
| return rt |
| |
| signature = [ragged_tensor.RaggedTensorSpec( |
| shape=[3, None], dtype=dtypes.int32)] |
| defined = function.defun(f, input_signature=signature) |
| rt1 = ragged_factory_ops.constant([[1], [], [2, 3, 4]]) |
| out1 = defined(rt1) |
| self.assertLen(total_function_cache(defined), 1) |
| self.assertAllEqual(out1.values, rt1.values) |
| self.assertAllEqual(out1.row_splits, rt1.row_splits) |
| |
| # Changing the row lengths shouldn't create a new function. |
| rt2 = ragged_factory_ops.constant([[1, 2], [3, 4], [5]]) |
| out2 = defined(rt2) |
| self.assertLen(total_function_cache(defined), 1) |
| self.assertAllEqual(out2.values, rt2.values) |
| self.assertAllEqual(out2.row_splits, rt2.row_splits) |
| |
| # Different number of rows |
| rt3 = ragged_factory_ops.constant([[1, 2], [3, 4], [5], [6]]) |
| with self.assertRaisesRegex(ValueError, 'incompatible'): |
| defined(rt3) |
| |
| # Different dtype |
| rt4 = ragged_factory_ops.constant([[1.0, 2.0], [], [3.0]]) |
| with self.assertRaisesRegex(ValueError, 'Structure .* does not match'): |
| defined(rt4) |
| |
| # Different rank |
| rt5 = ragged_factory_ops.constant([[[1]], [[2]], [[3]]]) |
| with self.assertRaisesRegex(ValueError, 'does not match'): |
| defined(rt5) |
| |
| def testInputSignatureWithVariableArgs(self): |
| |
| def f(v): |
| v.assign_add(1) |
| |
| signature = [ |
| resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32) |
| ] |
| defined = function.defun(f, input_signature=signature) |
| |
| v1 = variables.Variable(0) |
| v2 = variables.Variable(0) |
| |
| defined(v1) |
| self.assertEqual(v1.numpy(), 1) |
| self.assertEqual(v2.numpy(), 0) |
| |
| defined(v=v2) |
| self.assertEqual(v1.numpy(), 1) |
| self.assertEqual(v2.numpy(), 1) |
| |
| def testTensorKeywordArguments(self): |
| |
| def foo(a, b): |
| del a |
| return b |
| |
| defined = function.defun(foo) |
| a = constant_op.constant(2.0) |
| b = constant_op.constant([1.0, 2.0]) |
| one = defined(a, b) |
| self.assertLen(total_function_cache(defined), 1) |
| |
| two = defined(a=a, b=b) |
| self.assertLen(total_function_cache(defined), 1) |
| |
| three = defined(b=b, a=a) |
| self.assertLen(total_function_cache(defined), 1) |
| |
| four = defined(a, b=b) |
| self.assertLen(total_function_cache(defined), 1) |
| |
| # The next call corresponds to a new input signature, hence |
| # we expect another function to be defined. |
| five = defined(b, a) |
| self.assertLen(total_function_cache(defined), 2) |
| |
| six = defined(a=b, b=a) |
| self.assertLen(total_function_cache(defined), 2) |
| |
| seven = defined(b=a, a=b) |
| self.assertLen(total_function_cache(defined), 2) |
| |
| self.assertAllEqual(one, [1.0, 2.0]) |
| self.assertAllEqual(two, [1.0, 2.0]) |
| self.assertAllEqual(three, [1.0, 2.0]) |
| self.assertAllEqual(four, [1.0, 2.0]) |
| self.assertAllEqual(five, 2.0) |
| self.assertAllEqual(six, 2.0) |
| self.assertAllEqual(seven, 2.0) |
| |
| def testDefuningInstanceMethod(self): |
| |
| integer = constant_op.constant(2, dtypes.int64) |
| |
| class Foo(object): |
| |
| def one(self, tensor): |
| return tensor |
| |
| @def_function.function |
| def two(self, tensor, other=integer): |
| return self.one(tensor), other |
| |
| foo = Foo() |
| t = constant_op.constant(1.0) |
| one, two = foo.two(t) |
| self.assertEqual(one.numpy(), 1.0) |
| self.assertEqual(two.numpy(), 2) |
| |
| def testDefuningInstanceMethodWithDefaultArgument(self): |
| |
| integer = constant_op.constant(2, dtypes.int64) |
| |
| class Foo(object): |
| |
| @def_function.function |
| def func(self, other=integer): |
| return other |
| |
| foo = Foo() |
| self.assertEqual(foo.func().numpy(), int(integer)) |
| |
| def testPythonCallWithSideEffects(self): |
| state = [] |
| |
| @def_function.function |
| def side_effecting_function(): |
| state.append(0) |
| |
| side_effecting_function() |
| self.assertAllEqual(state, [0]) |
| |
| # The second invocation should call the graph function, which shouldn't |
| # trigger the list append. |
| side_effecting_function() |
| self.assertAllEqual(state, [0]) |
| |
| # Whereas calling the python function directly should create a side-effect. |
| side_effecting_function.python_function() |
| self.assertAllEqual(state, [0, 0]) |
| |
| def testFunctionWithNestedFunctionCallAndSideEffects(self): |
| v1 = variables.Variable(1.0) |
| v2 = variables.Variable(1.0) |
| |
| @def_function.function |
| def add_one(a): |
| a.assign_add(1.0) |
| |
| # Grappler will inline calls to `add_one` into the function body, we check |
| # that all side-effects were executed. |
| @def_function.function |
| def side_effecting_function(a, b): |
| add_one(a) |
| add_one(b) |
| return a + b |
| |
| result = side_effecting_function(v1, v2) |
| self.assertEqual(result.numpy(), 4.0) |
| |
| def testFunctionWithExtraAttributes(self): |
| @function.defun_with_attributes(attributes={'experimental_1': 'value1', |
| 'experimental_2': 2}) |
| def matmul(x, y): |
| return math_ops.matmul(x, y) |
| |
| def add(x, y): |
| return math_ops.add(x, y) |
| defun_add = function.defun_with_attributes( |
| add, attributes={'experimental_3': True, 'experimental_4': 1.0}) |
| |
| with context.graph_mode(), self.cached_session(): |
| with ops.get_default_graph().as_default(): |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| sq = matmul(t, t) |
| double = defun_add(t, t) |
| self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) |
| self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8]) |
| |
| graph = ops.get_default_graph() |
| # pylint: disable=protected-access |
| self.assertLen(graph._functions, 2) |
| functions = list(graph._functions.values()) |
| self.assertRegex(functions[0].definition.signature.name, '.*matmul.*') |
| attrs = functions[0].definition.attr |
| self.assertLen(attrs, 2) |
| self.assertEqual(attrs['experimental_1'].s, b'value1') |
| self.assertEqual(attrs['experimental_2'].i, 2) |
| |
| self.assertRegex(functions[1].definition.signature.name, '.*add.*') |
| attrs = functions[1].definition.attr |
| self.assertLen(attrs, 2) |
| self.assertEqual(attrs['experimental_3'].b, True) |
| self.assertEqual(attrs['experimental_4'].f, 1.0) |
| # pylint: enable=protected-access |
| |
| def testFunctionWithInvalidAttribute(self): |
| @function.defun_with_attributes(attributes={'experimental_1': ['value1']}) |
| def add(x, y): |
| return math_ops.add(x, y) |
| |
| with self.assertRaisesRegex(ValueError, '.*Unsupported attribute type.*'): |
| with context.graph_mode(), self.cached_session(): |
| with ops.get_default_graph().as_default(): |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| add(t, t) |
| |
| def testRegisterFunction(self): |
| |
| @function.defun |
| def add(x, y): |
| return math_ops.add(x, y) |
| |
| def matmul(x, y): |
| return math_ops.matmul(x, y) |
| defun_matmul = function.defun(matmul) |
| |
| with context.graph_mode(), self.cached_session(): |
| with ops.get_default_graph().as_default(): |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| function.register(defun_matmul, t, t) |
| function.register(add, t, t) |
| |
| graph = ops.get_default_graph() |
| # pylint: disable=protected-access |
| self.assertLen(graph._functions, 6) |
| # two sets of functions, each of them are (inference, forward, backward) |
| functions = list(graph._functions.values()) |
| captured_function_names = [ |
| f.definition.signature.name for f in functions |
| ] |
| expected_func_name_regex = [ |
| '.*inference.*matmul.*', |
| '.*forward.*matmul.*', |
| '.*inference.*backward.*matmul.*', |
| '.*inference.*add.*', |
| '.*forward.*add.*', |
| '.*inference.*backward.*add.*', |
| ] |
| for i in range(len(functions)): |
| self.assertRegex(captured_function_names[i], |
| expected_func_name_regex[i]) |
| |
| # Check the forward and backward function has the correct attributes. |
| self.assertEqual( |
| functions[1].definition.attr['backward_function_name'].s, |
| functions[2].name) |
| self.assertEqual( |
| functions[2].definition.attr['forward_function_name'].s, |
| functions[1].name) |
| |
| self.assertEqual( |
| functions[4].definition.attr['backward_function_name'].s, |
| functions[5].name) |
| self.assertEqual( |
| functions[5].definition.attr['forward_function_name'].s, |
| functions[4].name) |
| |
| sq = defun_matmul(t, t) |
| double = add(t, t) |
| self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) |
| self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8]) |
| # Make sure the pre registered function is used, and no other function |
| # is added. |
| self.assertLen(graph._functions, 6) |
| functions = list(graph._functions.values()) |
| for i in range(len(functions)): |
| self.assertEqual(captured_function_names[i], |
| functions[i].definition.signature.name) |
| |
| @parameterized.named_parameters( |
| dict(testcase_name='Defun', |
| function_decorator=function.defun), |
| dict(testcase_name='DefFunction', |
| function_decorator=def_function.function)) |
| def testRegisterConcreteFunction(self, function_decorator): |
| @function_decorator |
| def py_add(x, y): |
| return math_ops.add(x, y) |
| |
| py_add(array_ops.ones([]), array_ops.ones([])) |
| add = py_add.get_concrete_function( |
| tensor_spec.TensorSpec(None, dtypes.float32), |
| tensor_spec.TensorSpec(None, dtypes.float32)) |
| |
| @function_decorator |
| def py_composite(x, y): |
| return x, add(x, y) |
| |
| py_composite(array_ops.ones([]), array_ops.ones([])) |
| composite = py_composite.get_concrete_function( |
| tensor_spec.TensorSpec(None, dtypes.float32), |
| tensor_spec.TensorSpec(None, dtypes.float32)) |
| |
| with context.graph_mode(), self.cached_session(): |
| with ops.get_default_graph().as_default(): |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| composite.add_to_graph() |
| composite.add_gradient_functions_to_graph() |
| |
| graph = ops.get_default_graph() |
| # pylint: disable=protected-access |
| self.assertLen(graph._functions, 6) |
| # two sets of functions, each of them are (inference, forward, backward) |
| functions = list(graph._functions.values()) |
| captured_function_names = [ |
| f.definition.signature.name for f in functions |
| ] |
| expected_func_name_regex = [ |
| '.*inference.*py_composite.*', |
| '.*inference.*py_add.*', |
| '.*forward.*py_composite.*', |
| '.*forward.*py_add.*', |
| '.*inference.*backward.*py_composite.*', |
| '.*inference.*backward.*py_add.*', |
| ] |
| for expected, found in zip( |
| expected_func_name_regex, |
| captured_function_names): |
| self.assertRegex(found, expected) |
| |
| composite_t, composite_double = composite(t, t) |
| double = add(t, t) |
| self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(double)) |
| self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(composite_double)) |
| self.assertAllEqual([[1, 2], [3, 4]], self.evaluate(composite_t)) |
| # Make sure the pre registered function is used, and no other function |
| # is added. |
| self.assertLen(graph._functions, 6) |
| |
| @parameterized.named_parameters( |
| dict(testcase_name='Defun', |
| function_decorator=function.defun), |
| dict(testcase_name='DefFunction', |
| function_decorator=def_function.function)) |
| def testEagerCaptures(self, function_decorator): |
| with context.eager_mode(): |
| large_tensor = array_ops.ones(shape=(256,)) |
| self.assertGreater(256, func_graph._EAGER_CONST_THRESHOLD) |
| |
| small_tensor = array_ops.ones(shape=(4,)) |
| self.assertLessEqual(4, func_graph._EAGER_CONST_THRESHOLD) |
| |
| v = resource_variable_ops.ResourceVariable(0.0) |
| |
| for captured, op_type in [(large_tensor, 'Placeholder'), |
| (small_tensor, 'Const'), (v, 'Placeholder')]: |
| @function_decorator |
| def test_fn(): |
| return captured + 1 # pylint: disable=cell-var-from-loop |
| |
| g = test_fn.get_concrete_function().graph |
| internal_captures = g.internal_captures |
| self.assertLen(internal_captures, 1) |
| self.assertEqual(internal_captures[0].op.type, op_type) |
| |
| def testRegisterFunctionWithInputSignature(self): |
| def matmul(x, y): |
| return math_ops.matmul(x, y) |
| defun_matmul = function.defun( |
| matmul, |
| input_signature=[ |
| tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32), |
| tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32) |
| ]) |
| with context.graph_mode(), self.cached_session(): |
| with ops.get_default_graph().as_default(): |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| function.register(defun_matmul, t, t) |
| |
| graph = ops.get_default_graph() |
| # pylint: disable=protected-access |
| self.assertLen(graph._functions, 3) |
| |
| # Test register function with cache, note inputs are ignored. |
| function.register(defun_matmul) |
| graph = ops.get_default_graph() |
| self.assertLen(graph._functions, 3) |
| |
| def testRegisterFunctionWithCache(self): |
| def matmul(x, y): |
| return math_ops.matmul(x, y) |
| defun_matmul = function.defun(matmul) |
| |
| with context.graph_mode(), self.cached_session(): |
| with ops.get_default_graph().as_default(): |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]]) |
| function.register(defun_matmul, t, t) |
| function.register(defun_matmul, t2, t2) |
| |
| graph = ops.get_default_graph() |
| # Only one function is registered since the input param are in same type |
| # pylint: disable=protected-access |
| self.assertLen(graph._functions, 3) |
| |
| def testCallingFunctionWithDifferentVariables(self): |
| |
| @function.defun |
| def foo(v): |
| v.assign_add(1.0) |
| return v.read_value() |
| |
| v = resource_variable_ops.ResourceVariable(0.0) |
| graph_function = foo.get_concrete_function(v) |
| self.assertLen(graph_function.inputs, 1) |
| self.assertEmpty(graph_function.captured_inputs) |
| |
| self.assertEqual(float(graph_function(v)), 1.0) |
| self.assertEqual(float(graph_function(v)), 2.0) |
| |
| w = resource_variable_ops.ResourceVariable(0.0) |
| |
| @function.defun |
| def bar(v): |
| del v |
| return constant_op.constant(1.0) |
| |
| graph_function = bar.get_concrete_function(v) |
| self.assertEqual(float(graph_function(v)), 1.0) |
| self.assertEqual(float(graph_function(w)), 1.0) |
| |
| def testCallingFunctionWithNonTensorsFails(self): |
| |
| @function.defun |
| def foo(x): |
| return x |
| |
| graph_function = foo.get_concrete_function(constant_op.constant(1.0)) |
| with self.assertRaises((TypeError, ValueError)): |
| graph_function('Not a Tensor.') |
| |
| def testSwapImplementationWithGrapplerPlugin(self): |
| # Set the min_graph_nodes to -1 since the graph in this test is too small, |
| # and will be ignored by grappler if don't set this. |
| rewrites = rewriter_config_pb2.RewriterConfig() |
| rewrites.implementation_selector = rewriter_config_pb2.RewriterConfig.ON |
| rewrites.min_graph_nodes = -1 |
| graph_options = config_pb2.GraphOptions( |
| rewrite_options=rewrites, build_cost_model=1) |
| config_proto = config_pb2.ConfigProto(graph_options=graph_options) |
| |
| with context.graph_mode(), self.cached_session( |
| config=config_proto, graph=ops.Graph(), use_gpu=True): |
| |
| @function.defun_with_attributes( |
| attributes={ |
| 'api_implements': 'random_boost', |
| 'api_preferred_device': 'CPU' |
| }) |
| def cpu_boost(x): |
| return math_ops.add(x, 2.0) |
| |
| @function.defun_with_attributes( |
| attributes={ |
| 'api_implements': 'random_boost', |
| 'api_preferred_device': 'GPU' |
| }) |
| def gpu_boost(x): |
| return math_ops.add(x, 4.0) |
| |
| x = constant_op.constant(1.0) |
| |
| function.register(cpu_boost, x) |
| y = gpu_boost(x) |
| y_value = self.evaluate(y) |
| |
| if test.is_gpu_available(): |
| self.assertEqual(y_value, 5.0) |
| else: |
| # Grappler fallback to use the CPU impl even called with GPU function. |
| self.assertEqual(y_value, 3.0) |
| |
| def testSwapImplementationInEager(self): |
| if not context.executing_eagerly(): |
| self.skipTest('eager only') |
| |
| # testSharedRendezvous sets the disable_meta_optimizer flag to True |
| # if that subtest runs before this one, then having that set to True |
| # will cause this subtest to fail. To avoid that scenario, explicitly |
| # set the disable_meta_optimizer flag to false here |
| context.context().set_optimizer_experimental_options({ |
| 'min_graph_nodes': -1, |
| 'implementation_selector': True, |
| 'disable_meta_optimizer': False |
| }) |
| |
| @function.defun_with_attributes( |
| attributes={'api_implements': 'foo', |
| 'api_preferred_device': 'CPU'}) |
| def on_cpu(x): |
| return x + 2 |
| |
| @function.defun_with_attributes( |
| attributes={'api_implements': 'foo', |
| 'api_preferred_device': 'GPU'}) |
| def on_gpu(x): |
| return x + 4 |
| |
| @function.defun |
| def run_on_cpu(t): |
| function.register(on_cpu, t) |
| with ops.device('CPU:0'): |
| return on_gpu(t) |
| |
| # Expect to run the on_cpu branch, regardless whether gpu is available. |
| self.assertEqual(run_on_cpu(constant_op.constant(1)).numpy(), 3) |
| |
| def testDefunFunctionSeparateGraphs(self): |
| with context.graph_mode(): |
| |
| @function.defun |
| def add(x): |
| return x + 5 |
| |
| @function.defun |
| def maybe_add(x, should_add): |
| if should_add: |
| return add(x) |
| else: |
| return x |
| |
| with ops.Graph().as_default(): |
| x = constant_op.constant(11) |
| maybe_add(x, True) |
| self.assertLen(total_function_cache(maybe_add), 1) |
| self.assertLen(total_function_cache(add), 1) |
| |
| maybe_add(x, False) |
| self.assertLen(total_function_cache(maybe_add), 2) |
| self.assertLen(total_function_cache(add), 1) |
| |
| with ops.Graph().as_default(): |
| x = constant_op.constant(11) |
| maybe_add(x, True) |
| self.assertLen(total_function_cache(maybe_add), 3) |
| self.assertLen(total_function_cache(add), 2) |
| |
| def testCacheKeyOverlappingShapes(self): |
| @function.defun |
| def defined(t): |
| return t |
| |
| defined(array_ops.zeros([12, 1])) |
| self.assertLen(total_function_cache(defined), 1) |
| |
| defined(array_ops.zeros([1, 21])) |
| self.assertLen(total_function_cache(defined), 2) |
| |
| def testCacheKeyNestedLists(self): |
| @function.defun |
| def defined(l): |
| return l |
| |
| a = constant_op.constant(1.) |
| b = constant_op.constant(2.) |
| c = constant_op.constant(3.) |
| defined([[a], b, c]) |
| self.assertLen(total_function_cache(defined), 1) |
| |
| defined([[a, b], c]) |
| self.assertLen(total_function_cache(defined), 2) |
| |
| def testCacheKeyAttrsClass(self): |
| if attr is None: |
| self.skipTest('attr module is unavailable.') |
| |
| @attr.s |
| class TestClass(object): |
| a = attr.ib() |
| b = attr.ib() |
| |
| @function.defun |
| def defined(l): |
| return l |
| |
| defined( |
| TestClass( |
| constant_op.constant(1.), |
| [constant_op.constant(2.), |
| constant_op.constant(3.)])) |
| self.assertLen(total_function_cache(defined), 1) |
| defined( |
| TestClass( |
| constant_op.constant(1.), |
| [constant_op.constant(2.), |
| constant_op.constant(3.)])) |
| self.assertLen(total_function_cache(defined), 1) |
| |
| defined( |
| TestClass([constant_op.constant(1.), |
| constant_op.constant(2.)], constant_op.constant(3.))) |
| self.assertLen(total_function_cache(defined), 2) |
| |
| def testCacheKeyVariables(self): |
| @function.defun |
| def defined(a, b, c): |
| return a + b + c |
| |
| x = resource_variable_ops.ResourceVariable(0.0) |
| y = resource_variable_ops.ResourceVariable(0.0) |
| z = resource_variable_ops.ResourceVariable(0.0) |
| |
| # If tensor equality is not enabled, we always get a cache miss if the |
| # function is called with different variables. With equality enabled we |
| # should only get a miss if the aliasing changed. |
| defined(x, y, z) |
| self.assertLen(total_function_cache(defined), 1) |
| defined(x, y, z) |
| self.assertLen(total_function_cache(defined), 1) |
| |
| # Re-arranging arguments causes cache miss |
| defined(z, y, x) |
| self.assertLen(total_function_cache(defined), 2) |
| defined(z, y, x) |
| self.assertLen(total_function_cache(defined), 2) |
| |
| # Aliasing causes cache miss |
| defined(x, x, z) |
| self.assertLen(total_function_cache(defined), 3) |
| defined(x, x, z) |
| self.assertLen(total_function_cache(defined), 3) |
| |
| # Re-arranging arguments causes cache miss |
| defined(y, y, z) |
| self.assertLen(total_function_cache(defined), 4) |
| defined(y, y, z) |
| self.assertLen(total_function_cache(defined), 4) |
| |
| # Different alias positions causes cache miss |
| defined(z, y, y) |
| self.assertLen(total_function_cache(defined), 5) |
| defined(z, y, y) |
| self.assertLen(total_function_cache(defined), 5) |
| |
| x_copy = copy.deepcopy(x) |
| |
| # Deep copy causes cache miss |
| defined(x_copy, y, z) |
| self.assertLen(total_function_cache(defined), 6) |
| defined(x_copy, y, z) |
| self.assertLen(total_function_cache(defined), 6) |
| |
| def testVariableRetracing(self): |
| v1 = variables.Variable(1.) |
| v2 = variables.Variable(1.) |
| v3 = copy.deepcopy(variables.Variable(1.)) |
| |
| var_dict = {id(v1): constant_op.constant(1), |
| id(v2): constant_op.constant(2), |
| id(v3): constant_op.constant(3)} |
| |
| @function.defun |
| def lookup_tensor(v): |
| return var_dict[id(v)] |
| |
| self.assertEqual(1, lookup_tensor(v1).numpy()) |
| self.assertEqual(2, lookup_tensor(v2).numpy()) |
| self.assertEqual(3, lookup_tensor(v3).numpy()) |
| |
| def testDecoratedMethodInspect(self): |
| |
| class DefunnedMiniModel(object): |
| |
| @function.defun |
| def call(self, inputs, training=True): |
| pass |
| |
| m = DefunnedMiniModel() |
| fullargspec = tf_inspect.getfullargspec(m.call) |
| self.assertIn('training', fullargspec.args) |
| |
| def testFunctionModifiesInputList(self): |
| # Tests on `list` methods that do in place modification, except `list.sort` |
| # since it cannot even be "defunned" in the first place |
| |
| def get_list(): |
| return [constant_op.constant(0.), constant_op.constant(1.)] |
| |
| expected_msg = '.*() should not modify' |
| |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| |
| @def_function.function |
| def append(l): |
| l.append(constant_op.constant(0.)) |
| |
| append(get_list()) |
| |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| |
| @def_function.function |
| def extend(l): |
| l.extend([constant_op.constant(0.)]) |
| |
| extend(get_list()) |
| |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| |
| @def_function.function |
| def insert(l): |
| l.insert(0, constant_op.constant(0.)) |
| |
| insert(get_list()) |
| |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| |
| @def_function.function |
| def pop(l): |
| l.pop() |
| |
| pop(get_list()) |
| |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| |
| @def_function.function |
| def reverse(l): |
| l.reverse() |
| |
| reverse(get_list()) |
| |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| |
| @def_function.function |
| def remove(l): |
| l.remove(l[0]) |
| |
| remove(get_list()) |
| |
| # `list.clear` is a method that is in Py3 but not Py2 |
| if sys.version.startswith('3'): |
| |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| |
| @def_function.function |
| def clear(l): |
| l.clear() |
| |
| clear(get_list()) |
| |
| # One last test for keyword arguments |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| |
| @def_function.function |
| def kwdappend(**kwargs): |
| l = kwargs['l'] |
| l.append(constant_op.constant(0.)) |
| |
| kwdappend(l=get_list()) |
| |
| def testFunctionModifiesInputDict(self): |
| |
| def get_dict(): |
| return {'t1': constant_op.constant(0.), 't2': constant_op.constant(1.)} |
| |
| expected_msg = '.* should not modify' |
| |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| |
| @def_function.function |
| def clear(m): |
| m.clear() |
| |
| clear(get_dict()) |
| |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| |
| @def_function.function |
| def pop(m): |
| m.pop('t1') |
| |
| pop(get_dict()) |
| |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| |
| @def_function.function |
| def popitem(m): |
| m.popitem() |
| |
| popitem(get_dict()) |
| |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| |
| @def_function.function |
| def update(m): |
| m.update({'t1': constant_op.constant(3.)}) |
| |
| update(get_dict()) |
| |
| with self.assertRaisesRegex(ValueError, expected_msg): |
| |
| @def_function.function |
| def setdefault(m): |
| m.setdefault('t3', constant_op.constant(3.)) |
| |
| setdefault(get_dict()) |
| |
| def testFunctionModifiesInputNest(self): |
| with self.assertRaisesRegex(ValueError, 'modify.* should not modify'): |
| |
| @def_function.function |
| def modify(n): |
| n[0]['t1'].append(constant_op.constant(1.)) |
| |
| nested_input = [{ |
| 't1': [constant_op.constant(0.), |
| constant_op.constant(1.)], |
| }, |
| constant_op.constant(2.)] |
| |
| modify(nested_input) |
| |
| with self.assertRaisesRegex(ValueError, |
| 'modify_same_flat.* should not modify'): |
| |
| # The flat list doesn't change whereas the true structure changes |
| @def_function.function |
| def modify_same_flat(n): |
| n[0].append(n[1].pop(0)) |
| |
| nested_input = [[constant_op.constant(0.)], |
| [constant_op.constant(1.), |
| constant_op.constant(2.)]] |
| |
| modify_same_flat(nested_input) |
| |
| def testExecutorType(self): |
| @function.defun |
| def add_five(x): |
| return x + 5 |
| |
| self.assertEqual( |
| 5, |
| add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy()) |
| |
| with self.assertRaisesRegex(errors.NotFoundError, 'NON_EXISTENT_EXECUTOR'): |
| with context.function_executor_type('NON_EXISTENT_EXECUTOR'): |
| add_five(constant_op.constant(0, dtype=dtypes.int32)) |
| |
| for executor_type in ('', 'DEFAULT', None): |
| with context.function_executor_type(executor_type): |
| self.assertAllEqual( |
| 5, |
| add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy()) |
| |
| @test_util.assert_no_garbage_created |
| def testReferenceCycles(self): |
| |
| fn = function.defun(lambda x: 2. * x) |
| |
| fn(constant_op.constant(4.0)) |
| weak_fn = weakref.ref(fn) |
| del fn |
| # Tests that the weak reference we made to the function is now dead, which |
| # means the object has been deleted. This should be true as long as the |
| # function itself is not involved in a reference cycle. |
| self.assertIs(None, weak_fn()) |
| |
| def testFunctionStackInErrorMessage(self): |
| if context.executing_eagerly(): |
| # TODO(b/122736651): Remove this skipTest once fixed. |
| self.skipTest('Error interpolation is not working when function is ' |
| 'invoked without PartitionedCallOp.') |
| |
| @def_function.function() |
| def fn3(x): |
| return x + 2 |
| |
| @def_function.function() |
| def fn2(x): |
| check_ops.assert_equal(fn3(x), 3) |
| return 2 |
| |
| @def_function.function() |
| def fn(x): |
| return fn2(x) |
| |
| with self.assertRaises(errors.InvalidArgumentError) as cm: |
| fn(2) |
| e = cm.exception |
| self.assertIn('fn -> fn2', e.message) |
| self.assertIn('node assert_equal/Assert/Assert (defined at', e.message) |
| self.assertNotIn('fn3', e.message) |
| |
| @test_util.run_gpu_only |
| def testFunctionIsNotPinned(self): |
| """Tests that functions aren't pinned to the CPU by the eager runtime.""" |
| seed1, seed2 = 79, 25 |
| shape = constant_op.constant([4, 7]) |
| dtype = dtypes.float32 |
| |
| @def_function.function |
| def func(): |
| with ops.device('GPU:0'): |
| return gen_random_ops.random_standard_normal( |
| shape, dtype=dtype, seed=seed1, seed2=seed2) |
| |
| with ops.device('GPU:0'): |
| x = func() |
| self.assertRegex(x.device, 'GPU') |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testShapeCaching(self): |
| |
| @function.defun |
| def func(x): |
| return array_ops.shape(x) |
| |
| @function.defun( |
| input_signature=[tensor_spec.TensorSpec([None, None], dtypes.float32)]) |
| def calls_func(x): |
| return func(x) |
| |
| self.assertAllEqual([1, 1], self.evaluate(func(array_ops.zeros([1, 1])))) |
| self.assertAllEqual([2, 2], self.evaluate(func(array_ops.zeros([2, 2])))) |
| self.assertAllEqual( |
| [3, 3], |
| self.evaluate(calls_func(array_ops.zeros([3, 3])))) |
| |
| def testLimitedRetracing(self): |
| trace_count = [0] |
| @function.defun |
| def func(x): |
| trace_count[0] += 1 |
| return x |
| |
| for _ in range(50): |
| func(constant_op.constant(3.)) |
| func(constant_op.constant(4.)) |
| func(constant_op.constant([[1., 2.]])) |
| func(constant_op.constant([[]])) |
| func(constant_op.constant([[3., 4.], [5., 6.]])) |
| func(constant_op.constant([[3., 4.], [5., 6.], [7., 8.]])) |
| # Tracing more than twice per input doesn't make sense. |
| self.assertLess(trace_count[0], 13) |
| |
| def testLimitedRetracingWithCompositeTensors(self): |
| trace_count = [0] |
| |
| @def_function.function |
| def f(x): |
| trace_count[0] += 1 |
| return x |
| |
| for i in range(10): |
| f(ragged_factory_ops.constant([[1, 2], [i]])) |
| f(ragged_factory_ops.constant([[1, 2], [], [3, 4, 5]])) |
| f(ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]])) |
| self.assertEqual(trace_count[0], 3) |
| |
| def test_concrete_function_shape_mismatch(self): |
| |
| @def_function.function |
| def f(argument_name): |
| return argument_name + 1. |
| |
| f_concrete = f.get_concrete_function(constant_op.constant([1.])) |
| |
| # Calling a function from eager doesn't do any shape checking above what |
| # kernels do while executing. |
| self.assertAllEqual( |
| [2., 3.], |
| f_concrete(constant_op.constant([1., 2.])).numpy()) |
| |
| @def_function.function |
| def g(): |
| f_concrete(constant_op.constant([1., 2.])) |
| |
| with self.assertRaisesRegex(ValueError, 'argument_name'): |
| g() |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_shape_inference_with_symbolic_shapes(self): |
| |
| @def_function.function |
| def _uses_symbolic_shapes(w, x, y): |
| x = array_ops.identity(x, name='name_collision') |
| x = array_ops.transpose(x, [1, 0, 2]) |
| x_batch = array_ops.shape(x)[0] |
| y_batch = array_ops.shape(y)[0] |
| y *= w |
| n = y_batch // x_batch |
| return array_ops.reshape(y, [n, x_batch, -1]) |
| |
| conc = _uses_symbolic_shapes.get_concrete_function( |
| tensor_spec.TensorSpec(None, dtypes.float32), |
| tensor_spec.TensorSpec(None, dtypes.float32), |
| tensor_spec.TensorSpec(None, dtypes.float32)) |
| |
| @def_function.function |
| def _call_concrete(): |
| c = constant_op.constant(1.) |
| array_ops.identity(c, name='name_collision') |
| output1 = conc(array_ops.ones([2]), |
| array_ops.ones([5, 4, 2]), |
| array_ops.ones([20, 2])) |
| self.assertEqual([5, 4, 2], output1.shape) |
| output2 = conc(array_ops.ones([3]), |
| array_ops.ones([5, 4, 3]), |
| array_ops.ones([40, 3])) |
| self.assertEqual([10, 4, 3], output2.shape) |
| return output1, output2 |
| |
| output1, output2 = _call_concrete() |
| self.assertEqual((5, 4, 2), self.evaluate(output1).shape) |
| self.assertEqual((10, 4, 3), self.evaluate(output2).shape) |
| |
| def testAutoGraphContext(self): |
| |
| @def_function.function |
| def test_fn(): |
| self.assertEqual( |
| ag_ctx.control_status_ctx().status, ag_ctx.Status.ENABLED) |
| |
| prev_status = ag_ctx.control_status_ctx().status |
| test_fn() |
| self.assertEqual(ag_ctx.control_status_ctx().status, prev_status) |
| |
| @test_util.disable_tfrt('b/170435618') |
| def testCancelBeforeFunctionExecution(self): |
| if not context.executing_eagerly(): |
| self.skipTest('eager only') |
| |
| q = data_flow_ops.FIFOQueue(1, dtypes.int32) |
| |
| @def_function.function |
| def f(): |
| return q.dequeue() |
| |
| c_mgr = cancellation.CancellationManager() |
| cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function()) |
| |
| c_mgr.start_cancel() |
| with self.assertRaises(errors.CancelledError): |
| cancelable_func() |
| |
| @test_util.disable_tfrt('b/170435618') |
| def testCancelBlockedFunctionExecution(self): |
| if not context.executing_eagerly(): |
| self.skipTest('eager only') |
| |
| q = data_flow_ops.FIFOQueue(1, dtypes.int32) |
| |
| @def_function.function |
| def f(): |
| return q.dequeue() |
| |
| c_mgr = cancellation.CancellationManager() |
| cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function()) |
| |
| def cancel_thread(): |
| time.sleep(0.5) |
| c_mgr.start_cancel() |
| |
| t = self.checkedThread(cancel_thread) |
| t.start() |
| with self.assertRaises(errors.CancelledError): |
| cancelable_func() |
| t.join() |
| |
| @test_util.disable_tfrt('b/170435618') |
| def testCancelAfterFunctionExecution(self): |
| if not context.executing_eagerly(): |
| self.skipTest('eager only') |
| |
| q = data_flow_ops.FIFOQueue(1, dtypes.int32) |
| q.enqueue(37) |
| |
| @def_function.function |
| def f(): |
| return q.dequeue() |
| |
| c_mgr = cancellation.CancellationManager() |
| cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function()) |
| |
| self.assertAllEqual(37, cancelable_func().numpy()) |
| |
| # Cancellation after the function executes is a no-op. |
| c_mgr.start_cancel() |
| |
| def testAddFunctionCallback(self): |
| functions = [] |
| def function_callback(f): |
| functions.append(f) |
| |
| @def_function.function |
| def plus_one(x): |
| return x + 1 |
| |
| try: |
| function.add_function_callback(function_callback) |
| x_float32 = numpy.array(3.0, dtype=numpy.float32) |
| self.assertAllClose(plus_one(x_float32), 4.0) |
| self.assertLen(functions, 1) |
| # Function is already created. Executing it again should not invoke the |
| # function callback. |
| self.assertAllClose(plus_one(x_float32), 4.0) |
| self.assertLen(functions, 1) |
| # Signature change leads to a new Function being built. |
| x_float64 = numpy.array(3.0, dtype=numpy.float64) |
| self.assertAllClose(plus_one(x_float64), 4.0) |
| self.assertLen(functions, 2) |
| finally: |
| function.clear_function_callbacks() |
| |
| def testRemoveFunctionCallback(self): |
| functions_1 = [] |
| def function_callback_1(f): |
| functions_1.append(f) |
| |
| functions_2 = [] |
| def function_callback_2(f): |
| functions_2.append(f) |
| |
| @def_function.function |
| def plus_one(x): |
| return x + 1 |
| |
| try: |
| function.add_function_callback(function_callback_1) |
| function.add_function_callback(function_callback_2) |
| self.assertAllClose(plus_one(numpy.array(3.0, dtype=numpy.float32)), 4.0) |
| self.assertLen(functions_1, 1) |
| self.assertLen(functions_2, 1) |
| function.remove_function_callback(function_callback_1) |
| # The 1st callback should not be invokved after remove_function_callback() |
| # is called. |
| self.assertAllClose(plus_one(numpy.array(3.0, dtype=numpy.float64)), 4.0) |
| self.assertLen(functions_1, 1) |
| self.assertLen(functions_2, 2) |
| finally: |
| function.clear_function_callbacks() |
| |
| def testClearFunctionCallbacks(self): |
| function.add_function_callback(lambda f: None) |
| function.add_function_callback(lambda f: None) |
| self.assertLen(function._function_callbacks, 2) |
| function.clear_function_callbacks() |
| self.assertEmpty(function._function_callbacks) # pylint:disable=protected-access |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testConcreteFunctionWithNestedTensorInputs(self): |
| |
| @def_function.function |
| def f(x, y): |
| return (x['a'] + x['b'], y[0] + y[1]) |
| |
| a = constant_op.constant(1000) |
| b = constant_op.constant(200) |
| c = constant_op.constant(30) |
| d = {'a': a, 'b': b} |
| e = (c, 4) |
| |
| # Test different argument signatures when constructing the concrete func. |
| for cf in [ |
| f.get_concrete_function(d, e), |
| f.get_concrete_function(d, y=e), |
| f.get_concrete_function(y=e, x=d), |
| f.get_concrete_function(_spec_for_value(d), _spec_for_value(e)), |
| f.get_concrete_function(_spec_for_value(d), y=_spec_for_value(e)), |
| f.get_concrete_function(y=_spec_for_value(e), x=_spec_for_value(d)) |
| ]: |
| # Test different calling conventions when calling the concrete func. |
| for output in [ |
| cf(d, e), # structured signature |
| cf(d, y=e), # structured signature w/ kwarg |
| cf(y=e, x=d), # structured signature w/ 2 kwargs |
| cf(a, b, c), # flat signature |
| cf(x=a, x_1=b, y=c) # flat signature w/ kwargs |
| ]: |
| self.assertIsInstance(output, tuple) |
| self.assertLen(output, 2) |
| self.assertAllEqual(output[0], 1200) |
| self.assertAllEqual(output[1], 34) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testConcreteFunctionWithNestedNonTensorInputs(self): |
| |
| @def_function.function |
| def f(x, y): |
| return (x['a'] + x['b'], y[0] + y[1]) |
| |
| a = {'a': constant_op.constant(1000), 'b': constant_op.constant(200)} |
| b = (50, 3) |
| |
| for cf in [ # argument y is bound to non-Tensor value (50, 3). |
| f.get_concrete_function(a, b), |
| f.get_concrete_function(a, y=b), |
| f.get_concrete_function(x=a, y=b) |
| ]: |
| for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]: |
| self.assertAllEqual(output[0] + output[1], 1253) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testConcreteFunctionWithBoundNestedNonTensorInputs(self): |
| |
| @def_function.function |
| def f(x, y): |
| return (x['a'] + x['b'], y[0] + y[1]) |
| |
| a = {'a': 3000, 'b': 200, 'c': 9000} |
| b = (constant_op.constant(30), 4) |
| |
| for cf in [ # argument x is bound to non-tensor value `a` |
| f.get_concrete_function(a, b), |
| f.get_concrete_function(a, y=b), |
| f.get_concrete_function(x=a, y=b) |
| ]: |
| for output in [cf(a, b), cf(a, y=b), cf(y=b), cf(x=a, y=b)]: |
| self.assertAllEqual(output[0] + output[1], 3234) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testConcreteFunctionWithAllBoundNestedNonTensorInputs(self): |
| |
| @def_function.function |
| def f(x, y): |
| return (x['a'] + x['b'], y[0] + y[1]) |
| |
| a = {'a': 5000, 'b': 500} |
| b = (50, 5) |
| |
| cf = f.get_concrete_function(a, b) |
| for output in [cf(), cf(a), cf(y=b)]: |
| self.assertAllEqual(output[0] + output[1], 5555) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testConcreteFunctionMethodWithVarargs(self): |
| float32_scalar = tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32) |
| |
| class MyModel(module.Module): |
| |
| @def_function.function(input_signature=[float32_scalar, float32_scalar]) |
| def add(self, *arg): |
| return math_ops.add(*arg) |
| |
| m = MyModel() |
| cf = m.add.get_concrete_function() |
| cf(-12.0, 3.0) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testConcreteFunctionStructuredSignatureKeywordOrder(self): |
| # Check that keyword-only arguments are sorted appropriately, so that they |
| # feed the right tensor into each input. |
| @def_function.function |
| def g(**kwargs): |
| return string_ops.reduce_join( |
| string_ops.reduce_join( |
| ops.convert_to_tensor(sorted(kwargs.items())), |
| axis=1, |
| separator='='), |
| axis=0, |
| separator=', ') |
| |
| s = constant_op.constant('s') |
| g.get_concrete_function(q=s, a=s, p=s, r=s, v=s, m=s, l=s) |
| self.assertAllEqual( |
| g(m='a', r='b', v='c', q='d', l='e', a='f', p='g'), |
| b'a=f, l=e, m=a, p=g, q=d, r=b, v=c') |
| self.assertAllEqual( |
| g(q='d', a='f', p='g', r='b', v='c', m='a', l='e'), |
| b'a=f, l=e, m=a, p=g, q=d, r=b, v=c') |
| self.assertAllEqual( |
| g(a='f', l='e', m='a', p='g', q='d', r='b', v='c'), |
| b'a=f, l=e, m=a, p=g, q=d, r=b, v=c') |
| |
| # pylint: disable=g-long-lambda |
| @parameterized.named_parameters([ |
| dict( |
| testcase_name='MissingArg', |
| conc_args=lambda: (1, constant_op.constant(2)), |
| call_args=lambda: (1,), |
| error=r'func\(x, y\) missing required arguments: y'), |
| dict( |
| testcase_name='MissingVararg', |
| conc_args=lambda: (1, 2, constant_op.constant(1.0)), |
| call_args=lambda: (1, 2), |
| error=r'func\(x, y, <arg3>\) missing required arguments: <arg3>'), |
| dict( |
| testcase_name='ExtraPositionalArg', |
| conc_args=lambda: (1, 2), |
| call_args=lambda: (1, 2, 3), |
| error=r'func\(x, y\) takes 2 positional arguments but 3 were given'), |
| dict( |
| testcase_name='MissingKeywordOnlyArg', |
| conc_args=lambda: (1, 2), |
| conc_kwargs=lambda: {'c': constant_op.constant(1.0)}, |
| call_args=lambda: (1, 2), |
| error=r'func\(x, y, \*, c\) missing required arguments: c'), |
| dict( |
| testcase_name='ExtraKeywordArg', |
| conc_args=lambda: (1, 2), |
| call_args=lambda: (1, 2), |
| call_kwargs=lambda: {'c': constant_op.constant(1.0)}, |
| error=r'func\(x, y\) got unexpected keyword arguments: c'), |
| dict( |
| testcase_name='ExpectedRaggedGotNest', |
| conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),), |
| call_args=lambda: ({ |
| 'a': constant_op.constant([1, 2, 3]) |
| },), |
| error=r'func\(x, y\): argument x had incorrect type\n' |
| r' expected: RaggedTensor\n' |
| r" got: {'a': (Eager)?Tensor}"), |
| dict( |
| testcase_name='WrongRaggedRank', |
| conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),), |
| call_args=lambda: (ragged_factory_ops.constant([[[1]]]),), |
| error=r'func\(x, y\): argument x had incorrect type\n'), |
| dict( |
| testcase_name='WrongRaggedDType', |
| conc_args=lambda: (ragged_factory_ops.constant([[1]]),), |
| call_args=lambda: (ragged_factory_ops.constant([[1.0]]),), |
| error=r'func\(x, y\): argument x had incorrect type\n'), |
| dict( |
| testcase_name='ExpectedDictGotTensor', |
| conc_args=lambda: ({ |
| 'a': constant_op.constant(1), |
| 'b': constant_op.constant(1) |
| },), |
| call_args=lambda: (constant_op.constant(1),), |
| error=r'func\(x, y\): argument x had incorrect type\n'), |
| dict( |
| testcase_name='ExpectedTupleGotTensor', |
| conc_args=lambda: |
| ((constant_op.constant(1), constant_op.constant(2)),), |
| call_args=lambda: (constant_op.constant(1),), |
| error=r'func\(x, y\): argument x had incorrect type\n'), |
| dict( |
| testcase_name='WrongDType', |
| conc_args=lambda: (constant_op.constant(1),), |
| call_args=lambda: (constant_op.constant(1.0),), |
| exception=(ValueError, errors.InvalidArgumentError, |
| # on xla_gpu, we get InternalError instead. |
| errors.InternalError)), |
| dict( |
| testcase_name='ExpectedTensorGotInt', |
| conc_args=lambda: (constant_op.constant(1),), |
| call_args=lambda: (5,), |
| error=r'func\(x, y\) expected a Tensor in x, but got int value 5'), |
| dict( |
| testcase_name='ExpectedIntGotDifferentInt', |
| conc_args=lambda: (5,), |
| call_args=lambda: (8,), |
| error=r'ConcreteFunction func\(x, y\) was constructed with int ' |
| r'value 5 in x, but was called with int value 8'), |
| dict( |
| testcase_name='ExpectedIntGotTensor', |
| conc_args=lambda: (5,), |
| call_args=lambda: (constant_op.constant(6),), |
| error=r'ConcreteFunction func\(x, y\) was constructed with int ' |
| 'value 5 in x, but was called with (Eager)?Tensor value .*'), |
| dict( |
| testcase_name='TwoValuesForArgument', |
| conc_args=lambda: (1, 2), |
| call_args=lambda: (1, 2), |
| call_kwargs=lambda: {'x': 3}, |
| error=r"func\(x, y\) got two values for argument 'x'"), |
| ]) |
| # pylint: enable=g-long-lambda |
| @test_util.run_in_graph_and_eager_modes |
| def testConcreteFunctionStructuredSignatureError(self, |
| conc_args=(), |
| conc_kwargs=None, |
| call_args=(), |
| call_kwargs=None, |
| error='.*', |
| exception=TypeError): |
| """Tests for errors in the structrued signature. |
| |
| Args: |
| conc_args: Positional arguments used for get_concrete_function. |
| conc_kwargs: Keyword arguments used for get_concrete_function. |
| call_args: Positional arguments used to call the function. |
| call_kwargs: Keyword arguments used to call the function. |
| error: Expected exception message. |
| exception: Expected exception type. |
| """ |
| conc_args = conc_args() if callable(conc_args) else conc_args |
| conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {} |
| call_args = call_args() if callable(call_args) else call_args |
| call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {} |
| self.assertIsInstance(conc_args, tuple) |
| self.assertIsInstance(call_args, tuple) |
| self.assertIsInstance(conc_kwargs, dict) |
| self.assertIsInstance(call_kwargs, dict) |
| |
| @def_function.function |
| def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg |
| del y, varargs, kwargs |
| return x |
| |
| conc = func.get_concrete_function(*conc_args, **conc_kwargs) |
| with self.assertRaisesRegex(exception, error): |
| self.evaluate(conc(*call_args, **call_kwargs)) |
| |
| # pylint: disable=g-long-lambda |
| @parameterized.named_parameters([ |
| dict( |
| testcase_name='MissingArg', |
| conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), |
| call_args=lambda: (constant_op.constant(1),), |
| error=r'func\(x, y\) missing required arguments: y'), |
| dict( |
| testcase_name='TwoValuesForArg', |
| conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), |
| call_args=lambda: (constant_op.constant(1),), |
| call_kwargs=lambda: { |
| 'x': constant_op.constant(1), |
| 'y': constant_op.constant(1) |
| }, |
| error=r"func\(x, y\) got two values for argument 'x'"), |
| dict( |
| testcase_name='ExtraPositionalArg', |
| conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), |
| call_args=lambda: (constant_op.constant(1), constant_op.constant(2), |
| constant_op.constant(3)), |
| error=r'func\(x, y\) takes 2 positional arguments but 3 were given'), |
| dict( |
| testcase_name='UnexpectedKeywordArg', |
| conc_args=lambda: (constant_op.constant(1),), |
| call_args=lambda: (constant_op.constant(1),), |
| call_kwargs=lambda: {'c': constant_op.constant(1)}, |
| error=r'func\(x\) got unexpected keyword arguments: c'), |
| dict( |
| testcase_name='MissingVararg', |
| conc_args=lambda: (constant_op.constant(1), constant_op.constant(2), |
| constant_op.constant(3)), |
| call_args=lambda: (constant_op.constant(1), constant_op.constant(2)), |
| error=r'func\(x, y, varargs_0\) missing required ' |
| r'arguments: varargs_0'), |
| dict( |
| testcase_name='MissingKeywordArg', |
| conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), |
| conc_kwargs=lambda: {'c': constant_op.constant(1)}, |
| call_args=lambda: (constant_op.constant(1), constant_op.constant(2)), |
| error=r'func\(x, y, c\) missing required arguments: c'), |
| dict( |
| testcase_name='ExpectedTensorGotInt', |
| conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), |
| call_args=lambda: (5, constant_op.constant(2)), |
| error=r'func\(x, y\): expected argument #0\(zero-based\) to be ' |
| r'a Tensor; got int \(5\)'), |
| dict( |
| testcase_name='WrongDType', |
| conc_args=lambda: (constant_op.constant(1),), |
| call_args=lambda: (constant_op.constant(1.0),), |
| exception=(ValueError, errors.InvalidArgumentError, |
| # on xla_gpu, we get InternalError instead. |
| errors.InternalError)), |
| dict( |
| testcase_name='MissingKeywordArgNestPiece', |
| conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)), |
| conc_kwargs=lambda: {'c': ragged_factory_ops.constant([[1]])}, |
| call_args=lambda: (constant_op.constant(1), constant_op.constant(2)), |
| call_kwargs=lambda: {'c': constant_op.constant(1)}, |
| error=r'func\(x, y, c, c_1\) missing required arguments: c_1'), |
| ]) |
| # pylint: enable=g-long-lambda |
| @test_util.run_in_graph_and_eager_modes |
| def testConcreteFunctionFlatSignatureError(self, |
| conc_args=(), |
| conc_kwargs=None, |
| call_args=(), |
| call_kwargs=None, |
| error='.*', |
| exception=TypeError): |
| """Tests for errors in the flat signature. |
| |
| Args: |
| conc_args: Positional arguments used for get_concrete_function. |
| conc_kwargs: Keyword arguments used for get_concrete_function. |
| call_args: Positional arguments used to call the function. |
| call_kwargs: Keyword arguments used to call the function. |
| error: Expected exception message. |
| exception: Expected exception type. |
| """ |
| conc_args = conc_args() if callable(conc_args) else conc_args |
| conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {} |
| call_args = call_args() if callable(call_args) else call_args |
| call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {} |
| self.assertIsInstance(conc_args, tuple) |
| self.assertIsInstance(call_args, tuple) |
| self.assertIsInstance(conc_kwargs, dict) |
| self.assertIsInstance(call_kwargs, dict) |
| |
| @def_function.function |
| def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg |
| del y, varargs, kwargs |
| return x |
| |
| conc = func.get_concrete_function(*conc_args, **conc_kwargs) |
| |
| # Remove _function_spec, to disable the structured signature. |
| conc._set_function_spec(None) # pylint: disable=protected-access |
| |
| with self.assertRaisesRegex(exception, error): |
| self.evaluate(conc(*call_args, **call_kwargs)) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testConcreteFunctionAmbiguousSignature(self): |
| # When both the flat & structured signatures are applicable, but they |
| # give different results, we use the structured signature. Note: we expect |
| # this to be extremely rare. |
| @def_function.function |
| def f(x, y): |
| return x * 10 + y |
| |
| conc = f.get_concrete_function( |
| x=tensor_spec.TensorSpec(None, dtypes.int32, name='y'), |
| y=tensor_spec.TensorSpec(None, dtypes.int32, name='x')) |
| |
| result = conc(x=constant_op.constant(5), y=constant_op.constant(6)) |
| self.assertAllEqual(result, 56) |
| |
| def testPrettyPrintedSignature(self): |
| |
| @def_function.function |
| def func(x, kangaroo=None, octopus=7): |
| del octopus, kangaroo |
| return x |
| |
| scalar = constant_op.constant(5) |
| vector = constant_op.constant([10, 10, 20]) |
| ragged = ragged_factory_ops.constant([[10, 20], [40]]) |
| |
| c1 = func.get_concrete_function(scalar, vector) |
| c1_summary = r'func\(x, kangaroo, octopus=7\)' |
| c1_details = (r' Args:\n' |
| r' x: int32 Tensor, shape=\(\)\n' |
| r' kangaroo: int32 Tensor, shape=\(3,\)\n' |
| r' Returns:\n' |
| r' int32 Tensor, shape=\(\)') |
| self.assertRegex(c1.pretty_printed_signature(verbose=False), c1_summary) |
| self.assertRegex( |
| c1.pretty_printed_signature(verbose=True), |
| c1_summary + '\n' + c1_details) |
| self.assertRegex( |
| repr(c1), r'<ConcreteFunction func\(x, kangaroo, octopus=7\) at .*>') |
| self.assertRegex( |
| str(c1), 'ConcreteFunction {}\n{}'.format(c1_summary, c1_details)) |
| |
| c2 = func.get_concrete_function(scalar, ragged, 3) |
| c2_summary = r'func\(x, kangaroo, octopus=3\)' |
| c2_details = (r' Args:\n' |
| r' x: int32 Tensor, shape=\(\)\n' |
| r' kangaroo: RaggedTensorSpec\(.*\)\n' |
| r' Returns:\n' |
| r' int32 Tensor, shape=\(\)') |
| self.assertRegex(c2.pretty_printed_signature(), |
| c2_summary + '\n' + c2_details) |
| |
| c3 = func.get_concrete_function({'a': scalar, 'b': [ragged, ragged]}) |
| c3_summary = r'func\(x, kangaroo=None, octopus=7\)' |
| c3_details = (r' Args:\n' |
| r" x: {'a': <1>, 'b': \[<2>, <3>\]}\n" |
| r' <1>: int32 Tensor, shape=\(\)\n' |
| r' <2>: RaggedTensorSpec\(.*\)\n' |
| r' <3>: RaggedTensorSpec\(.*\)\n' |
| r' Returns:\n' |
| r" {'a': <1>, 'b': \[<2>, <3>\]}\n" |
| r' <1>: int32 Tensor, shape=\(\)\n' |
| r' <2>: RaggedTensorSpec\(.*\)\n' |
| r' <3>: RaggedTensorSpec\(.*\)') |
| |
| # python 3.5 does not gurantee deterministic iteration of dict contents |
| # which can lead mismatch on pretty_printed_signature output for "Args" |
| if sys.version_info >= (3, 6): |
| self.assertRegex(c3.pretty_printed_signature(), |
| c3_summary + '\n' + c3_details) |
| |
| # pylint: disable=keyword-arg-before-vararg |
| @def_function.function |
| def func2(x, y=3, *args, **kwargs): |
| return (x, y, args, kwargs) |
| |
| c4 = func2.get_concrete_function(scalar, 4, 5, a=scalar) |
| c4_summary = 'func2(x, y=4, <arg3>=5, *, a)' |
| self.assertEqual(c4.pretty_printed_signature(verbose=False), c4_summary) |
| |
| c5 = func2.get_concrete_function(8, vector) |
| c5_summary = 'func2(x=8, y)' |
| self.assertEqual(c5.pretty_printed_signature(verbose=False), c5_summary) |
| |
| def testPrettyPrintedExplicitSignatureWithKeywordArg(self): # b/159639913 |
| |
| @def_function.function(input_signature=[tensor_spec.TensorSpec(None)]) |
| def fn(a, b=1): |
| return a + b |
| |
| concrete_fn = fn.get_concrete_function() |
| self.assertEqual(concrete_fn.pretty_printed_signature(False), 'fn(a)') |
| self.assertEqual( |
| concrete_fn.pretty_printed_signature(True), 'fn(a)\n' |
| ' Args:\n' |
| ' a: float32 Tensor, shape=<unknown>\n' |
| ' Returns:\n' |
| ' float32 Tensor, shape=<unknown>') |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testIndexedSlicesAsGradientsForConcreteFunctions(self): |
| |
| @def_function.function |
| def summing_rnn(inputs): |
| return math_ops.reduce_sum(inputs, axis=1) |
| |
| @def_function.function |
| def gradients(inputs): |
| with backprop.GradientTape() as tape: |
| tape.watch(inputs) |
| hidden = summing_rnn(inputs) |
| hidden = array_ops.gather(hidden, constant_op.constant([0])) |
| loss = math_ops.reduce_mean(hidden) |
| return tape.gradient(loss, inputs) |
| |
| gradients(constant_op.constant([[[1.0], [2.0]]])) # No error is raised |
| |
| def testFollowTypeHintsTraceBasic(self): |
| trace_count = [0] |
| |
| def func(x: ops.Tensor): |
| trace_count[0] += 1 |
| return x |
| |
| enabled = def_function.function(func, experimental_follow_type_hints=True) |
| disabled = def_function.function(func, experimental_follow_type_hints=False) |
| |
| enabled(1) # Initial call gets traced |
| enabled(2) |
| enabled(3) |
| self.assertEqual(trace_count[0], 1) |
| |
| trace_count = [0] |
| disabled(1) |
| disabled(2) # Retrace |
| disabled(3) # Retrace |
| self.assertEqual(trace_count[0], 3) |
| |
| def testFollowTypeHintsTraceWithArgs(self): |
| trace_count = [0] |
| |
| def func(*args: ops.Tensor): |
| trace_count[0] += 1 |
| return args |
| |
| enabled = def_function.function(func, experimental_follow_type_hints=True) |
| disabled = def_function.function(func, experimental_follow_type_hints=False) |
| |
| args = ( |
| 'abc', |
| 'def', |
| ) * 20 |
| args2 = ( |
| 'def', |
| 'abc', |
| ) * 20 |
| |
| enabled(args) |
| enabled(args2) |
| self.assertEqual(trace_count[0], 1) |
| |
| trace_count = [0] |
| disabled(args) |
| disabled(args2) # Retrace |
| self.assertEqual(trace_count[0], 2) |
| |
| def testFollowTypeHintsTraceWithKwargs(self): |
| trace_count = [0] |
| |
| def func(t: ops.Tensor, **kwargs: ops.Tensor): |
| del kwargs |
| trace_count[0] += 1 |
| return t |
| |
| enabled = def_function.function(func, experimental_follow_type_hints=True) |
| disabled = def_function.function(func, experimental_follow_type_hints=False) |
| |
| enabled(1, x=1, y=1.0, z='one') |
| enabled(2, x=2, y=2.0, z='two') |
| self.assertEqual(trace_count[0], 1) |
| |
| trace_count = [0] |
| disabled(1, x=1, y=1.0, z='one') |
| disabled(2, x=2, y=2.0, z='two') # Retrace |
| self.assertEqual(trace_count[0], 2) |
| |
| def testFollowTypeHintsTraceWithMultipleInputTypes(self): |
| trace_count = [0] |
| |
| def func(t: ops.Tensor, *args: ops.Tensor, **kwargs: ops.Tensor): |
| del args, kwargs |
| trace_count[0] += 1 |
| return t |
| |
| enabled = def_function.function(func, experimental_follow_type_hints=True) |
| disabled = def_function.function(func, experimental_follow_type_hints=False) |
| |
| enabled(1, constant_op.constant(1), 'str', x=4.0) |
| enabled(2, constant_op.constant(2), 'str2', x=5.0) |
| self.assertEqual(trace_count[0], 1) |
| |
| trace_count = [0] |
| disabled(1, constant_op.constant(1), 'str', x=4.0) |
| disabled(2, constant_op.constant(2), 'str2', x=5.0) # Retrace |
| self.assertEqual(trace_count[0], 2) |
| |
| def testFollowTypeHintsTraceWithOnlyArgNamed(self): |
| trace_count = [0] |
| |
| def func(t: ops.Tensor, i: int = 1, **kwargs): # pylint: disable=bad-whitespace |
| del i, kwargs |
| trace_count[0] += 1 |
| return t |
| |
| enabled = def_function.function(func, experimental_follow_type_hints=True) |
| |
| enabled(1, 3, x=4.0, y='str') |
| enabled(2, 4, x=4.0, y='str') # Retrace |
| self.assertEqual(trace_count[0], 2) |
| |
| def testFollowTypeHintsTraceWithNotAllNamed(self): |
| trace_count = [0] |
| |
| def func(x, y: ops.Tensor, z: int): |
| del y, z |
| trace_count[0] += 1 |
| return x |
| |
| enabled = def_function.function(func, experimental_follow_type_hints=True) |
| |
| enabled(1, 2, 3) |
| enabled(1, 20, 3) # No retrace - change in ops.Tensor typed arg |
| enabled(2, 2, 3) # Retrace - change in untyped arg |
| enabled(2, 2, 4) # Retrace - change in typed arg |
| self.assertEqual(trace_count[0], 3) |
| |
| def testFollowTypeHintsTraceWithOnlyArgsNamed(self): |
| trace_count = [0] |
| |
| def func(x, y, *args: ops.Tensor): |
| del y, args |
| trace_count[0] += 1 |
| return x |
| |
| enabled = def_function.function(func, experimental_follow_type_hints=True) |
| |
| enabled(1, 20, 3, 4, 5, 6) |
| enabled(1, 20, 3, 4, 5, 60) # No retrace - change in *args |
| enabled(1, 30, 7, 8, 9, 10) # Retrace - change in args |
| self.assertEqual(trace_count[0], 2) |
| |
| def testFollowTypeHintsTraceWithOnlyKwargsNamed(self): |
| trace_count = [0] |
| |
| def func(x, y, *args, **kwargs: ops.Tensor): |
| del y, args, kwargs |
| trace_count[0] += 1 |
| return x |
| |
| enabled = def_function.function(func, experimental_follow_type_hints=True) |
| |
| enabled(1, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0) |
| enabled( |
| 1, 2, 3, 4, 5, 6, a=1.5, b=2.5, |
| c=3.5) # No retrace - change in **kwargs |
| enabled(100, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0) # Retrace - change in args |
| enabled( |
| 1, 2, 3, 4, 5, 100, a=1.0, b=2.0, c=3.0) # Retrace - change in *args |
| self.assertEqual(trace_count[0], 3) |
| |
| def testFollowTypeHintsTraceWithArgsEquals(self): |
| trace_count = [0] |
| |
| def func( |
| x: ops.Tensor = 0, # pylint:disable=bad-whitespace |
| y: int = 1, # pylint:disable=bad-whitespace |
| **kwargs: ops.Tensor): |
| del y, kwargs |
| trace_count[0] += 1 |
| return x |
| |
| enabled = def_function.function(func, experimental_follow_type_hints=True) |
| |
| enabled(x=1, y=2, z=3) |
| enabled(x=1, y=3, z=3) # Retrace - change in args |
| enabled(x=2, y=2, z=4) # No retrace - change in args and **kwargs |
| enabled(x=2, y=2, z=4, u=5) # Retrace - change in **kwargs |
| self.assertEqual(trace_count[0], 3) |
| |
| def testFollowTypeHintsTraceWithArgsEqualsTypedKwargs(self): |
| trace_count = [0] |
| |
| def func(x, y, **kwargs: ops.Tensor): |
| del y, kwargs |
| trace_count[0] += 1 |
| return x |
| |
| enabled = def_function.function(func, experimental_follow_type_hints=True) |
| |
| enabled(x=1, y=2, z=3) |
| enabled(x=1, y=3, z=3) # Retrace |
| enabled(x=1, y=2, z=4) # No retrace |
| enabled(x=2, y=2, z=4) # Retrace |
| enabled(x=2, y=2, z=4, u=5) # Retrace |
| self.assertEqual(trace_count[0], 4) |
| |
| def testFollowTypeHintsTraceWithArgsEqualsTypedArgs(self): |
| trace_count = [0] |
| |
| def func(x: ops.Tensor, y: int, **kwargs): |
| del y, kwargs |
| trace_count[0] += 1 |
| return x |
| |
| enabled = def_function.function(func, experimental_follow_type_hints=True) |
| |
| enabled(x=1, y=2, z=3) |
| enabled(x=1, y=3, z=3) # Retrace |
| enabled(x=1, y=2, z=4) # Retrace |
| enabled(x=2, y=2, z=3) # No retrace |
| enabled(x=2, y=2, z=4, u=5) # Retrace |
| self.assertEqual(trace_count[0], 4) |
| |
| def testFollowTypeHintsTraceWithKwOnlyArgsBasic(self): |
| trace_count = [0] |
| |
| def func(*, a: ops.Tensor = None, b=1): # pylint: disable=bad-whitespace |
| del b |
| trace_count[0] += 1 |
| return a |
| |
| enabled = def_function.function(func, experimental_follow_type_hints=True) |
| |
| enabled(a=1, b=2) |
| enabled(a=2, b=2) # No retrace |
| enabled(a=1, b=1) # Retrace |
| self.assertEqual(trace_count[0], 2) |
| |
| def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArg(self): |
| trace_count = [0] |
| |
| def func(arg: ops.Tensor, *args, kwonly, **kwargs): |
| del args, kwonly, kwargs |
| trace_count[0] += 1 |
| return arg |
| |
| enabled = def_function.function(func, experimental_follow_type_hints=True) |
| |
| enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) |
| enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # No retrace |
| enabled(1000, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # No retrace |
| enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace |
| enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace |
| enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace |
| self.assertEqual(trace_count[0], 4) |
| |
| def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArgs(self): |
| trace_count = [0] |
| |
| def func(arg, *args: ops.Tensor, kwonly, **kwargs): |
| del args, kwonly, kwargs |
| trace_count[0] += 1 |
| return arg |
| |
| enabled = def_function.function(func, experimental_follow_type_hints=True) |
| |
| enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) |
| enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace |
| enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # No retrace |
| enabled(1, 200, 300, 400, kwonly=5, kwarg1=6, kwarg2=7) # No retrace |
| enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace |
| enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace |
| self.assertEqual(trace_count[0], 4) |
| |
| def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwOnlyArg(self): |
| trace_count = [0] |
| |
| def func(arg, *args, kwonly: ops.Tensor, **kwargs): |
| del args, kwonly, kwargs |
| trace_count[0] += 1 |
| return arg |
| |
| enabled = def_function.function(func, experimental_follow_type_hints=True) |
| |
| enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) |
| enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace |
| enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace |
| enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # No retrace |
| enabled(1, 2, 3, 4, kwonly=500, kwarg1=6, kwarg2=7) # No retrace |
| enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace |
| self.assertEqual(trace_count[0], 4) |
| |
| def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwargs(self): |
| trace_count = [0] |
| |
| def func(arg, *args, kwonly, **kwargs: ops.Tensor): |
| del args, kwonly, kwargs |
| trace_count[0] += 1 |
| return arg |
| |
| enabled = def_function.function(func, experimental_follow_type_hints=True) |
| |
| enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) |
| enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace |
| enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace |
| enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace |
| enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # No retrace |
| enabled(1, 2, 3, 4, kwonly=5, kwarg1=600, kwarg2=700) # No retrace |
| self.assertEqual(trace_count[0], 4) |
| |
| def testWithModuleNameScope(self): |
| self.skipTest('b/166158748:function does not handle this case correctly.') |
| |
| class Foo(module.Module): |
| |
| def __init__(self): |
| super().__init__() |
| self.var = None |
| |
| @def_function.function |
| @module.Module.with_name_scope |
| def add(self, x, y, z=1): |
| if self.var is None: |
| return x + y + z |
| |
| foo = Foo() |
| self.assertEqual(foo.add(2, 3), 6) |
| |
| def testWithModuleNameScopeRedundantArgs(self): |
| self.skipTest('b/166158748:function does not handle this case correctly.') |
| |
| class Foo(module.Module): |
| |
| def __init__(self): |
| super().__init__() |
| self.var = None |
| |
| @def_function.function |
| @module.Module.with_name_scope |
| def add(self, x, y): |
| if self.var is None: |
| return x + y |
| |
| foo = Foo() |
| with self.assertRaisesRegex(TypeError, 'got two values for argument'): |
| foo.add(2, x=3) # pylint: disable=redundant-keyword-arg,no-value-for-parameter |
| |
| def testWithModuleNameScopeMissingArgs(self): |
| self.skipTest('b/166158748:function does not handle this case correctly.') |
| |
| class Foo(module.Module): |
| |
| def __init__(self): |
| super().__init__() |
| self.var = None |
| |
| @def_function.function |
| @module.Module.with_name_scope |
| def add(self, x, y): |
| if self.var is None: |
| return x + y |
| |
| foo = Foo() |
| with self.assertRaisesRegex(TypeError, 'missing required arguments: y'): |
| foo.add(2) # pylint: disable=no-value-for-parameter |
| |
| def testShapeInferencePropagateConstNestedStack(self): |
| |
| @def_function.function(input_signature=[ |
| tensor_spec.TensorSpec((None, None), dtype=dtypes.int32), |
| tensor_spec.TensorSpec((), dtype=dtypes.int32), |
| ]) |
| def f(x, s): |
| old_shape = array_ops.shape(x) |
| new_shape = array_ops.stack([old_shape[0], s], axis=0) |
| y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) |
| return y |
| |
| @def_function.function(input_signature=[ |
| tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32) |
| ]) |
| def g(x): |
| y = f(x, s=5) |
| assert y.shape.as_list() == [3, 5], y.shape.as_list() |
| return y |
| |
| self.assertAllEqual( |
| g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5])) |
| |
| def testShapeInferencePropagateConstNestedUnstackStack(self): |
| |
| @def_function.function(input_signature=[ |
| tensor_spec.TensorSpec((None, None), dtype=dtypes.int32), |
| tensor_spec.TensorSpec((), dtype=dtypes.int32), |
| ]) |
| def f(x, s): |
| s0, _ = array_ops.unstack(array_ops.shape(x), axis=0) |
| new_shape = array_ops.stack([s0, s], axis=0) |
| y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) |
| return y |
| |
| @def_function.function(input_signature=[ |
| tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32) |
| ]) |
| def g(x): |
| y = f(x, s=5) |
| assert y.shape.as_list() == [3, 5], y.shape.as_list() |
| return y |
| |
| self.assertAllEqual( |
| g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5])) |
| |
| def testShapeInferencePropagateConstNestedConcat(self): |
| |
| @def_function.function(input_signature=[ |
| tensor_spec.TensorSpec((), dtype=dtypes.int32), |
| tensor_spec.TensorSpec((), dtype=dtypes.int32), |
| tensor_spec.TensorSpec((), dtype=dtypes.int32), |
| ]) |
| def f(d1, d2, d3): |
| new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1) |
| y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) |
| return y |
| |
| @def_function.function() |
| def g(): |
| y = f(1, 2, 3) |
| assert y.shape.as_list() == [1, 2, 3], y.shape.as_list() |
| return y |
| |
| self.assertAllEqual(g(), array_ops.ones([1, 2, 3])) |
| |
| def testShapeInferencePropagateConstDoubleNested(self): |
| |
| @def_function.function(input_signature=[ |
| tensor_spec.TensorSpec((), dtype=dtypes.int32), |
| tensor_spec.TensorSpec((), dtype=dtypes.int32), |
| tensor_spec.TensorSpec((), dtype=dtypes.int32), |
| ]) |
| def f(d1, d2, d3): |
| new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1) |
| y = array_ops.ones(shape=new_shape, dtype=dtypes.int32) |
| return y |
| |
| @def_function.function() |
| def g(): |
| y = def_function.function(f)(1, 2, 3) |
| assert y.shape.as_list() == [1, 2, 3], y.shape.as_list() |
| return y |
| |
| self.assertAllEqual(g(), array_ops.ones([1, 2, 3])) |
| |
| @test_util.run_v2_only |
| def testControlDependencyAfterInline(self): |
| v = variables.Variable(0.) |
| |
| @def_function.function |
| def assign(): |
| return v.assign(1.) |
| |
| @def_function.function |
| def assign_add(): |
| return v.assign_add(1.) |
| |
| @def_function.function |
| def f(): |
| check_ops.assert_equal_v2(assign(), 1.) |
| check_ops.assert_equal_v2(assign_add(), 2.) |
| |
| # We don't have a way to inspect the inlined graph in Python, so we run it |
| # multiple times to have more confidence the dependency is correct. |
| for _ in range(30): |
| f() |
| |
| @test_util.run_v2_only |
| def testReadInFuncWriteOutside(self): |
| # Run many times since we are testing for a potential race condition. |
| for _ in range(30): |
| # pylint: disable=cell-var-from-loop |
| v = variables.Variable(1.) |
| |
| @def_function.function |
| def add_one(): |
| return v + 1. |
| |
| @def_function.function |
| def get_v_plus_one(): |
| v_plus_one = add_one() |
| v.assign_add(2.0) |
| return v_plus_one |
| |
| self.assertAllEqual(get_v_plus_one(), 2.0) |
| |
| |
| class MultiDeviceTest(test.TestCase, parameterized.TestCase): |
| |
| @test_util.run_gpu_only |
| def testMultiDeviceOutput(self): |
| """Tests that functions can produce outputs on multiple devices.""" |
| @function.defun |
| def func(a, b, transpose_a): |
| with ops.device('/device:CPU:0'): |
| m1 = math_ops.matmul(a, b, transpose_a=transpose_a) |
| with ops.device('/device:GPU:0'): |
| m2 = math_ops.matmul(a, b, transpose_a=transpose_a) |
| return m1, m2 |
| |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| m1, m2 = func(t, t, transpose_a=True) |
| self.assertAllEqual(m1.numpy(), [[10, 14], [14, 20]]) |
| self.assertRegex(m1.backing_device, 'CPU') |
| self.assertAllEqual(m2.numpy(), [[10, 14], [14, 20]]) |
| self.assertRegex(m2.backing_device, 'GPU') |
| |
| @test_util.run_gpu_only |
| def testEmptyBody(self): |
| @function.defun |
| def func(a, b): |
| return b, a |
| |
| with ops.device('/device:CPU:0'): |
| a = array_ops.identity(3.0) |
| with ops.device('/device:GPU:0'): |
| b = array_ops.identity(5.0) |
| |
| m1, m2 = func(a, b) |
| self.assertAllEqual(m1.numpy(), 5.0) |
| self.assertRegex(m1.backing_device, 'GPU') |
| self.assertAllEqual(m2.numpy(), 3.0) |
| self.assertRegex(m2.backing_device, 'CPU') |
| |
| @test_util.run_gpu_only |
| def testMultiDeviceInt32(self): |
| """Tests that multi-device functions can take and output INT32s. |
| |
| When an INT32 device tensor is fed into a function, it is copied to CPU |
| by the eager runtime. The function sees all INT32 inputs on CPU. |
| |
| We set allocator attribute 'on_host' for INT32 outputs. They can be |
| partitioned into the GPU component function, but will be allocated on |
| CPU nevertheless. |
| |
| There is experimental support for `ints_on_device` in |
| FunctionLibraryRuntime now. We can try that. |
| |
| """ |
| with ops.device('/device:CPU:0'): |
| int_cpu = constant_op.constant(3, dtype=dtypes.int32) |
| resource = resource_variable_ops.ResourceVariable(5, dtype=dtypes.int32) |
| with ops.device('/device:GPU:0'): |
| int_gpu = constant_op.constant(7, dtype=dtypes.int32) |
| |
| @function.defun |
| def func(int_cpu, resource, int_gpu): |
| with ops.device('/device:CPU:0'): |
| m1 = int_cpu * resource + int_gpu |
| with ops.device('/device:GPU:0'): |
| # This computation will happen on GPU but m2 will be copied to CPU. |
| m2 = int_gpu * resource + int_cpu + 1 |
| return m1, m2 |
| |
| m1, m2 = func(int_cpu, resource, int_gpu) |
| self.assertAllEqual(m1.numpy(), 22) |
| self.assertRegex(m1.backing_device, 'CPU') |
| self.assertAllEqual(m2.numpy(), 39) |
| self.assertRegex(m2.backing_device, 'CPU') |
| |
| # flip arguments |
| m1, m2 = func(int_gpu, resource, int_cpu) |
| self.assertAllEqual(m1.numpy(), 38) |
| self.assertRegex(m1.backing_device, 'CPU') |
| self.assertAllEqual(m2.numpy(), 23) |
| self.assertRegex(m2.backing_device, 'CPU') |
| |
| @test_util.run_gpu_only |
| def testMultiDeviceColocateWith(self): |
| """Tests that function's outputs respect colocation constraints.""" |
| @function.defun |
| def func(a, b): |
| with ops.colocate_with(a): |
| ra = 2 * a |
| with ops.colocate_with(b): |
| rb = 3 * b |
| return ra, rb |
| |
| devices = ['/device:CPU:0', '/device:GPU:0'] |
| for dev1, dev2 in itertools.product(devices, devices): |
| with ops.device(dev1): |
| a = array_ops.identity(1.0) |
| with ops.device(dev2): |
| b = array_ops.identity(10.0) |
| |
| ra, rb = func(a, b) |
| self.assertEqual(ra.numpy(), 2.0) |
| self.assertRegex(ra.backing_device, dev1) |
| self.assertEqual(rb.numpy(), 30.0) |
| self.assertRegex(rb.backing_device, dev2) |
| |
| @test_util.run_gpu_only |
| def testMultiDeviceResources(self): |
| with ops.device('/device:CPU:0'): |
| c1 = resource_variable_ops.ResourceVariable(2.0) |
| c2 = resource_variable_ops.ResourceVariable(7.0) |
| with ops.device('/device:GPU:0'): |
| g1 = resource_variable_ops.ResourceVariable(3.0) |
| g2 = resource_variable_ops.ResourceVariable(5.0) |
| |
| @function.defun |
| def func(resource1, resource2): |
| with ops.device('/device:CPU:0'): |
| result1 = resource1 * g2 |
| with ops.device('/device:GPU:0'): |
| result2 = resource2 * c2 |
| return result1, result2 |
| |
| r1, r2 = func(c1, g1) |
| self.assertEqual(r1.numpy(), 10.0) |
| self.assertRegex(r1.backing_device, 'CPU') |
| self.assertEqual(r2.numpy(), 21.0) |
| self.assertRegex(r2.backing_device, 'GPU') |
| |
| # Call with flipped inputs. Check that we look at resource's |
| # device and reinstantiates the function when inputs' devices change. |
| r1, r2 = func(g1, c1) |
| self.assertEqual(r1.numpy(), 15.0) |
| self.assertRegex(r1.backing_device, 'CPU') |
| self.assertEqual(r2.numpy(), 14.0) |
| self.assertRegex(r2.backing_device, 'GPU') |
| |
| @test_util.run_gpu_only |
| def testOutputResources(self): |
| with ops.device('/device:CPU:0'): |
| c1 = resource_variable_ops.ResourceVariable(2.0) |
| with ops.device('/device:GPU:0'): |
| g1 = resource_variable_ops.ResourceVariable(3.0) |
| |
| @function.defun |
| def func(resource1, resource2): |
| with ops.device('/device:CPU:0'): |
| result1 = resource1 * 5 |
| with ops.device('/device:GPU:0'): |
| result2 = resource2 * 7 |
| return result1, resource1.handle, result2, resource2.handle |
| |
| r1, res1, r2, res2 = func(c1, g1) |
| self.assertEqual(r1.numpy(), 10.0) |
| self.assertRegex(r1.backing_device, 'CPU') |
| self.assertEqual(r2.numpy(), 21.0) |
| self.assertRegex(r2.backing_device, 'GPU') |
| |
| def check_handle(handle, expected_value): |
| self.assertRegex(handle.backing_device, 'CPU') |
| tensor = gen_resource_variable_ops.read_variable_op( |
| handle, dtypes.float32) |
| self.assertEqual(tensor.numpy(), expected_value) |
| |
| # Check that handles returned from functions are on CPU and an op using |
| # the resource handle is correctly placed on the device backing the |
| # resource. |
| check_handle(res1, 2.0) |
| check_handle(res2, 3.0) |
| |
| # Call with flipped inputs to make sure the same the function is |
| # reinstantiated and eager runtime does not mess up the device assignment |
| # for ops consuming handles returned from defuns. |
| r1, res1, r2, res2 = func(g1, c1) |
| self.assertEqual(r1.numpy(), 15.0) |
| self.assertRegex(r1.backing_device, 'CPU') |
| self.assertEqual(r2.numpy(), 14.0) |
| self.assertRegex(r2.backing_device, 'GPU') |
| check_handle(res1, 3.0) |
| check_handle(res2, 2.0) |
| |
| @test_util.run_gpu_only |
| def testPassResourceThroughNestedFunctionCall(self): |
| """Test passing GPU resource to noinline function call placed on CPU. |
| |
| PartitionedCallOp must not enforce any particular device assignment for the |
| resource output. Inner function marked as `_nospecialize`, so Grappler would |
| not prune unused function output. |
| """ |
| |
| with ops.device('/device:GPU:0'): |
| g1 = resource_variable_ops.ResourceVariable(3.0) |
| |
| @function.defun_with_attributes(attributes={ |
| '_noinline': True, |
| '_nospecialize': True |
| }) |
| def inner(resource1): |
| return resource1 * 2, resource1.handle |
| |
| @function.defun |
| def outer(resource1): |
| with ops.device('/device:CPU:0'): |
| r1, _ = inner(resource1) |
| return r1 |
| |
| r1 = outer(g1) |
| |
| self.assertEqual(r1.numpy(), 6.0) |
| self.assertRegex(r1.backing_device, 'CPU') |
| |
| @test_util.run_gpu_only |
| def testReturnResourceFromNestedFunctionCall(self): |
| """Test returning GPU resource from noinline function call placed on CPU. |
| |
| When inferring output devices for the return value, do not set a device for |
| returns of DT_RESOURCE data type based on the device assignment of the node |
| that produced that resource. As an example function call placed on CPU can |
| return resources on GPU. |
| """ |
| |
| with ops.device('/device:GPU:0'): |
| g1 = resource_variable_ops.ResourceVariable(3.0) |
| |
| @function.defun_with_attributes(attributes={ |
| '_noinline': True |
| }) |
| def inner(resource1): |
| resource1.assign_add(2.0) |
| return resource1 * 2, resource1.handle |
| |
| @function.defun |
| def outer(resource1): |
| with ops.device('/device:CPU:0'): |
| r1, res1 = inner(resource1) |
| return r1, res1 |
| |
| r1, res1 = outer(g1) |
| |
| self.assertEqual(r1.numpy(), 10.0) |
| self.assertRegex(r1.backing_device, 'CPU') |
| |
| def check_handle(handle, expected_value): |
| self.assertRegex(handle.backing_device, 'CPU') |
| tensor = gen_resource_variable_ops.read_variable_op( |
| handle, dtypes.float32) |
| self.assertEqual(tensor.numpy(), expected_value) |
| |
| # Check that handles returned from functions are on CPU and an op using |
| # the resource handle is correctly placed on the device backing the |
| # resource. |
| check_handle(res1, 5.0) |
| |
| @test_util.run_gpu_only |
| def testComplexInputOutputDevicePattern(self): |
| """Tests input/output mapping logic in partitioning.""" |
| with ops.device('/device:CPU:0'): |
| rc0 = resource_variable_ops.ResourceVariable(2.0) |
| rc1 = resource_variable_ops.ResourceVariable(3.0) |
| cc0 = array_ops.identity(5.0) |
| cc1 = array_ops.identity(7.0) |
| with ops.device('/device:GPU:0'): |
| rg0 = resource_variable_ops.ResourceVariable(11.0) |
| rg1 = resource_variable_ops.ResourceVariable(13.0) |
| cg0 = array_ops.identity(17.0) |
| cg1 = array_ops.identity(19.0) |
| |
| # Make sure tensors are on expected devices. |
| for tensor in [cc0, cc1]: |
| self.assertRegex(tensor.backing_device, 'CPU:0') |
| for tensor in [cg0, cg1]: |
| self.assertRegex(tensor.backing_device, 'GPU:0') |
| |
| @function.defun |
| def func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1): |
| with ops.device('/device:CPU:0'): |
| m1 = rc0 * cg0 |
| with ops.device('/device:GPU:0'): |
| m2 = rg0 * cc0 |
| |
| with ops.device('/device:CPU:0'): |
| r1 = 1000.0 * m2 + rc1 * cg1 |
| with ops.device('/device:GPU:0'): |
| r2 = 1000.0 * m1 + rg1 * cc1 |
| |
| return r1, r2, m2, m1 |
| |
| r1, r2, m2, m1 = func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1) |
| self.assertRegex(m1.backing_device, 'CPU') |
| self.assertRegex(r1.backing_device, 'CPU') |
| self.assertRegex(m2.backing_device, 'GPU') |
| self.assertRegex(r2.backing_device, 'GPU') |
| self.assertEqual(m1.numpy(), 34.0) |
| self.assertEqual(r1.numpy(), 55000.0 + 3.0 * 19.0) |
| self.assertEqual(m2.numpy(), 55.0) |
| self.assertEqual(r2.numpy(), 34000.0 + 13.0 * 7.0) |
| |
| @test_util.run_gpu_only |
| def testArgumentPruning(self): |
| """Tests functions taking unnecessary arguments.""" |
| with ops.device('/device:CPU:0'): |
| c1 = constant_op.constant(5.0) |
| c2 = constant_op.constant(7.0) |
| |
| with ops.device('/device:GPU:0'): |
| g1 = constant_op.constant(11.0) |
| g2 = constant_op.constant(13.0) |
| g3 = constant_op.constant(17.0) |
| |
| @function.defun |
| def func(g1, g2, c1, g3, c2): # pylint: disable=unused-argument |
| # arguments g1 and g2 are unused and can be pruned by grappler. |
| return c1 * g3 * c2 |
| |
| result = func(g1, g2, c1, g3, c2) |
| self.assertEqual(result.numpy(), 5.0 * 7.0 * 17.0) |
| |
| def testNestedCallWatchedVariables(self): |
| |
| v = variables.Variable(4.) |
| |
| @def_function.function |
| def f(): |
| return v ** 2. |
| |
| with backprop.GradientTape() as tape: |
| f() |
| |
| self.assertEqual((v,), tape.watched_variables()) |
| |
| @def_function.function |
| def g(): |
| return f() |
| |
| with backprop.GradientTape() as tape: |
| g() |
| |
| self.assertEqual((v,), tape.watched_variables()) |
| |
| # f() can rely on the variable being read during its trace. g() checks that |
| # variables from a function which knows about them are recorded on the |
| # tape. h() tests that functions forward knowledge of variables to callers. |
| |
| @def_function.function |
| def h(): |
| return g() |
| |
| with backprop.GradientTape() as tape: |
| h() |
| |
| self.assertEqual((v,), tape.watched_variables()) |
| |
| def testDeferredCapture(self): |
| value = 1.0 |
| |
| @def_function.function |
| def lazy_capture(x): |
| y = ops.get_default_graph().capture_call_time_value( |
| lambda: value, tensor_spec.TensorSpec(None)) |
| return x + y |
| |
| self.assertAllEqual(lazy_capture(2.0), 3.0) |
| # After changing the value of `value` the function call should return a |
| # different result. |
| value = 2.0 |
| self.assertAllEqual(lazy_capture(2.0), 4.0) |
| |
| def testDeferredCaptureWithKey(self): |
| value0 = 1.0 |
| value1 = 2.0 |
| |
| @def_function.function |
| def lazy_capture(x): |
| w = ops.get_default_graph().capture_call_time_value( |
| lambda: value0, tensor_spec.TensorSpec(None), key=0) |
| y = ops.get_default_graph().capture_call_time_value( |
| lambda: value1, tensor_spec.TensorSpec(None), key=1) |
| def bad_closure(): |
| raise ValueError('Should not run') |
| z = ops.get_default_graph().capture_call_time_value( |
| bad_closure, tensor_spec.TensorSpec(None), key=1) |
| return x + y + w + z |
| |
| self.assertAllEqual(lazy_capture(2.0), 7.0) |
| value0 = 2.0 |
| value1 = 3.0 |
| self.assertAllEqual(lazy_capture(2.0), 10.0) |
| |
| def testDeferredCaptureTypeError(self): |
| value = constant_op.constant(1.0) |
| |
| @def_function.function |
| def lazy_capture(x): |
| y = ops.get_default_graph().capture_call_time_value( |
| lambda: value, tensor_spec.TensorSpec(())) |
| return x + y |
| |
| self.assertAllEqual(lazy_capture(2.0), 3.0) |
| |
| # dtype mismatch |
| value = constant_op.constant(1) |
| with self.assertRaisesRegex(ValueError, 'Value .* to a tensor with dtype'): |
| lazy_capture(2.0) |
| |
| # shape mismatch |
| value = constant_op.constant([1.0]) |
| with self.assertRaisesRegex(ValueError, 'Value .* shape'): |
| lazy_capture(2.0) |
| |
| def testDeferredCaptureReturnNestWithCompositeTensor(self): |
| i_s = indexed_slices.IndexedSlices( |
| constant_op.constant([1, 2]), |
| constant_op.constant([0, 1], dtype=dtypes.int64), |
| constant_op.constant([2])) |
| r_t = ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]]) |
| s_t = sparse_tensor.SparseTensor( |
| values=[1, 2, 3], indices=[[0], [8], [10]], dense_shape=[20]) |
| |
| @def_function.function |
| def lazy_capture(): |
| y = ops.get_default_graph().capture_call_time_value( |
| lambda: {'i': i_s, 't': (r_t, s_t)}, |
| {'i': indexed_slices.IndexedSlicesSpec( |
| dtype=dtypes.int32, dense_shape_dtype=dtypes.int32), |
| 't': (ragged_tensor.RaggedTensorSpec([2, None, None], dtypes.int32), |
| sparse_tensor.SparseTensorSpec([None], dtypes.int32))}) |
| return y['i'], y['t'] |
| |
| i, (r, s) = lazy_capture() |
| self.assertAllEqual(i_s.values, i.values) |
| self.assertAllEqual(i_s.indices, i.indices) |
| self.assertAllEqual(i_s.dense_shape, i.dense_shape) |
| self.assertAllEqual(r_t, r) |
| self.assertAllEqual(s_t.indices, s.indices) |
| self.assertAllEqual(s_t.values, s.values) |
| self.assertAllEqual(s_t.dense_shape, s.dense_shape) |
| |
| def testDeferredCaptureCompositeTensorSpecTypeMismatch(self): |
| value = indexed_slices.IndexedSlices( |
| constant_op.constant([1, 2]), |
| constant_op.constant([0, 1], dtype=dtypes.int64)) |
| |
| @def_function.function |
| def lazy_capture(): |
| return ops.get_default_graph().capture_call_time_value( |
| lambda: value, |
| indexed_slices.IndexedSlicesSpec(dtype=dtypes.int32)) |
| |
| # Type matches spec. |
| lazy_capture() |
| |
| # Extra dense shape component. |
| value = indexed_slices.IndexedSlices( |
| constant_op.constant([1, 2]), |
| constant_op.constant([0, 1], dtype=dtypes.int64), |
| constant_op.constant([2])) |
| with self.assertRaises(ValueError): |
| lazy_capture() |
| |
| # Index dtype mismatch int32 vs. int64. |
| value = indexed_slices.IndexedSlices( |
| constant_op.constant([1, 2]), |
| constant_op.constant([0, 1])) |
| with self.assertRaises(ValueError): |
| lazy_capture() |
| |
| |
| if __name__ == '__main__': |
| ops.enable_eager_execution() |
| test.main() |