blob: 7ebcf77ff61032c66bebe64b24f4b65ba7640bf4 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import functools
import itertools
import multiprocessing.pool
import sys
import time
import weakref
from absl.testing import parameterized
import numpy
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import cancellation
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import function as tf_function
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import type_spec
from tensorflow.python.layers import convolutional
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gen_sendrecv_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.structured import structured_tensor
from tensorflow.python.platform import test
from tensorflow.python.training import training_ops
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
try:
import attr # pylint:disable=g-import-not-at-top
except ImportError:
attr = None
def total_function_cache(defined):
# pylint: disable=protected-access
return (set(defined._function_cache.primary)
| set(defined._function_cache.arg_relaxed))
# pylint: enable=protected-access
def _example_indexed_slices_with_dense_shape():
return indexed_slices.IndexedSlices(
constant_op.constant([1, 2]), constant_op.constant([0, 1]),
constant_op.constant([2]))
def _example_indexed_slices_without_dense_shape():
return indexed_slices.IndexedSlices(
constant_op.constant([1, 2]), constant_op.constant([0, 1]))
def _spec_for_value(value):
"""Returns the (nested) TypeSpec for a value."""
if nest.is_sequence(value):
return nest.map_structure(_spec_for_value, value)
elif isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)):
return type_spec.type_spec_from_value(value)
else:
return value
class FunctionTest(test.TestCase, parameterized.TestCase):
def setUp(self):
super(FunctionTest, self).setUp()
cpus = config.list_physical_devices('CPU')
# Set 4 virtual CPUs
config.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration()
])
def testBasic(self):
matmul = def_function.function(math_ops.matmul)
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq = matmul(t, t, transpose_a=True)
sq2 = matmul(sq, t, transpose_a=True)
self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
self.assertAllEqual(sq2.numpy().reshape(-1), [52, 76, 74, 108])
def testOnExitCallback(self):
values = []
def append_1():
values.append(1)
def append_2():
values.append(2)
def g(x):
old_values = list(values)
ops.add_exit_callback_to_default_func_graph(append_1)
self.assertEqual(old_values, values)
return x + 1
tf_g = def_function.function(g)
def f(x):
old_values = list(values)
ops.add_exit_callback_to_default_func_graph(append_2)
self.assertEqual(old_values, values)
return tf_g(x)
tf_f = def_function.function(f)
self.assertEmpty(values)
tf_f(constant_op.constant(1.0))
self.assertEqual(values, [1, 2]) # Once for g, once for f.
tf_f(constant_op.constant([1.0])) # force a retrace
self.assertEqual(values, [1, 2, 1, 2]) # And again.
def testCannotAddExitCallbackWhenNotInFunctionScope(self):
with self.assertRaisesRegex(RuntimeError, 'when not building a function.'):
ops.add_exit_callback_to_default_func_graph(lambda: None)
def testVariable(self):
v1 = variables.Variable(1.0)
add = def_function.function(lambda x, v: x + v1 + v)
v2 = variables.Variable(1.0)
x = constant_op.constant(1.0)
r = add(x, v2)
self.assertEqual(3.0, self.evaluate(r))
def testVariableOnly(self):
v = variables.Variable(1.0)
add = def_function.function(lambda x: x.assign_add(1.0))
r1 = add(v)
self.assertEqual(2.0, self.evaluate(r1))
c = constant_op.constant(1.0)
with self.assertRaisesRegex(AttributeError, 'no attribute'):
add(c)
@test_util.disable_tfrt('Packed tensor is not supported in tfrt yet.')
def testPackedVariable(self):
with ops.device('/cpu:0'):
v0_0 = resource_variable_ops.ResourceVariable(1.0)
with ops.device('/cpu:1'):
v0_1 = resource_variable_ops.ResourceVariable(2.0)
v1_0 = resource_variable_ops.ResourceVariable(3.0)
with ops.device('/cpu:2'):
v1_1 = resource_variable_ops.ResourceVariable(4.0)
packed_var_0 = ops.pack_eager_tensors([v0_0.handle, v0_1.handle])
packed_var_1 = ops.pack_eager_tensors([v1_0.handle, v1_1.handle])
# TODO(b/145922293): use ResourceVariable.assign_add and
# ResourceVariable.read_value directly once we support packing multiple
# ResourceVariable into one ResourceVariable.
@def_function.function
def read_var():
resource_variable_ops.assign_add_variable_op(
packed_var_0, constant_op.constant(5.0))
resource_variable_ops.assign_add_variable_op(
packed_var_1, constant_op.constant(6.0))
with ops.device('/cpu:0'):
read0 = resource_variable_ops.read_variable_op(
packed_var_0, dtype=dtypes.float32)
with ops.device('/cpu:1'):
read1 = resource_variable_ops.read_variable_op(
packed_var_0, dtype=dtypes.float32)
read2 = resource_variable_ops.read_variable_op(
packed_var_1, dtype=dtypes.float32)
with ops.device('/cpu:2'):
read3 = resource_variable_ops.read_variable_op(
packed_var_1, dtype=dtypes.float32)
return read0, read1, read2, read3
arg_attrs = read_var.get_concrete_function().function_def.arg_attr
self.assertLen(arg_attrs, 2)
self.assertEqual(arg_attrs[0].attr['_composite_device'].s,
compat.as_bytes(packed_var_0.device))
self.assertEqual(arg_attrs[1].attr['_composite_device'].s,
compat.as_bytes(packed_var_1.device))
self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6))
def testImplementsAttributeBasic(self):
v = def_function.function(
experimental_implements='func')(lambda x, y: x + y)
with context.graph_mode(), self.cached_session():
a = array_ops.placeholder(dtypes.float32, ())
b = array_ops.placeholder(dtypes.float32, ())
v(a, b)
gradients_impl.gradients(v(a, b), [a, b])
fdefs = ops.get_default_graph().as_graph_def().library.function
self.assertLen(fdefs, 3)
not_present = 0
present = 0
for f in fdefs:
name = f.signature.name
if 'forward' in name or 'backward' in name:
not_present += 1
self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f)
else:
present += 1
self.assertEqual(f.attr[function.IMPLEMENTS_ATTRIBUTE_NAME].s,
'func'.encode('ascii'), f)
self.assertEqual(not_present, 2, fdefs)
self.assertEqual(present, 1, fdefs)
def testImplementsAttributeAssertsOnSideInput(self):
with context.graph_mode(), self.cached_session():
z = array_ops.zeros(0)
v = def_function.function(
experimental_implements='func')(lambda x, y: x + y + z)
a = array_ops.ones((1.0,))
b = array_ops.ones((1.0,))
with self.assertRaisesRegex(AssertionError,
'variables are always captured'):
v(a, b)
functions = ops.get_default_graph().as_graph_def().library.function
self.assertEmpty(functions)
def testImplementsAttributeWorksOnVariables(self):
with context.graph_mode(), self.cached_session():
v = def_function.function(
experimental_implements='func')(lambda x, y: x + y)
a = variables.Variable((1.0,))
b = variables.Variable((1.0,))
r1 = v(a, b)
_ = v(a, a)
functions = ops.get_default_graph().as_graph_def().library.function
# Verify that we created only one function
self.assertLen(functions, 1)
# Verify that eval() reads the current values.
a.initializer.run()
b.initializer.run()
self.assertEqual(r1.eval(), 2)
a.assign_add([1]).eval()
self.assertEqual(r1.eval(), 3)
def testImplementsAttributeWorksOnConstants(self):
with context.graph_mode(), self.cached_session():
v = def_function.function(
experimental_implements='func')(lambda x, y: x + y)
a = variables.Variable(1.0)
r1 = v(a, 2.)
r2 = v(2., a)
functions = ops.get_default_graph().as_graph_def().library.function
self.assertLen(functions, 1)
self.assertLen(functions[0].signature.input_arg, 2)
# Verify that eval() reads the current values.
a.initializer.run()
self.assertEqual(r1.eval(), 3)
self.assertEqual(r2.eval(), 3)
def testImplementsAttributeSpecializes(self):
with context.graph_mode(), self.cached_session():
v = def_function.function(
experimental_implements='func')(lambda x, y: x + y)
a = variables.Variable(1.0)
r1 = v(a, [2.])
r2 = v([2., 2], a)
functions = ops.get_default_graph().as_graph_def().library.function
self.assertLen(functions, 2)
# Ensure that all parameters are still there and haven't been inlined!
self.assertLen(functions[0].signature.input_arg, 2)
self.assertLen(functions[1].signature.input_arg, 2)
# Verify that eval() reads the current values.
a.initializer.run()
numpy.testing.assert_equal(r1.eval(), [3.])
numpy.testing.assert_equal(r2.eval(), [3., 3.])
def testImplementsAttributeAsNameAttrList(self):
implements_attr = (
'name: "embedding_matmul" attr { key: "key1" value { i: 2 } '
'} attr { key: "key2" value { b: false } }')
v = def_function.function(
experimental_implements=implements_attr)(lambda x, y: x + y)
with context.graph_mode(), self.cached_session():
a = array_ops.placeholder(dtypes.float32, ())
b = array_ops.placeholder(dtypes.float32, ())
v(a, b)
gradients_impl.gradients(v(a, b), [a, b])
fdefs = ops.get_default_graph().as_graph_def().library.function
self.assertLen(fdefs, 3)
not_present = 0
present = 0
for f in fdefs:
name = f.signature.name
if 'forward' in name or 'backward' in name:
not_present += 1
self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f)
else:
present += 1
attr_value = f.attr[function.IMPLEMENTS_ATTRIBUTE_NAME]
self.assertIsNotNone(attr_value.func, f)
self.assertEqual(attr_value.func.name, 'embedding_matmul')
name_attrs = attr_value.func.attr
self.assertLen(name_attrs, 2)
self.assertEqual(not_present, 2, fdefs)
self.assertEqual(present, 1, fdefs)
def testExternalControlDependency(self):
with ops.Graph().as_default(), self.test_session():
v = variables.Variable(1.0)
v.initializer.run()
op = v.assign_add(1.0)
@function.defun
def f():
with ops.control_dependencies([op]):
return 1.0
self.evaluate(f())
self.assertAllEqual(self.evaluate(v), 2.0)
def testInputShapeFunctionRelaxation(self):
unknown_dim = [False]
@function.defun(experimental_relax_shapes=True)
def func(a):
if a._shape_tuple()[0] is None:
unknown_dim[0] = True
return a + 1
func(constant_op.constant([]))
self.assertFalse(unknown_dim[0])
self.assertLen(total_function_cache(func), 1)
func(constant_op.constant([1.0]))
self.assertFalse(unknown_dim[0])
self.assertLen(total_function_cache(func), 2)
func(constant_op.constant([1.0, 2.0]))
self.assertTrue(unknown_dim[0])
self.assertLen(total_function_cache(func), 2)
def testInputShapeRelaxationOnInstanceMethod(self):
# Test that experimental_relax_shapes is passed during
# instance method bounding.
unknown_dim = [False]
class Foo(object):
@def_function.function(experimental_relax_shapes=True)
def func(self, a):
if a._shape_tuple()[0] is None:
unknown_dim[0] = True
return a + 1
foo = Foo()
foo.func(constant_op.constant([]))
self.assertFalse(unknown_dim[0])
foo.func(constant_op.constant([1.0]))
self.assertFalse(unknown_dim[0])
foo.func(constant_op.constant([1.0, 2.0]))
self.assertTrue(unknown_dim[0])
def testInputShapeFunctionRelaxationWithRaggedTensors(self):
traced_type_spec = [None]
@def_function.function(experimental_relax_shapes=True)
def func(x):
traced_type_spec[0] = x._type_spec
return x
def check_trace(x, expected_trace):
traced_type_spec[0] = None
func(x)
self.assertEqual(traced_type_spec[0], expected_trace)
check_trace( # Initial call gets traced.
ragged_factory_ops.constant([[1], [2, 3, 4]]),
ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32))
check_trace( # Input TypeSpec is the same -> no retrace.
ragged_factory_ops.constant([[1, 2], [3, 4]]), None)
check_trace( # Even if component tensor shapes change -> no retrace.
ragged_factory_ops.constant([[1, 2], [3, 4, 5, 6]]), None)
check_trace( # Different TypeSpec shape (nrows): retrace
ragged_factory_ops.constant([[1], [2], [3]]),
ragged_tensor.RaggedTensorSpec([3, None], dtypes.int32))
check_trace( # Different nrows again: relax & retrace
ragged_factory_ops.constant([[1], [2], [3], [4]]),
ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32))
check_trace( # Different nrows yet again: not retrace
ragged_factory_ops.constant([[1]]), None)
check_trace( # Different ragged_rank: retrace
ragged_factory_ops.constant([[[1]]]),
ragged_tensor.RaggedTensorSpec([1, None, None], dtypes.int32))
check_trace( # Different ragged_rank again: retrace & relax
ragged_factory_ops.constant([[[1]], [[2]]]),
ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32))
def testInputShapeFunctionRelaxationWithStructuredTensors(self):
traced_type_spec = [None]
@def_function.function(experimental_relax_shapes=True)
def func(x):
traced_type_spec[0] = x._type_spec
return x
def check_trace(x, expected_trace):
traced_type_spec[0] = None
func(x)
self.assertEqual(traced_type_spec[0], expected_trace)
# If we have TypeSpecs that differ in ways other than just their shape,
# then retrace each time.
check_trace(
structured_tensor.StructuredTensor.from_pyval({'a': [1]}),
structured_tensor.StructuredTensorSpec(
[], {'a': tensor_spec.TensorSpec((1,), dtypes.int32)}))
check_trace(
structured_tensor.StructuredTensor.from_pyval({'b': [1]}),
structured_tensor.StructuredTensorSpec(
[], {'b': tensor_spec.TensorSpec((1,), dtypes.int32)}))
check_trace(
structured_tensor.StructuredTensor.from_pyval({'c': [1]}),
structured_tensor.StructuredTensorSpec(
[], {'c': tensor_spec.TensorSpec((1,), dtypes.int32)}))
# But if we call again with only shape different, then do relax:
check_trace( # retrace
structured_tensor.StructuredTensor.from_pyval({'a': [1, 2]}),
structured_tensor.StructuredTensorSpec(
[], {'a': tensor_spec.TensorSpec((2,), dtypes.int32)}))
check_trace( # relax & retrace
structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3]}),
structured_tensor.StructuredTensorSpec(
[], {'a': tensor_spec.TensorSpec((None,), dtypes.int32)}))
check_trace( # use relaxed graph
structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3, 4]}),
None)
def testInputShapeFunctionRelaxationWithDatasetIterators(self):
# For dataset iterators, the TypeSpec includes type information that's
# not derivable from the component tensors. Make sure that the TypeSpec
# shapes get relaxed as appropriate.
traced_type_spec = [None]
@def_function.function(experimental_relax_shapes=True)
def func(x):
traced_type_spec[0] = x._type_spec
return x
def check_trace(x, expected_trace):
traced_type_spec[0] = None
func(x)
self.assertEqual(traced_type_spec[0], expected_trace)
ds_1_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([1, 2]))
ds_2_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 2]))
ds_3_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([3, 2]))
ds_4_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([4, 2]))
ds_2_1 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 1]))
check_trace( # shape=[1, 2]: retrace
dataset_ops.make_one_shot_iterator(ds_1_2),
iterator_ops.IteratorSpec(
tensor_spec.TensorSpec([1, 2], dtypes.float32)))
check_trace( # shape=[1, 2]: no retrace (use the [1, 2] graph)
dataset_ops.make_one_shot_iterator(ds_1_2), None)
check_trace( # shape=[2, 2]: retrace
dataset_ops.make_one_shot_iterator(ds_2_2),
iterator_ops.IteratorSpec(
tensor_spec.TensorSpec([2, 2], dtypes.float32)))
check_trace( # shape=[3, 2]: relax to [None, 2] and retrace
dataset_ops.make_one_shot_iterator(ds_3_2),
iterator_ops.IteratorSpec(
tensor_spec.TensorSpec([None, 2], dtypes.float32)))
check_trace( # shape=[4, 2]: no retrace (use the [None, 2] graph)
dataset_ops.make_one_shot_iterator(ds_4_2), None)
check_trace( # shape=[2, 1]: relax to [None, None] and retrace
dataset_ops.make_one_shot_iterator(ds_2_1),
iterator_ops.IteratorSpec(
tensor_spec.TensorSpec([None, None], dtypes.float32)))
def testCapturesVariables(self):
a = variables.Variable(1.0, trainable=False)
b = variables.Variable(1.0)
cc = [None]
@def_function.function
def f():
c = cc[0]
if c is None:
c = cc[0] = variables.Variable(1.)
return a + b + c + 1
cf = f.get_concrete_function()
c = cc[0]
captured_variables = {v.ref() for v in (a, b, c)}
trainable_variables = {v.ref() for v in (b, c)}
self.assertEqual({v.ref() for v in cf.variables}, captured_variables)
self.assertEqual({v.ref() for v in cf.trainable_variables},
trainable_variables)
self.assertEqual(cf.variables, cf.graph.variables)
self.assertEqual(cf.trainable_variables, cf.graph.trainable_variables)
def testNestedInputShapeFunctionRelaxation(self):
unknown_dim = [False]
@function.defun(experimental_relax_shapes=True)
def func(a_, b_=None):
del a_ # Only used to check which cache is used.
self.assertEqual(b_[0]._shape_tuple(), ())
if b_[1]._shape_tuple()[0] is None:
unknown_dim[0] = True
return b_[0] + 1
a = 'hi'
b0 = constant_op.constant(1.0)
func(a, b_=[b0, constant_op.constant([])])
self.assertFalse(unknown_dim[0])
self.assertLen(total_function_cache(func), 1)
func(a, b_=[b0, constant_op.constant([1.0])])
self.assertFalse(unknown_dim[0])
self.assertLen(total_function_cache(func), 2)
func(a, b_=[b0, constant_op.constant([1.0, 1.0])])
self.assertTrue(unknown_dim[0])
self.assertLen(total_function_cache(func), 2)
unknown_dim[0] = False
# Now do the same except with a new a which is not a tensor; this should
# change the cache key.
a = 'bye'
func(a, b_=[b0, constant_op.constant([])])
self.assertFalse(unknown_dim[0])
self.assertLen(total_function_cache(func), 3)
# Since we already marked a cache miss for a function with the same
# non-input signatures, here we will immediately start relaxing shapes.
func(a, b_=[b0, constant_op.constant([1.0])])
self.assertTrue(unknown_dim[0])
self.assertLen(total_function_cache(func), 3)
def testNestedShapeFunctionRelaxation(self):
got_shape = [None]
# The inner function will go through shape relaxation because the shapes it
# receives will be [1], [2], [3], ...
@def_function.function(experimental_relax_shapes=True)
def bar(x_shape):
got_shape[0] = x_shape._shape_tuple()
return x_shape
# The outer function will not go through shape relaxation because the shapes
# it receives will be [1], [[1]], [[[1]]], ...
@def_function.function(experimental_relax_shapes=True)
def foo(ones):
return bar(array_ops.shape(ones))
for rank in range(1, 6):
x_shape = self.evaluate(foo(array_ops.ones([1] * rank)))
self.assertAllEqual(x_shape, [1] * rank)
if rank < 3:
self.assertEqual(got_shape[0], (rank,))
else:
self.assertEqual(got_shape[0], (None,))
def testNoHash(self):
@def_function.function()
def f(_):
return 1.0
with self.assertRaisesRegex(ValueError, r'Got type: set'):
f(set([]))
def testFuncName(self):
@function.defun_with_attributes(attributes={'func_name': 'multiply'})
def add(x, y):
_ = x * y
return x + y
@function.defun
def add_2(x, y):
_ = x * y
return x + y
self.assertEqual(add._name, 'multiply')
self.assertEqual(add_2._name, 'add_2')
def testBasicGraphMode(self):
matmul = def_function.function(math_ops.matmul)
@def_function.function
def sq(a):
return matmul(a, a)
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
out = sq(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testNestedInputsGraphMode(self):
matmul = def_function.function(math_ops.matmul)
pair = collections.namedtuple('pair', ['a', 'b'])
@def_function.function
def a_times_b(inputs):
return matmul(inputs.a['a'], inputs.b['b'])
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
out = a_times_b(pair({'a': t}, {'b': t}))
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testNestedOutputsGraphMode(self):
matmul = def_function.function(math_ops.matmul)
pair = collections.namedtuple('pair', ['a', 'b'])
@def_function.function()
def pairs_mul(pair_a, pair_b):
return pair(matmul(pair_a.a, pair_b.a), matmul(pair_a.b, pair_b.b))
a = constant_op.constant([[1.0, 2.0], [1.0, 2.0]])
b = constant_op.constant([[3.0, 4.0], [3.0, 4.0]])
out = pairs_mul(pair(a, b), pair(b, a))
expected = pair(math_ops.matmul(a, b).numpy(),
math_ops.matmul(b, a).numpy())
self.assertAllClose(out, expected)
@parameterized.named_parameters(
dict(testcase_name='Defun',
function_decorator=function.defun),
dict(testcase_name='DefFunction',
function_decorator=def_function.function))
def testNestedFunctionGraphNotOutOfDate(self, function_decorator):
@function_decorator
def f():
return constant_op.constant(1.)
class _Model(object):
@function_decorator
def g(self):
self.f = f.get_concrete_function()
model = _Model()
model.g()
concrete = model.f
weak_g_graph = weakref.ref(model.g.get_concrete_function().graph)
self.assertIs(weak_g_graph(), concrete.graph.outer_graph)
weak_g = weakref.ref(model.g)
del model
self.assertIsNone(weak_g())
self.assertIsNone(weak_g_graph())
self.assertIsNotNone(concrete.graph.outer_graph)
self.assertIs(ops.get_default_graph(), concrete.graph.outer_graph)
def testGraphEagerIsolation(self):
@function.defun
def f():
self.v = variables.Variable(1.0)
return self.v.read_value()
self.assertAllEqual(f(), 1.0)
with ops.Graph().as_default():
self.assertEqual(f().shape, ())
def testBasicGraphFunction(self):
matmul = def_function.function(math_ops.matmul)
@def_function.function
def sq(a):
return matmul(a, a)
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq_op = sq.get_concrete_function(t)
self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
out = sq_op(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testGetConcreteFunctionThreadSafety(self):
@def_function.function
def sq():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
return math_ops.matmul(t, t)
concrete_functions = []
def thread_func(_):
cf = sq.get_concrete_function()
concrete_functions.append(cf)
num_threads = 100
pool = multiprocessing.pool.ThreadPool(num_threads)
_ = pool.map(thread_func, list(range(num_threads)))
self.assertLen(set(concrete_functions), 1)
def testGetConcreteFunctionThreadSafetyWithArgs(self):
@def_function.function
def add_100(*args):
return math_ops.add_n(args)
p = multiprocessing.pool.ThreadPool(2)
args = (constant_op.constant(1.),) * 100
f1, f2 = p.map(add_100.get_concrete_function, [args] * 2)
# I see about len(args) + max(0, len(args) - 3) arguments expected.
f1(*args)
del f2
def testInputSpecGraphFunction(self):
matmul = def_function.function(math_ops.matmul)
@def_function.function
def sq(a):
return matmul(a, a)
sq_op = sq.get_concrete_function(
tensor_spec.TensorSpec((None, None), dtypes.float32))
self.assertEqual([None, None], sq_op.output_shapes.as_list())
t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
out1 = sq_op(t1)
self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy())
t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
out2 = sq_op(t2)
self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy())
def testNestedInputSpecGraphFunction(self):
matmul = def_function.function(math_ops.matmul)
@def_function.function
def sq(mats):
((a, b),) = mats
return matmul(a, b)
sq_op_autonamed = sq.get_concrete_function(
[(tensor_spec.TensorSpec((None, None), dtypes.float32),
tensor_spec.TensorSpec((None, None), dtypes.float32))])
self.assertEqual([None, None], sq_op_autonamed.output_shapes.as_list())
sq_op = sq.get_concrete_function(
[(tensor_spec.TensorSpec((None, None), dtypes.float32,
name='first_mat'),
tensor_spec.TensorSpec((None, None), dtypes.float32,
name='second_mat'))])
self.assertEqual([None, None], sq_op.output_shapes.as_list())
t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]])
out = sq_op(first_mat=t1, second_mat=t2)
self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy())
self.assertAllEqual(sq_op_autonamed(t1, t2),
math_ops.matmul(t1, t2).numpy())
def testExecutingStatelessDefunConcurrently(self):
@def_function.function
def stateless(x):
return math_ops.multiply(2.0, x)
pool = multiprocessing.pool.ThreadPool()
inputs = [constant_op.constant(1.0 * x) for x in range(100)]
outputs = [float(out) for out in pool.map(stateless, inputs)]
expected = [float(2.0 * x) for x in inputs]
self.assertSequenceEqual(outputs, expected)
def testExecutingManyStatelessDefunsConcurrently(self):
@def_function.function
def stateless(x):
del x
return math_ops.multiply(2.0, 2.0)
pool = multiprocessing.pool.ThreadPool()
# `pool.map` below instantiates 100 functions, one for each object.
objects = [object() for _ in range(100)]
outputs = [float(out) for out in pool.map(stateless, objects)]
expected = [4.0] * 100
self.assertSequenceEqual(outputs, expected)
@test_util.disable_tfrt('b/169431085: This test is flaky on tfrt')
def testExecutingStatefulDefunConcurrently(self):
v = resource_variable_ops.ResourceVariable(1.0)
@def_function.function
def stateful(x):
v.assign(x)
pool = multiprocessing.pool.ThreadPool()
inputs = [constant_op.constant(0.0)] * 100
pool.map(stateful, inputs)
self.assertEqual(float(v.read_value()), 0.0)
def testExecutingManyStatefulDefunsConcurrently(self):
v = resource_variable_ops.ResourceVariable(1.0)
@def_function.function
def stateful(x):
del x
return v.assign(0.0)
pool = multiprocessing.pool.ThreadPool()
# `pool.map` below instantiates 100 functions, one for each object.
pool.map(stateful, [object() for _ in range(100)])
self.assertEqual(float(v.read_value()), 0.0)
def testShareRendezvous(self):
# Disable grappler from inlining the functions. Note we run the send & recv
# in graph mode since with eager mode the function should automatically be
# inlined.
context.context().set_optimizer_experimental_options(
{'disable_meta_optimizer': True})
cpu = '/device:CPU:0'
signature = [tensor_spec.TensorSpec([], dtypes.int32)]
@def_function.function
def send():
x = constant_op.constant(1)
gen_sendrecv_ops.send(x, 'x', cpu, 0, cpu)
return x
send._shared_rendezvous = True # pylint: disable=protected-access
@def_function.function(input_signature=signature)
def send_body(n):
send()
return n - 1
@def_function.function
def recv():
return gen_sendrecv_ops.recv(dtypes.int32, 'x', cpu, 0, cpu)
recv._shared_rendezvous = True # pylint: disable=protected-access
@def_function.function(input_signature=signature)
def recv_body(n):
recv()
return n - 1
@def_function.function(input_signature=signature)
def cond(n):
return n > 0
# Instead of calling the send & recv functions directly we want to call them
# through a functional while to ensure the rendezvous is shared across the
# while boundary.
@def_function.function
def fn(n):
functional_ops.While([n], cond.get_concrete_function(),
send_body.get_concrete_function())
return functional_ops.While([n], cond.get_concrete_function(),
recv_body.get_concrete_function())
# Use a graph context since functions will not be automatically inlined
with context.graph_mode(), self.cached_session():
self.evaluate(fn(2))
def disabled_testRandomSeed(self):
@def_function.function
def f():
return random_ops.random_normal(())
random_seed.set_random_seed(1)
x = f()
self.assertNotEqual(x, f())
random_seed.set_random_seed(1)
self.assertAllEqual(f(), x)
def testNestedInputsGraphFunction(self):
matmul = def_function.function(math_ops.matmul)
pair = collections.namedtuple('pair', ['a', 'b'])
@def_function.function
def a_times_b(inputs):
return matmul(inputs.a['a'], inputs.b['b'])
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq_op = a_times_b.get_concrete_function(
pair(dict(a=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'a')),
dict(b=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'b'))))
self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
out = sq_op(a=t, b=t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testNestedOutputGraphFunction(self):
matmul = def_function.function(math_ops.matmul)
@def_function.function
def sq(a):
return (matmul(a, a), {'b': constant_op.constant(1.0)})
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq_op = sq.get_concrete_function(t)
self.assertEqual(sq_op.output_shapes,
(tensor_shape.TensorShape([2, 2]),
{'b': tensor_shape.TensorShape([])}))
self.assertEqual(sq_op.output_dtypes,
(dtypes.float32, {'b': dtypes.float32}))
(a, b) = sq_op(t)
self.assertAllEqual(a, math_ops.matmul(t, t).numpy())
self.assertAllEqual(b['b'].numpy(), 1.0)
def testGraphFunctionNoneOutput(self):
@def_function.function
def fn(unused_a, unused_b):
return None
x = constant_op.constant(1)
fn_op = fn.get_concrete_function(x, x)
self.assertEqual(fn_op.output_dtypes, None)
self.assertEqual(fn_op.output_shapes, None)
self.assertAllEqual(fn_op(x, x), None)
def testDefunNumpyArraysConvertedToTensors(self):
def f(x):
self.assertIsInstance(x, ops.Tensor)
return x
x = random_ops.random_uniform([2, 2]).numpy()
defined = function.defun(f)
defined(x)
self.assertLen(total_function_cache(defined), 1)
x = random_ops.random_uniform([2, 2]).numpy()
defined(x)
# A NumPy array with different values but the same shape and dtype
# shouldn't trigger another function definition.
self.assertLen(total_function_cache(defined), 1)
np_ones = numpy.ones([], numpy.float32)
np_zeros = numpy.zeros([], numpy.float32)
tf_ones = array_ops.ones([])
tf_zeros = array_ops.zeros([])
# Test that the numpy array is properly an argument to the graph function.
self.assertEqual(1., defined(np_ones).numpy())
self.assertLen(total_function_cache(defined), 2)
self.assertEqual(0., defined(np_zeros).numpy())
self.assertEqual(1., defined(tf_ones).numpy())
self.assertEqual(0., defined(tf_zeros).numpy())
self.assertLen(total_function_cache(defined), 2)
# Test that mutable inputs are supported.
mutable = numpy.ones([], numpy.float32)
self.assertEqual(1., defined(mutable).numpy())
mutable.fill(0)
self.assertEqual(0., defined(mutable).numpy())
class MyNdarray(numpy.ndarray):
pass
# Test that the subclasses of ndarray are converted too.
self.assertEqual(1., defined(np_ones.view(MyNdarray)).numpy())
self.assertEqual(0., defined(np_zeros.view(MyNdarray)).numpy())
# We should not have triggered any re-tracing of the python function.
self.assertLen(total_function_cache(defined), 2)
def testNumpyDtypeInputSupported(self):
@function.defun
def f(x, dtype):
return constant_op.constant(dtype(x))
self.assertEqual(f(1, numpy.float32).numpy(), numpy.float32(1))
self.assertEqual(f(2, numpy.float32).numpy(), numpy.float32(2))
self.assertEqual(f(1, numpy.int32).numpy(), numpy.int32(1))
self.assertEqual(f(2, numpy.int32).numpy(), numpy.int32(2))
def testDefunNumpyArraysConvertedToTensorsInKwargs(self):
def f(**kwargs):
x = kwargs.pop('x')
self.assertIsInstance(x, ops.Tensor)
return x
x = random_ops.random_uniform([2, 2]).numpy()
defined = function.defun(f)
defined(x=x)
self.assertLen(total_function_cache(defined), 1)
x = random_ops.random_uniform([2, 2]).numpy()
defined(x=x)
# A NumPy array with different values but the same shape and dtype
# shouldn't trigger another function definition.
self.assertLen(total_function_cache(defined), 1)
# Test that the numpy array is properly an argument to the graph function.
self.assertEqual(1., defined(x=numpy.ones([])).numpy())
self.assertEqual(0., defined(x=numpy.zeros([])).numpy())
self.assertEqual(1., defined(x=array_ops.ones([])).numpy())
self.assertEqual(0., defined(x=array_ops.zeros([])).numpy())
def testDefunCapturedInt32(self):
x = constant_op.constant(1, dtype=dtypes.int32)
@def_function.function
def add_int32s():
return x + x
self.assertEqual(2, int(add_int32s()))
def testDefunReadVariable(self):
v = resource_variable_ops.ResourceVariable(1.0)
@def_function.function
def f():
return v.read_value()
self.assertEqual(1.0, float(f()))
def testDefunAssignAddVariable(self):
v = resource_variable_ops.ResourceVariable(1.0)
x = constant_op.constant(2.0)
@def_function.function
def test_assign_add():
v.assign_add(x)
return v.read_value()
self.assertEqual(3.0, float(test_assign_add()))
@test_util.run_in_graph_and_eager_modes
def testTensorInitializationInFunctionRaisesError(self):
error_msg = ('Tensor-typed variable initializers must either be '
'wrapped in an init_scope or callable.*')
@def_function.function
def tensor_init():
with self.assertRaisesRegex(ValueError, error_msg):
resource_variable_ops.ResourceVariable(constant_op.constant(2.0))
tensor_init()
@test_util.run_in_graph_and_eager_modes
def testCallableTensorInitializationInFunction(self):
@def_function.function
def tensor_init():
self.v = resource_variable_ops.ResourceVariable(
lambda: constant_op.constant(2.0))
return self.v.read_value()
value = tensor_init()
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
self.assertEqual(self.evaluate(value), 2.0)
@test_util.also_run_as_tf_function
def testInitScopeTensorInitializationInFunction(self):
@def_function.function
def tensor_init():
with ops.init_scope():
const = constant_op.constant(2.0)
# Note: this variable bypasses tf.function's variable creation
# requirements by bypassing variable_creator_scope by using
# ResourceVariable instead of Variable.
self.v = resource_variable_ops.ResourceVariable(const)
return self.v.read_value()
value = tensor_init()
self.assertAllEqual(value, 2.0)
@test_util.run_in_graph_and_eager_modes
def testGetConcreteFunctionCreatesVariables(self):
v_holder = []
@def_function.function
def tensor_init():
if not v_holder:
v_holder.append(variables.Variable(5.))
return v_holder[0].read_value()
concrete = tensor_init.get_concrete_function()
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual(5., self.evaluate(concrete()))
self.assertAllEqual(5., self.evaluate(tensor_init()))
def testFuncGraphCaptureByValue(self):
v = variables.Variable(1.0)
def trivial_function():
return v.read_value()
graph_function = function.Function(
trivial_function, 'test', capture_by_value=True)
self.assertAllEqual(graph_function(), 1.0)
v.assign(2.0)
self.assertAllEqual(graph_function(), 1.0)
def testFuncGraphCaptureByValueNested(self):
v = variables.Variable(1.0)
def trivial_function():
return control_flow_ops.cond(
array_ops.placeholder_with_default(True, ()),
v.read_value, v.read_value)
graph_function = function.Function(
trivial_function, 'test', capture_by_value=True)
self.assertAllEqual(graph_function(), 1.0)
v.assign(2.0)
self.assertAllEqual(graph_function(), 1.0)
def testDefunShapeInferenceWithCapturedResourceVariable(self):
v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
# We do not return v directly since the tensor conversion function of
# ResourceVariable returns the read value and not the resource itself.
return v._handle
compiled = def_function.function(f)
var_handle = compiled()
self.assertEqual(var_handle.dtype, dtypes.resource)
self.assertEqual(var_handle.shape, tensor_shape.TensorShape([]))
var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
def testShapeInferenceForMoreSpecificInput(self):
def f(a):
return array_ops.reshape(a, [-1, 3])
signature = [tensor_spec.TensorSpec(None, dtypes.float32)]
compiled = def_function.function(f, input_signature=signature)
@def_function.function
def use_f():
inputs = array_ops.zeros([10, 10, 3])
self.assertAllEqual(f(inputs).shape, compiled(inputs).shape)
use_f()
def testFuncListAttr(self):
@function.defun
def test_function(val):
def fn1():
return array_ops.ones([10])
fn2 = lambda: array_ops.ones([10]) * 2
def fn3(x=3):
return array_ops.ones([10]) * x
fn4 = functools.partial(fn3, x=4)
fn5 = functools.partial(fn3, 5)
return gen_functional_ops.case(val, [], [dtypes.float32],
[function.defun(f).get_concrete_function()
for f in (fn1, fn2, fn3, fn4, fn5)])
ones = array_ops.ones([10])
self.assertAllEqual([ones], test_function(0))
self.assertAllEqual([ones * 2], test_function(1))
self.assertAllEqual([ones * 3], test_function(2))
self.assertAllEqual([ones * 4], test_function(3))
self.assertAllEqual([ones * 5], test_function(4))
self.assertAllEqual([ones * 5], test_function(22)) # default branch
@test_util.enable_control_flow_v2
def testVariableInLoopInFunction(self):
@function.defun
def test_function():
def loop_test(_):
return False
def loop_body(_):
return variable_scope.get_variable('a', shape=())
return control_flow_ops.while_loop(loop_test, loop_body, [0.0])
self.assertEqual(test_function().shape, [])
def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self):
with context.graph_mode():
v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
# We do not return v directly since the tensor conversion function of
# ResourceVariable returns the read value and not the resource itself.
return v._handle
compiled = def_function.function(f)
var_handle = compiled()
self.assertEqual(var_handle.dtype, dtypes.resource)
self.assertEqual(var_handle.shape, tensor_shape.TensorShape([]))
var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
def testDefunShapeInferenceWithCapturedVariableInGraphMode(self):
with context.graph_mode():
v = variables.Variable([[1, 2], [3, 4]])
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
# Check that shape inference works while creating the defun
compiled = def_function.function(f)
compiled()
def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self):
with context.graph_mode():
tensor_list = list_ops.empty_tensor_list(
element_dtype=dtypes.float32,
element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
tensor_list = list_ops.tensor_list_push_back(tensor_list,
constant_op.constant(1.0))
tensor_list = list_ops.tensor_list_push_back(tensor_list,
constant_op.constant(2.0))
def f():
tl, value = list_ops.tensor_list_pop_back(
tensor_list, element_dtype=dtypes.float32)
self.assertEqual(value.shape, tensor_shape.TensorShape([]))
return tl
compiled = def_function.function(f)
output_tensor_list = compiled()
_, value = list_ops.tensor_list_pop_back(
output_tensor_list, element_dtype=dtypes.float32)
self.assertEqual(value.shape, tensor_shape.TensorShape([]))
@test_util.run_in_graph_and_eager_modes
def testDefunForcesResourceVariables(self):
def variable_creator():
self.v = variables.Variable(0.0)
return self.v.read_value()
self.v = None
defined = function.defun(variable_creator)
defined() # Create the variable.
self.assertIsInstance(
self.v, resource_variable_ops.ResourceVariable)
@test_util.disable_tfrt('b/169294215')
def testRunMetadata(self):
@def_function.function
def f(x):
return x * x
with ops.device('cpu:0'):
context.enable_run_metadata()
f(constant_op.constant(1.0))
run_metadata = context.export_run_metadata()
context.disable_run_metadata()
self.assertLen(run_metadata.partition_graphs, 1)
def testGraphModeCaptureVariable(self):
with context.graph_mode(), self.cached_session():
class HasAVar(object):
def __init__(self):
self.v = resource_variable_ops.ResourceVariable(1.0)
def call(self):
return self.v * 2
o = HasAVar()
self.evaluate(variables.global_variables_initializer())
call = def_function.function(o.call)
op = call()
self.assertAllEqual(self.evaluate(op), 2.0)
def testGraphModeManyFunctions(self):
with ops.Graph().as_default(), self.cached_session():
@def_function.function
def f(x):
return x * x
@def_function.function
def g(x):
return f(x) + 1
self.assertAllEqual(g(constant_op.constant(2.0)), 5.0)
def testDict(self):
@def_function.function
def f(x):
return {'name': x + 1}
self.assertAllEqual(f(constant_op.constant(1.0))['name'], 2.0)
def testTensorConversionWithDefun(self):
@def_function.function
def f(x):
return math_ops.add(x, constant_op.constant(3))
self.assertAllEqual(5, f(constant_op.constant(2)))
def testTensorConversionCall(self):
@def_function.function
def f(x):
return math_ops.add(x, constant_op.constant(3))
@def_function.function
def g(x):
return f(f(x))
self.assertAllEqual(8, g(constant_op.constant(2)))
def testCallShape(self):
@def_function.function
def f(x):
return x + 1
@def_function.function
def g(x):
x = f(x)
self.assertEqual(x.shape.as_list(), [])
return None
g(constant_op.constant(1.0))
def testNestedDefunWithNoOutputAndTapedInput(self):
three = resource_variable_ops.ResourceVariable(3.0, name='v')
@def_function.function
def f(x):
# This function intentionally takes a taped variable as input,
# but does not return any values
math_ops.add(x, three)
@def_function.function
def g(x):
y = math_ops.add(x, three)
f(y)
g(three)
def testGatherResourceWithDefun(self):
with ops.device('cpu:0'):
v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
def sum_gather():
return math_ops.reduce_sum(array_ops.gather(v, [1, 2]))
defined = def_function.function(sum_gather)
self.assertAllEqual(sum_gather(), defined())
@parameterized.named_parameters([
('IndexedSlicesWithDenseShape',
_example_indexed_slices_with_dense_shape,),
('IndexedSlicesWithoutDenseShape',
_example_indexed_slices_without_dense_shape,),
('RaggedTensorRaggedRank1', ragged_tensor.RaggedTensor.from_row_lengths,
{'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}),
('RaggedTensorRaggedRank2',
ragged_tensor.RaggedTensor.from_nested_row_lengths,
{'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}),
('SparseTensor', sparse_tensor.SparseTensor,
{'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}),
]) # pyformat: disable
def testReturnCompositeTensorWithDefun(self,
factory_fn,
factory_kwargs={},
input_signature=None):
input_ct = factory_fn(**factory_kwargs)
@def_function.function(input_signature=input_signature)
def f():
return input_ct
output_ct = f()
self.assertIsInstance(output_ct, type(input_ct))
nest.assert_same_structure(input_ct, output_ct, expand_composites=True)
input_flat = nest.flatten(input_ct, expand_composites=True)
output_flat = nest.flatten(output_ct, expand_composites=True)
for (input_component, output_component) in zip(input_flat, output_flat):
self.assertAllEqual(input_component, output_component)
@parameterized.named_parameters([
('IndexedSlicesWithDenseShape',
_example_indexed_slices_with_dense_shape,),
('IndexedSlicesWithoutDenseShape',
_example_indexed_slices_without_dense_shape,),
('RaggedTensorRaggedRank1',
ragged_tensor.RaggedTensor.from_row_lengths,
{'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}),
('RaggedTensorRaggedRank2',
ragged_tensor.RaggedTensor.from_nested_row_lengths,
{'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}),
('SparseTensor',
sparse_tensor.SparseTensor,
{'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}),
('RaggedTensorRaggedRank1WithSignature',
ragged_tensor.RaggedTensor.from_row_lengths,
{'values': [1, 2, 3], 'row_lengths': [2, 0, 1]},
[ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)]),
('RaggedTensorRaggedRank2WithSignature',
ragged_tensor.RaggedTensor.from_nested_row_lengths,
{'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]},
[ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32)]),
('SparseTensorWithSignature',
sparse_tensor.SparseTensor,
{'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]},
[sparse_tensor.SparseTensorSpec([None], dtypes.int32)]),
]) # pyformat: disable
def testCompositeAsArgumentTensorWithDefun(self,
factory_fn,
factory_kwargs={},
input_signature=None):
input_ct = factory_fn(**factory_kwargs)
@def_function.function(input_signature=input_signature)
def f(x):
return x
output_ct = f(input_ct)
self.assertIsInstance(output_ct, type(input_ct))
nest.assert_same_structure(input_ct, output_ct, expand_composites=True)
input_flat = nest.flatten(input_ct, expand_composites=True)
output_flat = nest.flatten(output_ct, expand_composites=True)
for (input_component, output_component) in zip(input_flat, output_flat):
self.assertAllEqual(input_component, output_component)
def testTracedCompositeDiscardsShapeInfo(self):
# SparseTensorSpec intentionally excludes info about the number of elements
# that are in a sparse tensor (which is recorded as st.indices.shape[0] and
# st.values.shape[0]). Similarly, RaggedTensorSpec intentionally excludes
# info about the total number of values in a RaggedTensor (stored as
# rt.values.shape[0]). This test checks that the placeholders created by
# tf.function() properly mask this shape info.
@def_function.function
def f(rt, st):
self.assertEqual(st.indices.shape.as_list()[:1], [None])
self.assertEqual(st.values.shape.as_list(), [None])
return (rt, st)
rt = ragged_factory_ops.constant([[1, 2], [3]])
st = sparse_tensor.SparseTensor([[0]], [0], [10])
f(rt, st)
@test_util.run_gpu_only
def testFunctionOnDevice(self):
x = constant_op.constant([1.]).gpu()
f = def_function.function(math_ops.add)
y = f(x, x).cpu()
self.assertAllEqual(y, [2.])
@test_util.run_gpu_only
@test_util.run_in_graph_and_eager_modes
def testFunctionWithResourcesOnDifferentDevices(self):
with ops.device('/cpu:0'):
v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
with ops.device('/gpu:0'):
v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
def sum_gather():
cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2]))
gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2]))
return cpu_result, gpu_result
defined = function.defun(sum_gather)
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
expected = self.evaluate(sum_gather())
self.assertAllEqual(expected, self.evaluate(defined()))
@test_util.run_gpu_only
@test_util.run_in_graph_and_eager_modes
def testOpInFunctionWithConflictingResourceInputs(self):
with ops.device('/cpu:0'):
v_cpu = resource_variable_ops.ResourceVariable(
[0.0, 1.0, 2.0], name='cpu')
v_also_cpu = resource_variable_ops.ResourceVariable(
[0.0, 1.0, 2.0], name='also_cpu')
with ops.device('/gpu:0'):
v_gpu = resource_variable_ops.ResourceVariable(
[0.0, 1.0, 2.0], name='gpu')
@def_function.function
def resource_apply_adam():
training_ops.resource_apply_adam(
v_cpu.handle,
v_gpu.handle,
v_also_cpu.handle,
1.0, # beta1_power
1.0, # beta2_power
1.0, # learning_rate
1.0, # beta1
1.0, # beta2
1.0, # epsilon,
[1.0, 1.0, 1.0], # grad
False) # use_locking
return None
with self.assertRaisesRegex(
errors.InvalidArgumentError,
'Cannot place the graph because a reference or resource edge connects '
'colocation groups with incompatible assigned devices'):
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
self.evaluate(resource_apply_adam())
@test_util.run_gpu_only
def testFunctionHandlesInputsOnDifferentDevices(self):
# The Reshape op requires the shape tensor to be placed in host memory.
reshape = def_function.function(array_ops.reshape)
value = constant_op.constant([1., 2.]).gpu()
shape = constant_op.constant([2, 1])
reshaped = reshape(value, shape).cpu()
self.assertAllEqual(reshaped, [[1], [2]])
@test_util.run_gpu_only
def testFunctionHandlesInputsPlacedOnTheWrongDeviceGracefully(self):
# The Reshape op requires the shape tensor to be placed in host memory.
reshape = def_function.function(array_ops.reshape)
value = constant_op.constant([1., 2.])
shape = constant_op.constant([2, 1]).gpu()
reshape(value, shape) # No error is raised
def testNoneOutput(self):
@def_function.function
def my_function(_):
return None
self.assertAllEqual(my_function(1), None)
def testNestedFunctions(self):
# TensorFlow function (which is what would be used in TensorFlow graph
# construction).
@tf_function.Defun(dtypes.int32, dtypes.int32)
def add(a, b):
return math_ops.add(a, b)
@def_function.function
def add_one(x):
return add(x, 1)
self.assertAllEqual(3, add_one(constant_op.constant(2)))
def testVariableCaptureInNestedFunctions(self):
v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32)
@def_function.function
def inner_read():
return v.read_value()
@def_function.function
def outer():
return inner_read()
self.assertEqual(1, int(outer()))
def testReturnCapturedEagerTensor(self):
t = constant_op.constant(1)
@def_function.function
def read():
return t
self.assertEqual(1, int(read()))
def testReturnCapturedGraphTensor(self):
with context.graph_mode(), self.cached_session():
t = constant_op.constant(1)
@def_function.function
def read():
return t
self.assertEqual(1, int(self.evaluate(read())))
def testSequenceInputs(self):
clip_by_global_norm = def_function.function(clip_ops.clip_by_global_norm)
t_list = [constant_op.constant(1.0), constant_op.constant(2.0)]
clipped_list, global_norm = clip_by_global_norm(t_list,
constant_op.constant(.2))
for t in clipped_list:
self.assertIsInstance(t, ops.Tensor)
self.assertIsInstance(global_norm, ops.Tensor)
def testNestedSequenceInputs(self):
def my_op(inputs):
a, b, c = inputs
e, f = b
g, h = e
return [a + a, [tuple([f + f, g + g]), h + h], c + c], a + f + g + h + c
my_eager_op = def_function.function(my_op)
ret = my_eager_op([
constant_op.constant(1), [(constant_op.constant(2),
constant_op.constant(3)),
constant_op.constant(4)],
constant_op.constant(5)
])
self.assertLen(ret, 2)
self.assertAllEqual(ret[0][0], 2)
self.assertAllEqual(ret[0][1][0][0], 8)
self.assertAllEqual(ret[0][1][0][1], 4)
self.assertIsInstance(ret[0][1][0], tuple)
self.assertAllEqual(ret[0][1][1], 6)
self.assertAllEqual(ret[0][2], 10)
self.assertAllEqual(ret[1], 15)
def testVariableNamesRespectNameScopesWithDefun(self):
@def_function.function
def create_variable():
with ops.name_scope('foo', skip_on_eager=False):
v = resource_variable_ops.ResourceVariable(0.0, name='bar')
self.assertEqual(v.name, 'foo/bar:0')
create_variable()
def testVariableNamesRespectNameScopesWithDefunInGraph(self):
with context.graph_mode():
@def_function.function
def create_variable():
with ops.name_scope('foo', skip_on_eager=False):
v = resource_variable_ops.ResourceVariable([1.0, 2.0], name='bar')
self.assertEqual(v.name, 'foo/bar:0')
with ops.get_default_graph().as_default():
create_variable()
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testLayerInDefun(self):
conv = convolutional.Conv2D(
filters=1,
kernel_size=2,
kernel_initializer=init_ops.ones_initializer(),
bias_initializer=init_ops.zeros_initializer())
@function.defun
def model(x):
return conv(x)
x = array_ops.ones([1, 2, 2, 1])
y = model(x)
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
self.assertAllClose([[[[4.0]]]], self.evaluate(y))
# Variable lifting is somewhat different between defun/tf.function, so testing
# device placement on both makes sense.
@parameterized.named_parameters(
dict(testcase_name='Defun',
function_decorator=function.defun),
dict(testcase_name='DefFunction',
function_decorator=def_function.function))
@test_util.run_in_graph_and_eager_modes
def testVariablesPlacedOnOutsideDevice(self, function_decorator):
class _Obj(object):
def __init__(self):
self.v = None
@function_decorator
def f(self):
if self.v is None:
self.v = variables.Variable(1.)
return self.v + 1.
has_device = _Obj()
with ops.device('cpu:0'):
has_device.f()
self.assertIn('CPU', has_device.v.device)
@test_util.run_in_graph_and_eager_modes
def testDeviceAnnotationsRespected(self):
def multi_device_fn():
with ops.device('/cpu:0'):
s0 = test_ops.device_placement_op()
with ops.device('/cpu:1'):
s1 = test_ops.device_placement_op()
with ops.device('/cpu:2'):
s2 = test_ops.device_placement_op()
s3 = test_ops.device_placement_op()
return s0, s1, s2, s3
defined = function.defun(multi_device_fn)
outputs = self.evaluate(defined())
self.assertLen(total_function_cache(defined), 1)
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
with ops.device('/cpu:3'):
outputs = self.evaluate(defined())
# All function definitions are agnostic to call site devices.
self.assertLen(total_function_cache(defined), 1)
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
self.assertIn(compat.as_bytes('CPU:3'), outputs[3])
with ops.device('/cpu:0'):
outputs = self.evaluate(defined())
self.assertLen(total_function_cache(defined), 1)
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
self.assertIn(compat.as_bytes('CPU:0'), outputs[3])
@test_util.run_in_graph_and_eager_modes
def testCallingGraphFunctionOnDifferentDevice(self):
def func():
return constant_op.constant(0)
defined = def_function.function(func)
with ops.device('cpu:0'):
cpu_graph_function = defined.get_concrete_function()
with ops.device('cpu:0'):
self.assertEqual(
self.evaluate(cpu_graph_function()), self.evaluate(func()))
with ops.device('cpu:1'):
self.assertEqual(0., self.evaluate(cpu_graph_function()))
with ops.device(None):
self.assertEqual(0., self.evaluate(cpu_graph_function()))
default_graph_function = defined.get_concrete_function()
self.assertEqual(
self.evaluate(default_graph_function()), self.evaluate(func()))
with ops.device('cpu:1'):
self.assertEqual(0., self.evaluate(default_graph_function()))
@test_util.run_gpu_only
@test_util.run_in_graph_and_eager_modes
def testColocateWithRespected(self):
# TODO(b/113291792): Use multiple CPUs instead of a GPU.
with ops.device('cpu:0'):
x = array_ops.identity(1.0)
with ops.device('gpu:0'):
y = array_ops.identity(1.0)
@def_function.function
def foo():
return test_ops.device_placement_op()
with ops.colocate_with(x):
self.assertIn(compat.as_bytes('CPU:0'), self.evaluate(foo()))
with ops.colocate_with(y):
self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo()))
def testVariablesAreTracked(self):
v = resource_variable_ops.ResourceVariable(1.0)
def foo(x):
return v * x
defined = def_function.function(foo)
x = constant_op.constant([1.0])
self.assertEqual(1., self.evaluate(defined(x)))
v.assign(2.)
x = constant_op.constant([1.0, 2.0])
self.assertAllEqual([2., 4.], self.evaluate(defined(x)))
def testCacheObjectHashCollisions(self):
class Foo(object):
def __hash__(self):
return 42
def func(foo):
del foo
return
defined = function.defun(func)
defined(Foo())
self.assertLen(total_function_cache(defined), 1)
defined(Foo())
self.assertLen(total_function_cache(defined), 2)
def testCacheTensorDtypeCollision(self):
def func(t):
return t + t
defined = function.defun(func)
t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
defined(t)
self.assertLen(total_function_cache(defined), 1)
t = constant_op.constant([[1.0]], dtype=dtypes.complex128)
defined(t)
self.assertLen(total_function_cache(defined), 2)
def testCacheTensorShapeCollision(self):
def func(t):
return t + t
defined = function.defun(func)
t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
defined(t)
self.assertLen(total_function_cache(defined), 1)
t = constant_op.constant([1.0], dtype=dtypes.complex64)
defined(t)
self.assertLen(total_function_cache(defined), 2)
def testCacheTensorShapeDtypeCollision(self):
def func(t):
return t + t
defined = function.defun(func)
t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
defined(t)
self.assertLen(total_function_cache(defined), 1)
t = constant_op.constant([1.0], dtype=dtypes.complex128)
defined(t)
self.assertLen(total_function_cache(defined), 2)
def testCacheTensorUnknownShapesCollisionRelaxedShapes(self):
def func(t):
return t + t
with context.graph_mode(), self.cached_session():
defined = function.defun(func, experimental_relax_shapes=True)
p = array_ops.placeholder(dtype=dtypes.float32, shape=[])
defined(p)
self.assertLen(total_function_cache(defined), 1)
p = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
defined(p)
self.assertLen(total_function_cache(defined), 2)
p = array_ops.placeholder(dtype=dtypes.float32, shape=[2])
defined(p)
# Gradual shape relaxation is performed; and the common shape between
# [1] and [2] is one containing unknown dimensions.
self.assertLen(total_function_cache(defined), 2)
# pylint: disable=protected-access
self.assertLen(defined._function_cache.arg_relaxed_specs, 1)
relaxed_specs = (
list(defined._function_cache.arg_relaxed_specs.values())[0])
self.assertLen(relaxed_specs, 1)
relaxed_shape = relaxed_specs[0].shape
# pylint: enable=protected-access
self.assertEqual(relaxed_shape.rank, 1)
self.assertEqual(tensor_shape.dimension_value(relaxed_shape[0]), None)
t = constant_op.constant([1.0, 1.0, 1.0], dtype=dtypes.float32)
defined(t)
# Shape (3,) matches the relaxed shape TensorShape([None])
self.assertLen(total_function_cache(defined), 2)
def testPythonFunctionWithDefaultArgs(self):
def func(foo, bar=1, baz=2):
del foo
del bar
del baz
return
defined = function.defun(func)
defined(0, baz=20)
def cache_keys():
"""Sanitizes cache keys of non-input metadata."""
return tuple(key[0] for key in total_function_cache(defined))
# `True` corresponds to the fact that we're executing eagerly
self.assertIn(('URRRu', (0, 1, 20)), cache_keys())
defined(1) # bar=1, baz=2
self.assertIn(('URRRu', (1, 1, 2)), cache_keys())
# This matches the previous call.
defined(foo=1)
self.assertLen(total_function_cache(defined), 2)
defined(1, 2, 3)
self.assertLen(total_function_cache(defined), 3)
self.assertIn(('URRRu', (1, 2, 3)), cache_keys())
# This matches the previous call.
defined(1, bar=2, baz=3)
self.assertLen(total_function_cache(defined), 3)
# This matches the previous call.
defined(1, baz=3, bar=2)
self.assertLen(total_function_cache(defined), 3)
def testFunctoolsPartialUnwrappedCorrectly(self):
def full_function(a, b, c=3):
return a, b, c
partial = functools.partial(full_function, 1, c=4)
a, b, c = partial(2)
defined = function.defun(partial)
func_a, func_b, func_c = defined(2)
self.assertEqual(func_a.numpy(), a)
self.assertEqual(func_b.numpy(), b)
self.assertEqual(func_c.numpy(), c)
def testInputSignatureWithMatchingInputs(self):
def foo(a):
self.assertEqual(a.shape, (2,))
return a
signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)]
defined = function.defun(foo, input_signature=signature)
a = array_ops.ones([2])
self.assertAllEqual(a, defined(a))
self.assertLen(total_function_cache(defined), 1)
self.assertAllEqual(a, defined.get_concrete_function()(a))
self.assertAllEqual(a, defined.get_concrete_function(a)(a))
self.assertAllEqual(a, defined.get_concrete_function(
tensor_spec.TensorSpec((2,), dtype=dtypes.float32))(a))
self.assertLen(total_function_cache(defined), 1)
def bar(a):
self.assertEqual(a._shape_tuple(), (2, None))
return a
signature = [tensor_spec.TensorSpec((2, None), dtypes.float32)]
defined = function.defun(bar, input_signature=signature)
a = array_ops.ones([2, 1])
out = defined(a)
self.assertLen(total_function_cache(defined), 1)
self.assertAllEqual(out, a)
# Changing the second dimension shouldn't create a new function.
b = array_ops.ones([2, 3])
out = defined(b)
self.assertLen(total_function_cache(defined), 1)
self.assertAllEqual(out, b)
def testInputSignatureWithCompatibleInputs(self):
rank2_spec = tensor_spec.TensorSpec(shape=(None, None),
dtype=dtypes.float32)
@function.defun(input_signature=[rank2_spec])
def func(a):
self.assertEqual([None, None], a.shape.as_list())
return array_ops.shape(a)
self.assertAllEqual([3, 1], func([[0], [1.0], [1]]))
self.assertAllEqual([2, 2], func(numpy.array([[1, 1], [2, 2]])))
with self.assertRaisesRegex(ValueError, 'incompatible'):
func([0.0, 1.0, 2.0]) # Wrong shape.
with self.assertRaisesRegex(ValueError, 'incompatible'):
func([['wrong dtype']])
def testNoKeywordOnlyArgumentsWithInputSignature(self):
if sys.version_info[0] < 3:
self.skipTest('keyword_only arguments only exist in Python 3.')
func = eval('lambda x, *, y: x') # pylint: disable=eval-used
signature = [tensor_spec.TensorSpec(None, dtypes.int32)]
with self.assertRaisesRegex(
ValueError, 'Cannot define a TensorFlow function from a Python '
'function with keyword-only arguments when input_signature is '
'provided.'):
def_function.function(func, signature)
def testNestedInputSignatures(self):
def expected_foo(a, b):
return [a, b]
@function.defun(input_signature=[
[tensor_spec.TensorSpec((2, None), dtypes.float32)] * 2,
tensor_spec.TensorSpec((1,), dtypes.float32),
])
def foo(a, b):
self.assertEqual(a[0]._shape_tuple(), (2, None))
self.assertEqual(a[1]._shape_tuple(), (2, None))
self.assertEqual(b._shape_tuple(), (1,))
return [a, b]
a = array_ops.ones([2, 1])
b = array_ops.ones([1])
expected = expected_foo([a, a], b)
out = foo([a, a], b)
self.assertLen(total_function_cache(foo), 1)
nest.assert_same_structure(out, expected)
self.assertAllEqual(out[0][0], a)
self.assertAllEqual(out[0][1], a)
self.assertAllEqual(out[1], b)
# Changing the unspecified dimensions shouldn't create a new function.
a = array_ops.ones([2, 3])
b = array_ops.ones([2, 5])
c = array_ops.ones([1])
expected = expected_foo([a, b], c)
out = foo([a, b], c)
self.assertLen(total_function_cache(foo), 1)
nest.assert_same_structure(out, expected)
self.assertAllEqual(out[0][0], a)
self.assertAllEqual(out[0][1], b)
self.assertAllEqual(out[1], c)
# Passing compatible inputs should work.
a = a.numpy().tolist()
b = b.numpy().tolist()
c = c.numpy().tolist()
out = foo([a, b], c)
self.assertLen(total_function_cache(foo), 1)
nest.assert_same_structure(out, expected)
self.assertAllEqual(out[0][0], a)
self.assertAllEqual(out[0][1], b)
self.assertAllEqual(out[1], c)
def testNestedInputSignaturesWithDict(self):
def expected_bar(a):
return a
@function.defun(input_signature=[{
'a': tensor_spec.TensorSpec((2, None), dtypes.float32),
'b': tensor_spec.TensorSpec((2, None), dtypes.float32),
'c': tensor_spec.TensorSpec((1,), dtypes.float32)}])
def bar(a):
self.assertEqual(a['a']._shape_tuple(), (2, None))
self.assertEqual(a['b']._shape_tuple(), (2, None))
self.assertEqual(a['c']._shape_tuple(), (1,))
return a
a = array_ops.ones([2, 3])
b = array_ops.ones([1])
inputs = {'a': a, 'b': a, 'c': b}
expected = expected_bar(inputs)
out = bar(inputs)
nest.assert_same_structure(out, expected)
self.assertAllEqual(out['a'], expected['a'])
self.assertAllEqual(out['b'], expected['b'])
self.assertAllEqual(out['c'], expected['c'])
# Passing compatible inputs should work.
a = a.numpy().tolist()
b = b.numpy().tolist()
inputs = {'a': a, 'b': a, 'c': b}
out = bar(inputs)
nest.assert_same_structure(out, expected)
self.assertAllEqual(out['a'], expected['a'])
self.assertAllEqual(out['b'], expected['b'])
self.assertAllEqual(out['c'], expected['c'])
def testInputSignatureMustBeSequenceOfTensorSpecs(self):
def foo(a, b):
del a
del b
# Signatures must consist exclusively of `TensorSpec` objects.
signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)]
with self.assertRaisesRegex(TypeError, 'Invalid input_signature.*'):
def_function.function(foo, input_signature=signature)
# Signatures must be either lists or tuples on their outermost levels.
signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)}
with self.assertRaisesRegex(
TypeError, 'input_signature must be either a '
'tuple or a list.*'):
function.defun(foo, input_signature=signature)
@test_util.run_in_graph_and_eager_modes
def testInputsIncompatibleWithSignatureRaisesError(self):
def foo(a):
return a
signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)]
defined = def_function.function(foo, input_signature=signature)
# Invalid shapes.
with self.assertRaisesRegex(ValueError, 'Python inputs incompatible.*'):
defined(array_ops.ones([3]))
with self.assertRaisesRegex(ValueError, 'Python inputs incompatible.*'):
defined(array_ops.ones([2, 1]))
# Wrong number of arguments.
with self.assertRaisesRegex(
TypeError, r'takes 1 positional arguments \(as specified by the '
r'input_signature\) but 2 were given'):
defined(array_ops.ones([2]), array_ops.ones([2]))
with self.assertRaisesRegex(ValueError,
'Structure of Python function inputs.*'):
defined()
with self.assertRaisesRegex(ValueError,
'inputs incompatible with input_signature'):
defined.get_concrete_function(
tensor_spec.TensorSpec(shape=(3,), dtype=dtypes.float32))
def testInputsIncompatibleWithNestedSignatureRaisesError(self):
def foo(a, b):
return [a, b]
signature = [[tensor_spec.TensorSpec((1,), dtypes.float32)] * 2,
[tensor_spec.TensorSpec((1,), dtypes.float32)] * 2]
defined = function.defun(foo, input_signature=signature)
a = array_ops.ones([1])
with self.assertRaisesRegex(ValueError,
'Structure of Python function inputs.*'):
defined([a, a, a], [a])
with self.assertRaisesRegex(ValueError,
'Structure of Python function inputs.*'):
defined([a], [a, a, a])
defined([a, a], [a, a])
def testUnderspecifiedInputSignature(self):
@function.defun(input_signature=[
tensor_spec.TensorSpec([], dtypes.float32),
])
def foo(a, training=True):
if training:
return a
else:
return -1.0 * a
x = constant_op.constant(1.0)
with self.assertRaisesRegex(
TypeError, 'got keyword argument `training` '
'that was not included in input_signature'):
foo(x, training=True)
with self.assertRaisesRegex(
TypeError, 'got keyword argument `training` '
'that was not included in input_signature'):
foo(x, training=False)
self.assertAllEqual(x.numpy(), foo(x).numpy())
def testInputSignatureWithPartialFunction(self):
def full_function(a, b, c=3.0):
return a, b, c
partial = functools.partial(full_function, 1, c=4)
a, b, c = partial(2.0)
signature = [tensor_spec.TensorSpec([], dtypes.float32)]
defined = function.defun(partial, input_signature=signature)
x = constant_op.constant(2.0)
func_a, func_b, func_c = defined(x)
self.assertEqual(func_a.numpy(), a)
self.assertEqual(func_b.numpy(), b)
self.assertEqual(func_c.numpy(), c)
def testInputSignatureConversionWithDefaultArg(self):
def foo(a, training=True):
if training:
return a
else:
return -1.0 * a
signature = [
tensor_spec.TensorSpec([], dtypes.float32),
tensor_spec.TensorSpec([], dtypes.bool),
]
defined = def_function.function(foo, input_signature=signature)
a = constant_op.constant(1.0)
self.assertAllEqual(a.numpy(), defined(a))
self.assertAllEqual(a.numpy(), defined(a, training=True))
self.assertAllEqual(-a.numpy(), defined(a, training=False))
def testInputSignatureWithKeywordPositionalArgs(self):
@function.defun(input_signature=[
tensor_spec.TensorSpec([], dtypes.float32),
tensor_spec.TensorSpec([], dtypes.int64)
])
def foo(flt, integer):
return flt, integer
flt = constant_op.constant(1.0)
integer = constant_op.constant(2, dtypes.int64)
out1, out2 = foo(flt, integer)
self.assertLen(total_function_cache(foo), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(flt=flt, integer=integer)
self.assertLen(total_function_cache(foo), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(integer=integer, flt=flt)
self.assertLen(total_function_cache(foo), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(flt, integer=integer)
self.assertLen(total_function_cache(foo), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
def testInputSignatureWithKeywordArgs(self):
def foo(a, b, **kwargs):
del kwargs
return a, b
x = function.defun(
foo,
input_signature=[
tensor_spec.TensorSpec([], dtypes.float32),
tensor_spec.TensorSpec([], dtypes.int32)
]).get_concrete_function()
result = x(constant_op.constant(5.0), constant_op.constant(5))
self.assertAllEqual(result, [5.0, 5])
def testInputSignatureWithCompositeTensors(self):
def f(rt):
self.assertEqual(rt.values.shape.as_list(), [None])
self.assertEqual(rt.row_splits.shape.as_list(), [4])
return rt
signature = [ragged_tensor.RaggedTensorSpec(
shape=[3, None], dtype=dtypes.int32)]
defined = function.defun(f, input_signature=signature)
rt1 = ragged_factory_ops.constant([[1], [], [2, 3, 4]])
out1 = defined(rt1)
self.assertLen(total_function_cache(defined), 1)
self.assertAllEqual(out1.values, rt1.values)
self.assertAllEqual(out1.row_splits, rt1.row_splits)
# Changing the row lengths shouldn't create a new function.
rt2 = ragged_factory_ops.constant([[1, 2], [3, 4], [5]])
out2 = defined(rt2)
self.assertLen(total_function_cache(defined), 1)
self.assertAllEqual(out2.values, rt2.values)
self.assertAllEqual(out2.row_splits, rt2.row_splits)
# Different number of rows
rt3 = ragged_factory_ops.constant([[1, 2], [3, 4], [5], [6]])
with self.assertRaisesRegex(ValueError, 'incompatible'):
defined(rt3)
# Different dtype
rt4 = ragged_factory_ops.constant([[1.0, 2.0], [], [3.0]])
with self.assertRaisesRegex(ValueError, 'Structure .* does not match'):
defined(rt4)
# Different rank
rt5 = ragged_factory_ops.constant([[[1]], [[2]], [[3]]])
with self.assertRaisesRegex(ValueError, 'does not match'):
defined(rt5)
def testInputSignatureWithVariableArgs(self):
def f(v):
v.assign_add(1)
signature = [
resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32)
]
defined = function.defun(f, input_signature=signature)
v1 = variables.Variable(0)
v2 = variables.Variable(0)
defined(v1)
self.assertEqual(v1.numpy(), 1)
self.assertEqual(v2.numpy(), 0)
defined(v=v2)
self.assertEqual(v1.numpy(), 1)
self.assertEqual(v2.numpy(), 1)
def testTensorKeywordArguments(self):
def foo(a, b):
del a
return b
defined = function.defun(foo)
a = constant_op.constant(2.0)
b = constant_op.constant([1.0, 2.0])
one = defined(a, b)
self.assertLen(total_function_cache(defined), 1)
two = defined(a=a, b=b)
self.assertLen(total_function_cache(defined), 1)
three = defined(b=b, a=a)
self.assertLen(total_function_cache(defined), 1)
four = defined(a, b=b)
self.assertLen(total_function_cache(defined), 1)
# The next call corresponds to a new input signature, hence
# we expect another function to be defined.
five = defined(b, a)
self.assertLen(total_function_cache(defined), 2)
six = defined(a=b, b=a)
self.assertLen(total_function_cache(defined), 2)
seven = defined(b=a, a=b)
self.assertLen(total_function_cache(defined), 2)
self.assertAllEqual(one, [1.0, 2.0])
self.assertAllEqual(two, [1.0, 2.0])
self.assertAllEqual(three, [1.0, 2.0])
self.assertAllEqual(four, [1.0, 2.0])
self.assertAllEqual(five, 2.0)
self.assertAllEqual(six, 2.0)
self.assertAllEqual(seven, 2.0)
def testDefuningInstanceMethod(self):
integer = constant_op.constant(2, dtypes.int64)
class Foo(object):
def one(self, tensor):
return tensor
@def_function.function
def two(self, tensor, other=integer):
return self.one(tensor), other
foo = Foo()
t = constant_op.constant(1.0)
one, two = foo.two(t)
self.assertEqual(one.numpy(), 1.0)
self.assertEqual(two.numpy(), 2)
def testDefuningInstanceMethodWithDefaultArgument(self):
integer = constant_op.constant(2, dtypes.int64)
class Foo(object):
@def_function.function
def func(self, other=integer):
return other
foo = Foo()
self.assertEqual(foo.func().numpy(), int(integer))
def testPythonCallWithSideEffects(self):
state = []
@def_function.function
def side_effecting_function():
state.append(0)
side_effecting_function()
self.assertAllEqual(state, [0])
# The second invocation should call the graph function, which shouldn't
# trigger the list append.
side_effecting_function()
self.assertAllEqual(state, [0])
# Whereas calling the python function directly should create a side-effect.
side_effecting_function.python_function()
self.assertAllEqual(state, [0, 0])
def testFunctionWithNestedFunctionCallAndSideEffects(self):
v1 = variables.Variable(1.0)
v2 = variables.Variable(1.0)
@def_function.function
def add_one(a):
a.assign_add(1.0)
# Grappler will inline calls to `add_one` into the function body, we check
# that all side-effects were executed.
@def_function.function
def side_effecting_function(a, b):
add_one(a)
add_one(b)
return a + b
result = side_effecting_function(v1, v2)
self.assertEqual(result.numpy(), 4.0)
def testFunctionWithExtraAttributes(self):
@function.defun_with_attributes(attributes={'experimental_1': 'value1',
'experimental_2': 2})
def matmul(x, y):
return math_ops.matmul(x, y)
def add(x, y):
return math_ops.add(x, y)
defun_add = function.defun_with_attributes(
add, attributes={'experimental_3': True, 'experimental_4': 1.0})
with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq = matmul(t, t)
double = defun_add(t, t)
self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
graph = ops.get_default_graph()
# pylint: disable=protected-access
self.assertLen(graph._functions, 2)
functions = list(graph._functions.values())
self.assertRegex(functions[0].definition.signature.name, '.*matmul.*')
attrs = functions[0].definition.attr
self.assertLen(attrs, 2)
self.assertEqual(attrs['experimental_1'].s, b'value1')
self.assertEqual(attrs['experimental_2'].i, 2)
self.assertRegex(functions[1].definition.signature.name, '.*add.*')
attrs = functions[1].definition.attr
self.assertLen(attrs, 2)
self.assertEqual(attrs['experimental_3'].b, True)
self.assertEqual(attrs['experimental_4'].f, 1.0)
# pylint: enable=protected-access
def testFunctionWithInvalidAttribute(self):
@function.defun_with_attributes(attributes={'experimental_1': ['value1']})
def add(x, y):
return math_ops.add(x, y)
with self.assertRaisesRegex(ValueError, '.*Unsupported attribute type.*'):
with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
add(t, t)
def testRegisterFunction(self):
@function.defun
def add(x, y):
return math_ops.add(x, y)
def matmul(x, y):
return math_ops.matmul(x, y)
defun_matmul = function.defun(matmul)
with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
function.register(defun_matmul, t, t)
function.register(add, t, t)
graph = ops.get_default_graph()
# pylint: disable=protected-access
self.assertLen(graph._functions, 6)
# two sets of functions, each of them are (inference, forward, backward)
functions = list(graph._functions.values())
captured_function_names = [
f.definition.signature.name for f in functions
]
expected_func_name_regex = [
'.*inference.*matmul.*',
'.*forward.*matmul.*',
'.*inference.*backward.*matmul.*',
'.*inference.*add.*',
'.*forward.*add.*',
'.*inference.*backward.*add.*',
]
for i in range(len(functions)):
self.assertRegex(captured_function_names[i],
expected_func_name_regex[i])
# Check the forward and backward function has the correct attributes.
self.assertEqual(
functions[1].definition.attr['backward_function_name'].s,
functions[2].name)
self.assertEqual(
functions[2].definition.attr['forward_function_name'].s,
functions[1].name)
self.assertEqual(
functions[4].definition.attr['backward_function_name'].s,
functions[5].name)
self.assertEqual(
functions[5].definition.attr['forward_function_name'].s,
functions[4].name)
sq = defun_matmul(t, t)
double = add(t, t)
self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
# Make sure the pre registered function is used, and no other function
# is added.
self.assertLen(graph._functions, 6)
functions = list(graph._functions.values())
for i in range(len(functions)):
self.assertEqual(captured_function_names[i],
functions[i].definition.signature.name)
@parameterized.named_parameters(
dict(testcase_name='Defun',
function_decorator=function.defun),
dict(testcase_name='DefFunction',
function_decorator=def_function.function))
def testRegisterConcreteFunction(self, function_decorator):
@function_decorator
def py_add(x, y):
return math_ops.add(x, y)
py_add(array_ops.ones([]), array_ops.ones([]))
add = py_add.get_concrete_function(
tensor_spec.TensorSpec(None, dtypes.float32),
tensor_spec.TensorSpec(None, dtypes.float32))
@function_decorator
def py_composite(x, y):
return x, add(x, y)
py_composite(array_ops.ones([]), array_ops.ones([]))
composite = py_composite.get_concrete_function(
tensor_spec.TensorSpec(None, dtypes.float32),
tensor_spec.TensorSpec(None, dtypes.float32))
with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
composite.add_to_graph()
composite.add_gradient_functions_to_graph()
graph = ops.get_default_graph()
# pylint: disable=protected-access
self.assertLen(graph._functions, 6)
# two sets of functions, each of them are (inference, forward, backward)
functions = list(graph._functions.values())
captured_function_names = [
f.definition.signature.name for f in functions
]
expected_func_name_regex = [
'.*inference.*py_composite.*',
'.*inference.*py_add.*',
'.*forward.*py_composite.*',
'.*forward.*py_add.*',
'.*inference.*backward.*py_composite.*',
'.*inference.*backward.*py_add.*',
]
for expected, found in zip(
expected_func_name_regex,
captured_function_names):
self.assertRegex(found, expected)
composite_t, composite_double = composite(t, t)
double = add(t, t)
self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(double))
self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(composite_double))
self.assertAllEqual([[1, 2], [3, 4]], self.evaluate(composite_t))
# Make sure the pre registered function is used, and no other function
# is added.
self.assertLen(graph._functions, 6)
@parameterized.named_parameters(
dict(testcase_name='Defun',
function_decorator=function.defun),
dict(testcase_name='DefFunction',
function_decorator=def_function.function))
def testEagerCaptures(self, function_decorator):
with context.eager_mode():
large_tensor = array_ops.ones(shape=(256,))
self.assertGreater(256, func_graph._EAGER_CONST_THRESHOLD)
small_tensor = array_ops.ones(shape=(4,))
self.assertLessEqual(4, func_graph._EAGER_CONST_THRESHOLD)
v = resource_variable_ops.ResourceVariable(0.0)
for captured, op_type in [(large_tensor, 'Placeholder'),
(small_tensor, 'Const'), (v, 'Placeholder')]:
@function_decorator
def test_fn():
return captured + 1 # pylint: disable=cell-var-from-loop
g = test_fn.get_concrete_function().graph
internal_captures = g.internal_captures
self.assertLen(internal_captures, 1)
self.assertEqual(internal_captures[0].op.type, op_type)
def testRegisterFunctionWithInputSignature(self):
def matmul(x, y):
return math_ops.matmul(x, y)
defun_matmul = function.defun(
matmul,
input_signature=[
tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
])
with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
function.register(defun_matmul, t, t)
graph = ops.get_default_graph()
# pylint: disable=protected-access
self.assertLen(graph._functions, 3)
# Test register function with cache, note inputs are ignored.
function.register(defun_matmul)
graph = ops.get_default_graph()
self.assertLen(graph._functions, 3)
def testRegisterFunctionWithCache(self):
def matmul(x, y):
return math_ops.matmul(x, y)
defun_matmul = function.defun(matmul)
with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]])
function.register(defun_matmul, t, t)
function.register(defun_matmul, t2, t2)
graph = ops.get_default_graph()
# Only one function is registered since the input param are in same type
# pylint: disable=protected-access
self.assertLen(graph._functions, 3)
def testCallingFunctionWithDifferentVariables(self):
@function.defun
def foo(v):
v.assign_add(1.0)
return v.read_value()
v = resource_variable_ops.ResourceVariable(0.0)
graph_function = foo.get_concrete_function(v)
self.assertLen(graph_function.inputs, 1)
self.assertEmpty(graph_function.captured_inputs)
self.assertEqual(float(graph_function(v)), 1.0)
self.assertEqual(float(graph_function(v)), 2.0)
w = resource_variable_ops.ResourceVariable(0.0)
@function.defun
def bar(v):
del v
return constant_op.constant(1.0)
graph_function = bar.get_concrete_function(v)
self.assertEqual(float(graph_function(v)), 1.0)
self.assertEqual(float(graph_function(w)), 1.0)
def testCallingFunctionWithNonTensorsFails(self):
@function.defun
def foo(x):
return x
graph_function = foo.get_concrete_function(constant_op.constant(1.0))
with self.assertRaises((TypeError, ValueError)):
graph_function('Not a Tensor.')
def testSwapImplementationWithGrapplerPlugin(self):
# Set the min_graph_nodes to -1 since the graph in this test is too small,
# and will be ignored by grappler if don't set this.
rewrites = rewriter_config_pb2.RewriterConfig()
rewrites.implementation_selector = rewriter_config_pb2.RewriterConfig.ON
rewrites.min_graph_nodes = -1
graph_options = config_pb2.GraphOptions(
rewrite_options=rewrites, build_cost_model=1)
config_proto = config_pb2.ConfigProto(graph_options=graph_options)
with context.graph_mode(), self.cached_session(
config=config_proto, graph=ops.Graph(), use_gpu=True):
@function.defun_with_attributes(
attributes={
'api_implements': 'random_boost',
'api_preferred_device': 'CPU'
})
def cpu_boost(x):
return math_ops.add(x, 2.0)
@function.defun_with_attributes(
attributes={
'api_implements': 'random_boost',
'api_preferred_device': 'GPU'
})
def gpu_boost(x):
return math_ops.add(x, 4.0)
x = constant_op.constant(1.0)
function.register(cpu_boost, x)
y = gpu_boost(x)
y_value = self.evaluate(y)
if test.is_gpu_available():
self.assertEqual(y_value, 5.0)
else:
# Grappler fallback to use the CPU impl even called with GPU function.
self.assertEqual(y_value, 3.0)
def testSwapImplementationInEager(self):
if not context.executing_eagerly():
self.skipTest('eager only')
# testSharedRendezvous sets the disable_meta_optimizer flag to True
# if that subtest runs before this one, then having that set to True
# will cause this subtest to fail. To avoid that scenario, explicitly
# set the disable_meta_optimizer flag to false here
context.context().set_optimizer_experimental_options({
'min_graph_nodes': -1,
'implementation_selector': True,
'disable_meta_optimizer': False
})
@function.defun_with_attributes(
attributes={'api_implements': 'foo',
'api_preferred_device': 'CPU'})
def on_cpu(x):
return x + 2
@function.defun_with_attributes(
attributes={'api_implements': 'foo',
'api_preferred_device': 'GPU'})
def on_gpu(x):
return x + 4
@function.defun
def run_on_cpu(t):
function.register(on_cpu, t)
with ops.device('CPU:0'):
return on_gpu(t)
# Expect to run the on_cpu branch, regardless whether gpu is available.
self.assertEqual(run_on_cpu(constant_op.constant(1)).numpy(), 3)
def testDefunFunctionSeparateGraphs(self):
with context.graph_mode():
@function.defun
def add(x):
return x + 5
@function.defun
def maybe_add(x, should_add):
if should_add:
return add(x)
else:
return x
with ops.Graph().as_default():
x = constant_op.constant(11)
maybe_add(x, True)
self.assertLen(total_function_cache(maybe_add), 1)
self.assertLen(total_function_cache(add), 1)
maybe_add(x, False)
self.assertLen(total_function_cache(maybe_add), 2)
self.assertLen(total_function_cache(add), 1)
with ops.Graph().as_default():
x = constant_op.constant(11)
maybe_add(x, True)
self.assertLen(total_function_cache(maybe_add), 3)
self.assertLen(total_function_cache(add), 2)
def testCacheKeyOverlappingShapes(self):
@function.defun
def defined(t):
return t
defined(array_ops.zeros([12, 1]))
self.assertLen(total_function_cache(defined), 1)
defined(array_ops.zeros([1, 21]))
self.assertLen(total_function_cache(defined), 2)
def testCacheKeyNestedLists(self):
@function.defun
def defined(l):
return l
a = constant_op.constant(1.)
b = constant_op.constant(2.)
c = constant_op.constant(3.)
defined([[a], b, c])
self.assertLen(total_function_cache(defined), 1)
defined([[a, b], c])
self.assertLen(total_function_cache(defined), 2)
def testCacheKeyAttrsClass(self):
if attr is None:
self.skipTest('attr module is unavailable.')
@attr.s
class TestClass(object):
a = attr.ib()
b = attr.ib()
@function.defun
def defined(l):
return l
defined(
TestClass(
constant_op.constant(1.),
[constant_op.constant(2.),
constant_op.constant(3.)]))
self.assertLen(total_function_cache(defined), 1)
defined(
TestClass(
constant_op.constant(1.),
[constant_op.constant(2.),
constant_op.constant(3.)]))
self.assertLen(total_function_cache(defined), 1)
defined(
TestClass([constant_op.constant(1.),
constant_op.constant(2.)], constant_op.constant(3.)))
self.assertLen(total_function_cache(defined), 2)
def testCacheKeyVariables(self):
@function.defun
def defined(a, b, c):
return a + b + c
x = resource_variable_ops.ResourceVariable(0.0)
y = resource_variable_ops.ResourceVariable(0.0)
z = resource_variable_ops.ResourceVariable(0.0)
# If tensor equality is not enabled, we always get a cache miss if the
# function is called with different variables. With equality enabled we
# should only get a miss if the aliasing changed.
defined(x, y, z)
self.assertLen(total_function_cache(defined), 1)
defined(x, y, z)
self.assertLen(total_function_cache(defined), 1)
# Re-arranging arguments causes cache miss
defined(z, y, x)
self.assertLen(total_function_cache(defined), 2)
defined(z, y, x)
self.assertLen(total_function_cache(defined), 2)
# Aliasing causes cache miss
defined(x, x, z)
self.assertLen(total_function_cache(defined), 3)
defined(x, x, z)
self.assertLen(total_function_cache(defined), 3)
# Re-arranging arguments causes cache miss
defined(y, y, z)
self.assertLen(total_function_cache(defined), 4)
defined(y, y, z)
self.assertLen(total_function_cache(defined), 4)
# Different alias positions causes cache miss
defined(z, y, y)
self.assertLen(total_function_cache(defined), 5)
defined(z, y, y)
self.assertLen(total_function_cache(defined), 5)
x_copy = copy.deepcopy(x)
# Deep copy causes cache miss
defined(x_copy, y, z)
self.assertLen(total_function_cache(defined), 6)
defined(x_copy, y, z)
self.assertLen(total_function_cache(defined), 6)
def testVariableRetracing(self):
v1 = variables.Variable(1.)
v2 = variables.Variable(1.)
v3 = copy.deepcopy(variables.Variable(1.))
var_dict = {id(v1): constant_op.constant(1),
id(v2): constant_op.constant(2),
id(v3): constant_op.constant(3)}
@function.defun
def lookup_tensor(v):
return var_dict[id(v)]
self.assertEqual(1, lookup_tensor(v1).numpy())
self.assertEqual(2, lookup_tensor(v2).numpy())
self.assertEqual(3, lookup_tensor(v3).numpy())
def testDecoratedMethodInspect(self):
class DefunnedMiniModel(object):
@function.defun
def call(self, inputs, training=True):
pass
m = DefunnedMiniModel()
fullargspec = tf_inspect.getfullargspec(m.call)
self.assertIn('training', fullargspec.args)
def testFunctionModifiesInputList(self):
# Tests on `list` methods that do in place modification, except `list.sort`
# since it cannot even be "defunned" in the first place
def get_list():
return [constant_op.constant(0.), constant_op.constant(1.)]
expected_msg = '.*() should not modify'
with self.assertRaisesRegex(ValueError, expected_msg):
@def_function.function
def append(l):
l.append(constant_op.constant(0.))
append(get_list())
with self.assertRaisesRegex(ValueError, expected_msg):
@def_function.function
def extend(l):
l.extend([constant_op.constant(0.)])
extend(get_list())
with self.assertRaisesRegex(ValueError, expected_msg):
@def_function.function
def insert(l):
l.insert(0, constant_op.constant(0.))
insert(get_list())
with self.assertRaisesRegex(ValueError, expected_msg):
@def_function.function
def pop(l):
l.pop()
pop(get_list())
with self.assertRaisesRegex(ValueError, expected_msg):
@def_function.function
def reverse(l):
l.reverse()
reverse(get_list())
with self.assertRaisesRegex(ValueError, expected_msg):
@def_function.function
def remove(l):
l.remove(l[0])
remove(get_list())
# `list.clear` is a method that is in Py3 but not Py2
if sys.version.startswith('3'):
with self.assertRaisesRegex(ValueError, expected_msg):
@def_function.function
def clear(l):
l.clear()
clear(get_list())
# One last test for keyword arguments
with self.assertRaisesRegex(ValueError, expected_msg):
@def_function.function
def kwdappend(**kwargs):
l = kwargs['l']
l.append(constant_op.constant(0.))
kwdappend(l=get_list())
def testFunctionModifiesInputDict(self):
def get_dict():
return {'t1': constant_op.constant(0.), 't2': constant_op.constant(1.)}
expected_msg = '.* should not modify'
with self.assertRaisesRegex(ValueError, expected_msg):
@def_function.function
def clear(m):
m.clear()
clear(get_dict())
with self.assertRaisesRegex(ValueError, expected_msg):
@def_function.function
def pop(m):
m.pop('t1')
pop(get_dict())
with self.assertRaisesRegex(ValueError, expected_msg):
@def_function.function
def popitem(m):
m.popitem()
popitem(get_dict())
with self.assertRaisesRegex(ValueError, expected_msg):
@def_function.function
def update(m):
m.update({'t1': constant_op.constant(3.)})
update(get_dict())
with self.assertRaisesRegex(ValueError, expected_msg):
@def_function.function
def setdefault(m):
m.setdefault('t3', constant_op.constant(3.))
setdefault(get_dict())
def testFunctionModifiesInputNest(self):
with self.assertRaisesRegex(ValueError, 'modify.* should not modify'):
@def_function.function
def modify(n):
n[0]['t1'].append(constant_op.constant(1.))
nested_input = [{
't1': [constant_op.constant(0.),
constant_op.constant(1.)],
},
constant_op.constant(2.)]
modify(nested_input)
with self.assertRaisesRegex(ValueError,
'modify_same_flat.* should not modify'):
# The flat list doesn't change whereas the true structure changes
@def_function.function
def modify_same_flat(n):
n[0].append(n[1].pop(0))
nested_input = [[constant_op.constant(0.)],
[constant_op.constant(1.),
constant_op.constant(2.)]]
modify_same_flat(nested_input)
def testExecutorType(self):
@function.defun
def add_five(x):
return x + 5
self.assertEqual(
5,
add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy())
with self.assertRaisesRegex(errors.NotFoundError, 'NON_EXISTENT_EXECUTOR'):
with context.function_executor_type('NON_EXISTENT_EXECUTOR'):
add_five(constant_op.constant(0, dtype=dtypes.int32))
for executor_type in ('', 'DEFAULT', None):
with context.function_executor_type(executor_type):
self.assertAllEqual(
5,
add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy())
@test_util.assert_no_garbage_created
def testReferenceCycles(self):
fn = function.defun(lambda x: 2. * x)
fn(constant_op.constant(4.0))
weak_fn = weakref.ref(fn)
del fn
# Tests that the weak reference we made to the function is now dead, which
# means the object has been deleted. This should be true as long as the
# function itself is not involved in a reference cycle.
self.assertIs(None, weak_fn())
def testFunctionStackInErrorMessage(self):
if context.executing_eagerly():
# TODO(b/122736651): Remove this skipTest once fixed.
self.skipTest('Error interpolation is not working when function is '
'invoked without PartitionedCallOp.')
@def_function.function()
def fn3(x):
return x + 2
@def_function.function()
def fn2(x):
check_ops.assert_equal(fn3(x), 3)
return 2
@def_function.function()
def fn(x):
return fn2(x)
with self.assertRaises(errors.InvalidArgumentError) as cm:
fn(2)
e = cm.exception
self.assertIn('fn -> fn2', e.message)
self.assertIn('node assert_equal/Assert/Assert (defined at', e.message)
self.assertNotIn('fn3', e.message)
@test_util.run_gpu_only
def testFunctionIsNotPinned(self):
"""Tests that functions aren't pinned to the CPU by the eager runtime."""
seed1, seed2 = 79, 25
shape = constant_op.constant([4, 7])
dtype = dtypes.float32
@def_function.function
def func():
with ops.device('GPU:0'):
return gen_random_ops.random_standard_normal(
shape, dtype=dtype, seed=seed1, seed2=seed2)
with ops.device('GPU:0'):
x = func()
self.assertRegex(x.device, 'GPU')
@test_util.run_in_graph_and_eager_modes
def testShapeCaching(self):
@function.defun
def func(x):
return array_ops.shape(x)
@function.defun(
input_signature=[tensor_spec.TensorSpec([None, None], dtypes.float32)])
def calls_func(x):
return func(x)
self.assertAllEqual([1, 1], self.evaluate(func(array_ops.zeros([1, 1]))))
self.assertAllEqual([2, 2], self.evaluate(func(array_ops.zeros([2, 2]))))
self.assertAllEqual(
[3, 3],
self.evaluate(calls_func(array_ops.zeros([3, 3]))))
def testLimitedRetracing(self):
trace_count = [0]
@function.defun
def func(x):
trace_count[0] += 1
return x
for _ in range(50):
func(constant_op.constant(3.))
func(constant_op.constant(4.))
func(constant_op.constant([[1., 2.]]))
func(constant_op.constant([[]]))
func(constant_op.constant([[3., 4.], [5., 6.]]))
func(constant_op.constant([[3., 4.], [5., 6.], [7., 8.]]))
# Tracing more than twice per input doesn't make sense.
self.assertLess(trace_count[0], 13)
def testLimitedRetracingWithCompositeTensors(self):
trace_count = [0]
@def_function.function
def f(x):
trace_count[0] += 1
return x
for i in range(10):
f(ragged_factory_ops.constant([[1, 2], [i]]))
f(ragged_factory_ops.constant([[1, 2], [], [3, 4, 5]]))
f(ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]]))
self.assertEqual(trace_count[0], 3)
def test_concrete_function_shape_mismatch(self):
@def_function.function
def f(argument_name):
return argument_name + 1.
f_concrete = f.get_concrete_function(constant_op.constant([1.]))
# Calling a function from eager doesn't do any shape checking above what
# kernels do while executing.
self.assertAllEqual(
[2., 3.],
f_concrete(constant_op.constant([1., 2.])).numpy())
@def_function.function
def g():
f_concrete(constant_op.constant([1., 2.]))
with self.assertRaisesRegex(ValueError, 'argument_name'):
g()
@test_util.run_in_graph_and_eager_modes
def test_shape_inference_with_symbolic_shapes(self):
@def_function.function
def _uses_symbolic_shapes(w, x, y):
x = array_ops.identity(x, name='name_collision')
x = array_ops.transpose(x, [1, 0, 2])
x_batch = array_ops.shape(x)[0]
y_batch = array_ops.shape(y)[0]
y *= w
n = y_batch // x_batch
return array_ops.reshape(y, [n, x_batch, -1])
conc = _uses_symbolic_shapes.get_concrete_function(
tensor_spec.TensorSpec(None, dtypes.float32),
tensor_spec.TensorSpec(None, dtypes.float32),
tensor_spec.TensorSpec(None, dtypes.float32))
@def_function.function
def _call_concrete():
c = constant_op.constant(1.)
array_ops.identity(c, name='name_collision')
output1 = conc(array_ops.ones([2]),
array_ops.ones([5, 4, 2]),
array_ops.ones([20, 2]))
self.assertEqual([5, 4, 2], output1.shape)
output2 = conc(array_ops.ones([3]),
array_ops.ones([5, 4, 3]),
array_ops.ones([40, 3]))
self.assertEqual([10, 4, 3], output2.shape)
return output1, output2
output1, output2 = _call_concrete()
self.assertEqual((5, 4, 2), self.evaluate(output1).shape)
self.assertEqual((10, 4, 3), self.evaluate(output2).shape)
def testAutoGraphContext(self):
@def_function.function
def test_fn():
self.assertEqual(
ag_ctx.control_status_ctx().status, ag_ctx.Status.ENABLED)
prev_status = ag_ctx.control_status_ctx().status
test_fn()
self.assertEqual(ag_ctx.control_status_ctx().status, prev_status)
@test_util.disable_tfrt('b/170435618')
def testCancelBeforeFunctionExecution(self):
if not context.executing_eagerly():
self.skipTest('eager only')
q = data_flow_ops.FIFOQueue(1, dtypes.int32)
@def_function.function
def f():
return q.dequeue()
c_mgr = cancellation.CancellationManager()
cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
c_mgr.start_cancel()
with self.assertRaises(errors.CancelledError):
cancelable_func()
@test_util.disable_tfrt('b/170435618')
def testCancelBlockedFunctionExecution(self):
if not context.executing_eagerly():
self.skipTest('eager only')
q = data_flow_ops.FIFOQueue(1, dtypes.int32)
@def_function.function
def f():
return q.dequeue()
c_mgr = cancellation.CancellationManager()
cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
def cancel_thread():
time.sleep(0.5)
c_mgr.start_cancel()
t = self.checkedThread(cancel_thread)
t.start()
with self.assertRaises(errors.CancelledError):
cancelable_func()
t.join()
@test_util.disable_tfrt('b/170435618')
def testCancelAfterFunctionExecution(self):
if not context.executing_eagerly():
self.skipTest('eager only')
q = data_flow_ops.FIFOQueue(1, dtypes.int32)
q.enqueue(37)
@def_function.function
def f():
return q.dequeue()
c_mgr = cancellation.CancellationManager()
cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
self.assertAllEqual(37, cancelable_func().numpy())
# Cancellation after the function executes is a no-op.
c_mgr.start_cancel()
def testAddFunctionCallback(self):
functions = []
def function_callback(f):
functions.append(f)
@def_function.function
def plus_one(x):
return x + 1
try:
function.add_function_callback(function_callback)
x_float32 = numpy.array(3.0, dtype=numpy.float32)
self.assertAllClose(plus_one(x_float32), 4.0)
self.assertLen(functions, 1)
# Function is already created. Executing it again should not invoke the
# function callback.
self.assertAllClose(plus_one(x_float32), 4.0)
self.assertLen(functions, 1)
# Signature change leads to a new Function being built.
x_float64 = numpy.array(3.0, dtype=numpy.float64)
self.assertAllClose(plus_one(x_float64), 4.0)
self.assertLen(functions, 2)
finally:
function.clear_function_callbacks()
def testRemoveFunctionCallback(self):
functions_1 = []
def function_callback_1(f):
functions_1.append(f)
functions_2 = []
def function_callback_2(f):
functions_2.append(f)
@def_function.function
def plus_one(x):
return x + 1
try:
function.add_function_callback(function_callback_1)
function.add_function_callback(function_callback_2)
self.assertAllClose(plus_one(numpy.array(3.0, dtype=numpy.float32)), 4.0)
self.assertLen(functions_1, 1)
self.assertLen(functions_2, 1)
function.remove_function_callback(function_callback_1)
# The 1st callback should not be invokved after remove_function_callback()
# is called.
self.assertAllClose(plus_one(numpy.array(3.0, dtype=numpy.float64)), 4.0)
self.assertLen(functions_1, 1)
self.assertLen(functions_2, 2)
finally:
function.clear_function_callbacks()
def testClearFunctionCallbacks(self):
function.add_function_callback(lambda f: None)
function.add_function_callback(lambda f: None)
self.assertLen(function._function_callbacks, 2)
function.clear_function_callbacks()
self.assertEmpty(function._function_callbacks) # pylint:disable=protected-access
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionWithNestedTensorInputs(self):
@def_function.function
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = constant_op.constant(1000)
b = constant_op.constant(200)
c = constant_op.constant(30)
d = {'a': a, 'b': b}
e = (c, 4)
# Test different argument signatures when constructing the concrete func.
for cf in [
f.get_concrete_function(d, e),
f.get_concrete_function(d, y=e),
f.get_concrete_function(y=e, x=d),
f.get_concrete_function(_spec_for_value(d), _spec_for_value(e)),
f.get_concrete_function(_spec_for_value(d), y=_spec_for_value(e)),
f.get_concrete_function(y=_spec_for_value(e), x=_spec_for_value(d))
]:
# Test different calling conventions when calling the concrete func.
for output in [
cf(d, e), # structured signature
cf(d, y=e), # structured signature w/ kwarg
cf(y=e, x=d), # structured signature w/ 2 kwargs
cf(a, b, c), # flat signature
cf(x=a, x_1=b, y=c) # flat signature w/ kwargs
]:
self.assertIsInstance(output, tuple)
self.assertLen(output, 2)
self.assertAllEqual(output[0], 1200)
self.assertAllEqual(output[1], 34)
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionWithNestedNonTensorInputs(self):
@def_function.function
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = {'a': constant_op.constant(1000), 'b': constant_op.constant(200)}
b = (50, 3)
for cf in [ # argument y is bound to non-Tensor value (50, 3).
f.get_concrete_function(a, b),
f.get_concrete_function(a, y=b),
f.get_concrete_function(x=a, y=b)
]:
for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]:
self.assertAllEqual(output[0] + output[1], 1253)
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionWithBoundNestedNonTensorInputs(self):
@def_function.function
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = {'a': 3000, 'b': 200, 'c': 9000}
b = (constant_op.constant(30), 4)
for cf in [ # argument x is bound to non-tensor value `a`
f.get_concrete_function(a, b),
f.get_concrete_function(a, y=b),
f.get_concrete_function(x=a, y=b)
]:
for output in [cf(a, b), cf(a, y=b), cf(y=b), cf(x=a, y=b)]:
self.assertAllEqual(output[0] + output[1], 3234)
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionWithAllBoundNestedNonTensorInputs(self):
@def_function.function
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = {'a': 5000, 'b': 500}
b = (50, 5)
cf = f.get_concrete_function(a, b)
for output in [cf(), cf(a), cf(y=b)]:
self.assertAllEqual(output[0] + output[1], 5555)
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionMethodWithVarargs(self):
float32_scalar = tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)
class MyModel(module.Module):
@def_function.function(input_signature=[float32_scalar, float32_scalar])
def add(self, *arg):
return math_ops.add(*arg)
m = MyModel()
cf = m.add.get_concrete_function()
cf(-12.0, 3.0)
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionStructuredSignatureKeywordOrder(self):
# Check that keyword-only arguments are sorted appropriately, so that they
# feed the right tensor into each input.
@def_function.function
def g(**kwargs):
return string_ops.reduce_join(
string_ops.reduce_join(
ops.convert_to_tensor(sorted(kwargs.items())),
axis=1,
separator='='),
axis=0,
separator=', ')
s = constant_op.constant('s')
g.get_concrete_function(q=s, a=s, p=s, r=s, v=s, m=s, l=s)
self.assertAllEqual(
g(m='a', r='b', v='c', q='d', l='e', a='f', p='g'),
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
self.assertAllEqual(
g(q='d', a='f', p='g', r='b', v='c', m='a', l='e'),
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
self.assertAllEqual(
g(a='f', l='e', m='a', p='g', q='d', r='b', v='c'),
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
# pylint: disable=g-long-lambda
@parameterized.named_parameters([
dict(
testcase_name='MissingArg',
conc_args=lambda: (1, constant_op.constant(2)),
call_args=lambda: (1,),
error=r'func\(x, y\) missing required arguments: y'),
dict(
testcase_name='MissingVararg',
conc_args=lambda: (1, 2, constant_op.constant(1.0)),
call_args=lambda: (1, 2),
error=r'func\(x, y, <arg3>\) missing required arguments: <arg3>'),
dict(
testcase_name='ExtraPositionalArg',
conc_args=lambda: (1, 2),
call_args=lambda: (1, 2, 3),
error=r'func\(x, y\) takes 2 positional arguments but 3 were given'),
dict(
testcase_name='MissingKeywordOnlyArg',
conc_args=lambda: (1, 2),
conc_kwargs=lambda: {'c': constant_op.constant(1.0)},
call_args=lambda: (1, 2),
error=r'func\(x, y, \*, c\) missing required arguments: c'),
dict(
testcase_name='ExtraKeywordArg',
conc_args=lambda: (1, 2),
call_args=lambda: (1, 2),
call_kwargs=lambda: {'c': constant_op.constant(1.0)},
error=r'func\(x, y\) got unexpected keyword arguments: c'),
dict(
testcase_name='ExpectedRaggedGotNest',
conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
call_args=lambda: ({
'a': constant_op.constant([1, 2, 3])
},),
error=r'func\(x, y\): argument x had incorrect type\n'
r' expected: RaggedTensor\n'
r" got: {'a': (Eager)?Tensor}"),
dict(
testcase_name='WrongRaggedRank',
conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
call_args=lambda: (ragged_factory_ops.constant([[[1]]]),),
error=r'func\(x, y\): argument x had incorrect type\n'),
dict(
testcase_name='WrongRaggedDType',
conc_args=lambda: (ragged_factory_ops.constant([[1]]),),
call_args=lambda: (ragged_factory_ops.constant([[1.0]]),),
error=r'func\(x, y\): argument x had incorrect type\n'),
dict(
testcase_name='ExpectedDictGotTensor',
conc_args=lambda: ({
'a': constant_op.constant(1),
'b': constant_op.constant(1)
},),
call_args=lambda: (constant_op.constant(1),),
error=r'func\(x, y\): argument x had incorrect type\n'),
dict(
testcase_name='ExpectedTupleGotTensor',
conc_args=lambda:
((constant_op.constant(1), constant_op.constant(2)),),
call_args=lambda: (constant_op.constant(1),),
error=r'func\(x, y\): argument x had incorrect type\n'),
dict(
testcase_name='WrongDType',
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (constant_op.constant(1.0),),
exception=(ValueError, errors.InvalidArgumentError,
# on xla_gpu, we get InternalError instead.
errors.InternalError)),
dict(
testcase_name='ExpectedTensorGotInt',
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (5,),
error=r'func\(x, y\) expected a Tensor in x, but got int value 5'),
dict(
testcase_name='ExpectedIntGotDifferentInt',
conc_args=lambda: (5,),
call_args=lambda: (8,),
error=r'ConcreteFunction func\(x, y\) was constructed with int '
r'value 5 in x, but was called with int value 8'),
dict(
testcase_name='ExpectedIntGotTensor',
conc_args=lambda: (5,),
call_args=lambda: (constant_op.constant(6),),
error=r'ConcreteFunction func\(x, y\) was constructed with int '
'value 5 in x, but was called with (Eager)?Tensor value .*'),
dict(
testcase_name='TwoValuesForArgument',
conc_args=lambda: (1, 2),
call_args=lambda: (1, 2),
call_kwargs=lambda: {'x': 3},
error=r"func\(x, y\) got two values for argument 'x'"),
])
# pylint: enable=g-long-lambda
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionStructuredSignatureError(self,
conc_args=(),
conc_kwargs=None,
call_args=(),
call_kwargs=None,
error='.*',
exception=TypeError):
"""Tests for errors in the structrued signature.
Args:
conc_args: Positional arguments used for get_concrete_function.
conc_kwargs: Keyword arguments used for get_concrete_function.
call_args: Positional arguments used to call the function.
call_kwargs: Keyword arguments used to call the function.
error: Expected exception message.
exception: Expected exception type.
"""
conc_args = conc_args() if callable(conc_args) else conc_args
conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
call_args = call_args() if callable(call_args) else call_args
call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
self.assertIsInstance(conc_args, tuple)
self.assertIsInstance(call_args, tuple)
self.assertIsInstance(conc_kwargs, dict)
self.assertIsInstance(call_kwargs, dict)
@def_function.function
def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg
del y, varargs, kwargs
return x
conc = func.get_concrete_function(*conc_args, **conc_kwargs)
with self.assertRaisesRegex(exception, error):
self.evaluate(conc(*call_args, **call_kwargs))
# pylint: disable=g-long-lambda
@parameterized.named_parameters([
dict(
testcase_name='MissingArg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (constant_op.constant(1),),
error=r'func\(x, y\) missing required arguments: y'),
dict(
testcase_name='TwoValuesForArg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (constant_op.constant(1),),
call_kwargs=lambda: {
'x': constant_op.constant(1),
'y': constant_op.constant(1)
},
error=r"func\(x, y\) got two values for argument 'x'"),
dict(
testcase_name='ExtraPositionalArg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (constant_op.constant(1), constant_op.constant(2),
constant_op.constant(3)),
error=r'func\(x, y\) takes 2 positional arguments but 3 were given'),
dict(
testcase_name='UnexpectedKeywordArg',
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (constant_op.constant(1),),
call_kwargs=lambda: {'c': constant_op.constant(1)},
error=r'func\(x\) got unexpected keyword arguments: c'),
dict(
testcase_name='MissingVararg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2),
constant_op.constant(3)),
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
error=r'func\(x, y, varargs_0\) missing required '
r'arguments: varargs_0'),
dict(
testcase_name='MissingKeywordArg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
conc_kwargs=lambda: {'c': constant_op.constant(1)},
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
error=r'func\(x, y, c\) missing required arguments: c'),
dict(
testcase_name='ExpectedTensorGotInt',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (5, constant_op.constant(2)),
error=r'func\(x, y\): expected argument #0\(zero-based\) to be '
r'a Tensor; got int \(5\)'),
dict(
testcase_name='WrongDType',
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (constant_op.constant(1.0),),
exception=(ValueError, errors.InvalidArgumentError,
# on xla_gpu, we get InternalError instead.
errors.InternalError)),
dict(
testcase_name='MissingKeywordArgNestPiece',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
conc_kwargs=lambda: {'c': ragged_factory_ops.constant([[1]])},
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_kwargs=lambda: {'c': constant_op.constant(1)},
error=r'func\(x, y, c, c_1\) missing required arguments: c_1'),
])
# pylint: enable=g-long-lambda
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionFlatSignatureError(self,
conc_args=(),
conc_kwargs=None,
call_args=(),
call_kwargs=None,
error='.*',
exception=TypeError):
"""Tests for errors in the flat signature.
Args:
conc_args: Positional arguments used for get_concrete_function.
conc_kwargs: Keyword arguments used for get_concrete_function.
call_args: Positional arguments used to call the function.
call_kwargs: Keyword arguments used to call the function.
error: Expected exception message.
exception: Expected exception type.
"""
conc_args = conc_args() if callable(conc_args) else conc_args
conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
call_args = call_args() if callable(call_args) else call_args
call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
self.assertIsInstance(conc_args, tuple)
self.assertIsInstance(call_args, tuple)
self.assertIsInstance(conc_kwargs, dict)
self.assertIsInstance(call_kwargs, dict)
@def_function.function
def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg
del y, varargs, kwargs
return x
conc = func.get_concrete_function(*conc_args, **conc_kwargs)
# Remove _function_spec, to disable the structured signature.
conc._set_function_spec(None) # pylint: disable=protected-access
with self.assertRaisesRegex(exception, error):
self.evaluate(conc(*call_args, **call_kwargs))
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionAmbiguousSignature(self):
# When both the flat & structured signatures are applicable, but they
# give different results, we use the structured signature. Note: we expect
# this to be extremely rare.
@def_function.function
def f(x, y):
return x * 10 + y
conc = f.get_concrete_function(
x=tensor_spec.TensorSpec(None, dtypes.int32, name='y'),
y=tensor_spec.TensorSpec(None, dtypes.int32, name='x'))
result = conc(x=constant_op.constant(5), y=constant_op.constant(6))
self.assertAllEqual(result, 56)
def testPrettyPrintedSignature(self):
@def_function.function
def func(x, kangaroo=None, octopus=7):
del octopus, kangaroo
return x
scalar = constant_op.constant(5)
vector = constant_op.constant([10, 10, 20])
ragged = ragged_factory_ops.constant([[10, 20], [40]])
c1 = func.get_concrete_function(scalar, vector)
c1_summary = r'func\(x, kangaroo, octopus=7\)'
c1_details = (r' Args:\n'
r' x: int32 Tensor, shape=\(\)\n'
r' kangaroo: int32 Tensor, shape=\(3,\)\n'
r' Returns:\n'
r' int32 Tensor, shape=\(\)')
self.assertRegex(c1.pretty_printed_signature(verbose=False), c1_summary)
self.assertRegex(
c1.pretty_printed_signature(verbose=True),
c1_summary + '\n' + c1_details)
self.assertRegex(
repr(c1), r'<ConcreteFunction func\(x, kangaroo, octopus=7\) at .*>')
self.assertRegex(
str(c1), 'ConcreteFunction {}\n{}'.format(c1_summary, c1_details))
c2 = func.get_concrete_function(scalar, ragged, 3)
c2_summary = r'func\(x, kangaroo, octopus=3\)'
c2_details = (r' Args:\n'
r' x: int32 Tensor, shape=\(\)\n'
r' kangaroo: RaggedTensorSpec\(.*\)\n'
r' Returns:\n'
r' int32 Tensor, shape=\(\)')
self.assertRegex(c2.pretty_printed_signature(),
c2_summary + '\n' + c2_details)
c3 = func.get_concrete_function({'a': scalar, 'b': [ragged, ragged]})
c3_summary = r'func\(x, kangaroo=None, octopus=7\)'
c3_details = (r' Args:\n'
r" x: {'a': <1>, 'b': \[<2>, <3>\]}\n"
r' <1>: int32 Tensor, shape=\(\)\n'
r' <2>: RaggedTensorSpec\(.*\)\n'
r' <3>: RaggedTensorSpec\(.*\)\n'
r' Returns:\n'
r" {'a': <1>, 'b': \[<2>, <3>\]}\n"
r' <1>: int32 Tensor, shape=\(\)\n'
r' <2>: RaggedTensorSpec\(.*\)\n'
r' <3>: RaggedTensorSpec\(.*\)')
# python 3.5 does not gurantee deterministic iteration of dict contents
# which can lead mismatch on pretty_printed_signature output for "Args"
if sys.version_info >= (3, 6):
self.assertRegex(c3.pretty_printed_signature(),
c3_summary + '\n' + c3_details)
# pylint: disable=keyword-arg-before-vararg
@def_function.function
def func2(x, y=3, *args, **kwargs):
return (x, y, args, kwargs)
c4 = func2.get_concrete_function(scalar, 4, 5, a=scalar)
c4_summary = 'func2(x, y=4, <arg3>=5, *, a)'
self.assertEqual(c4.pretty_printed_signature(verbose=False), c4_summary)
c5 = func2.get_concrete_function(8, vector)
c5_summary = 'func2(x=8, y)'
self.assertEqual(c5.pretty_printed_signature(verbose=False), c5_summary)
def testPrettyPrintedExplicitSignatureWithKeywordArg(self): # b/159639913
@def_function.function(input_signature=[tensor_spec.TensorSpec(None)])
def fn(a, b=1):
return a + b
concrete_fn = fn.get_concrete_function()
self.assertEqual(concrete_fn.pretty_printed_signature(False), 'fn(a)')
self.assertEqual(
concrete_fn.pretty_printed_signature(True), 'fn(a)\n'
' Args:\n'
' a: float32 Tensor, shape=<unknown>\n'
' Returns:\n'
' float32 Tensor, shape=<unknown>')
@test_util.run_in_graph_and_eager_modes
def testIndexedSlicesAsGradientsForConcreteFunctions(self):
@def_function.function
def summing_rnn(inputs):
return math_ops.reduce_sum(inputs, axis=1)
@def_function.function
def gradients(inputs):
with backprop.GradientTape() as tape:
tape.watch(inputs)
hidden = summing_rnn(inputs)
hidden = array_ops.gather(hidden, constant_op.constant([0]))
loss = math_ops.reduce_mean(hidden)
return tape.gradient(loss, inputs)
gradients(constant_op.constant([[[1.0], [2.0]]])) # No error is raised
def testFollowTypeHintsTraceBasic(self):
trace_count = [0]
def func(x: ops.Tensor):
trace_count[0] += 1
return x
enabled = def_function.function(func, experimental_follow_type_hints=True)
disabled = def_function.function(func, experimental_follow_type_hints=False)
enabled(1) # Initial call gets traced
enabled(2)
enabled(3)
self.assertEqual(trace_count[0], 1)
trace_count = [0]
disabled(1)
disabled(2) # Retrace
disabled(3) # Retrace
self.assertEqual(trace_count[0], 3)
def testFollowTypeHintsTraceWithArgs(self):
trace_count = [0]
def func(*args: ops.Tensor):
trace_count[0] += 1
return args
enabled = def_function.function(func, experimental_follow_type_hints=True)
disabled = def_function.function(func, experimental_follow_type_hints=False)
args = (
'abc',
'def',
) * 20
args2 = (
'def',
'abc',
) * 20
enabled(args)
enabled(args2)
self.assertEqual(trace_count[0], 1)
trace_count = [0]
disabled(args)
disabled(args2) # Retrace
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithKwargs(self):
trace_count = [0]
def func(t: ops.Tensor, **kwargs: ops.Tensor):
del kwargs
trace_count[0] += 1
return t
enabled = def_function.function(func, experimental_follow_type_hints=True)
disabled = def_function.function(func, experimental_follow_type_hints=False)
enabled(1, x=1, y=1.0, z='one')
enabled(2, x=2, y=2.0, z='two')
self.assertEqual(trace_count[0], 1)
trace_count = [0]
disabled(1, x=1, y=1.0, z='one')
disabled(2, x=2, y=2.0, z='two') # Retrace
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithMultipleInputTypes(self):
trace_count = [0]
def func(t: ops.Tensor, *args: ops.Tensor, **kwargs: ops.Tensor):
del args, kwargs
trace_count[0] += 1
return t
enabled = def_function.function(func, experimental_follow_type_hints=True)
disabled = def_function.function(func, experimental_follow_type_hints=False)
enabled(1, constant_op.constant(1), 'str', x=4.0)
enabled(2, constant_op.constant(2), 'str2', x=5.0)
self.assertEqual(trace_count[0], 1)
trace_count = [0]
disabled(1, constant_op.constant(1), 'str', x=4.0)
disabled(2, constant_op.constant(2), 'str2', x=5.0) # Retrace
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithOnlyArgNamed(self):
trace_count = [0]
def func(t: ops.Tensor, i: int = 1, **kwargs): # pylint: disable=bad-whitespace
del i, kwargs
trace_count[0] += 1
return t
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(1, 3, x=4.0, y='str')
enabled(2, 4, x=4.0, y='str') # Retrace
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithNotAllNamed(self):
trace_count = [0]
def func(x, y: ops.Tensor, z: int):
del y, z
trace_count[0] += 1
return x
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(1, 2, 3)
enabled(1, 20, 3) # No retrace - change in ops.Tensor typed arg
enabled(2, 2, 3) # Retrace - change in untyped arg
enabled(2, 2, 4) # Retrace - change in typed arg
self.assertEqual(trace_count[0], 3)
def testFollowTypeHintsTraceWithOnlyArgsNamed(self):
trace_count = [0]
def func(x, y, *args: ops.Tensor):
del y, args
trace_count[0] += 1
return x
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(1, 20, 3, 4, 5, 6)
enabled(1, 20, 3, 4, 5, 60) # No retrace - change in *args
enabled(1, 30, 7, 8, 9, 10) # Retrace - change in args
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithOnlyKwargsNamed(self):
trace_count = [0]
def func(x, y, *args, **kwargs: ops.Tensor):
del y, args, kwargs
trace_count[0] += 1
return x
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(1, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0)
enabled(
1, 2, 3, 4, 5, 6, a=1.5, b=2.5,
c=3.5) # No retrace - change in **kwargs
enabled(100, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0) # Retrace - change in args
enabled(
1, 2, 3, 4, 5, 100, a=1.0, b=2.0, c=3.0) # Retrace - change in *args
self.assertEqual(trace_count[0], 3)
def testFollowTypeHintsTraceWithArgsEquals(self):
trace_count = [0]
def func(
x: ops.Tensor = 0, # pylint:disable=bad-whitespace
y: int = 1, # pylint:disable=bad-whitespace
**kwargs: ops.Tensor):
del y, kwargs
trace_count[0] += 1
return x
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(x=1, y=2, z=3)
enabled(x=1, y=3, z=3) # Retrace - change in args
enabled(x=2, y=2, z=4) # No retrace - change in args and **kwargs
enabled(x=2, y=2, z=4, u=5) # Retrace - change in **kwargs
self.assertEqual(trace_count[0], 3)
def testFollowTypeHintsTraceWithArgsEqualsTypedKwargs(self):
trace_count = [0]
def func(x, y, **kwargs: ops.Tensor):
del y, kwargs
trace_count[0] += 1
return x
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(x=1, y=2, z=3)
enabled(x=1, y=3, z=3) # Retrace
enabled(x=1, y=2, z=4) # No retrace
enabled(x=2, y=2, z=4) # Retrace
enabled(x=2, y=2, z=4, u=5) # Retrace
self.assertEqual(trace_count[0], 4)
def testFollowTypeHintsTraceWithArgsEqualsTypedArgs(self):
trace_count = [0]
def func(x: ops.Tensor, y: int, **kwargs):
del y, kwargs
trace_count[0] += 1
return x
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(x=1, y=2, z=3)
enabled(x=1, y=3, z=3) # Retrace
enabled(x=1, y=2, z=4) # Retrace
enabled(x=2, y=2, z=3) # No retrace
enabled(x=2, y=2, z=4, u=5) # Retrace
self.assertEqual(trace_count[0], 4)
def testFollowTypeHintsTraceWithKwOnlyArgsBasic(self):
trace_count = [0]
def func(*, a: ops.Tensor = None, b=1): # pylint: disable=bad-whitespace
del b
trace_count[0] += 1
return a
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(a=1, b=2)
enabled(a=2, b=2) # No retrace
enabled(a=1, b=1) # Retrace
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArg(self):
trace_count = [0]
def func(arg: ops.Tensor, *args, kwonly, **kwargs):
del args, kwonly, kwargs
trace_count[0] += 1
return arg
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # No retrace
enabled(1000, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # No retrace
enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace
self.assertEqual(trace_count[0], 4)
def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArgs(self):
trace_count = [0]
def func(arg, *args: ops.Tensor, kwonly, **kwargs):
del args, kwonly, kwargs
trace_count[0] += 1
return arg
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # No retrace
enabled(1, 200, 300, 400, kwonly=5, kwarg1=6, kwarg2=7) # No retrace
enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace
self.assertEqual(trace_count[0], 4)
def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwOnlyArg(self):
trace_count = [0]
def func(arg, *args, kwonly: ops.Tensor, **kwargs):
del args, kwonly, kwargs
trace_count[0] += 1
return arg
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # No retrace
enabled(1, 2, 3, 4, kwonly=500, kwarg1=6, kwarg2=7) # No retrace
enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace
self.assertEqual(trace_count[0], 4)
def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwargs(self):
trace_count = [0]
def func(arg, *args, kwonly, **kwargs: ops.Tensor):
del args, kwonly, kwargs
trace_count[0] += 1
return arg
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # No retrace
enabled(1, 2, 3, 4, kwonly=5, kwarg1=600, kwarg2=700) # No retrace
self.assertEqual(trace_count[0], 4)
def testWithModuleNameScope(self):
self.skipTest('b/166158748:function does not handle this case correctly.')
class Foo(module.Module):
def __init__(self):
super().__init__()
self.var = None
@def_function.function
@module.Module.with_name_scope
def add(self, x, y, z=1):
if self.var is None:
return x + y + z
foo = Foo()
self.assertEqual(foo.add(2, 3), 6)
def testWithModuleNameScopeRedundantArgs(self):
self.skipTest('b/166158748:function does not handle this case correctly.')
class Foo(module.Module):
def __init__(self):
super().__init__()
self.var = None
@def_function.function
@module.Module.with_name_scope
def add(self, x, y):
if self.var is None:
return x + y
foo = Foo()
with self.assertRaisesRegex(TypeError, 'got two values for argument'):
foo.add(2, x=3) # pylint: disable=redundant-keyword-arg,no-value-for-parameter
def testWithModuleNameScopeMissingArgs(self):
self.skipTest('b/166158748:function does not handle this case correctly.')
class Foo(module.Module):
def __init__(self):
super().__init__()
self.var = None
@def_function.function
@module.Module.with_name_scope
def add(self, x, y):
if self.var is None:
return x + y
foo = Foo()
with self.assertRaisesRegex(TypeError, 'missing required arguments: y'):
foo.add(2) # pylint: disable=no-value-for-parameter
def testShapeInferencePropagateConstNestedStack(self):
@def_function.function(input_signature=[
tensor_spec.TensorSpec((None, None), dtype=dtypes.int32),
tensor_spec.TensorSpec((), dtype=dtypes.int32),
])
def f(x, s):
old_shape = array_ops.shape(x)
new_shape = array_ops.stack([old_shape[0], s], axis=0)
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
return y
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32)
])
def g(x):
y = f(x, s=5)
assert y.shape.as_list() == [3, 5], y.shape.as_list()
return y
self.assertAllEqual(
g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5]))
def testShapeInferencePropagateConstNestedUnstackStack(self):
@def_function.function(input_signature=[
tensor_spec.TensorSpec((None, None), dtype=dtypes.int32),
tensor_spec.TensorSpec((), dtype=dtypes.int32),
])
def f(x, s):
s0, _ = array_ops.unstack(array_ops.shape(x), axis=0)
new_shape = array_ops.stack([s0, s], axis=0)
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
return y
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32)
])
def g(x):
y = f(x, s=5)
assert y.shape.as_list() == [3, 5], y.shape.as_list()
return y
self.assertAllEqual(
g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5]))
def testShapeInferencePropagateConstNestedConcat(self):
@def_function.function(input_signature=[
tensor_spec.TensorSpec((), dtype=dtypes.int32),
tensor_spec.TensorSpec((), dtype=dtypes.int32),
tensor_spec.TensorSpec((), dtype=dtypes.int32),
])
def f(d1, d2, d3):
new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1)
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
return y
@def_function.function()
def g():
y = f(1, 2, 3)
assert y.shape.as_list() == [1, 2, 3], y.shape.as_list()
return y
self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
def testShapeInferencePropagateConstDoubleNested(self):
@def_function.function(input_signature=[
tensor_spec.TensorSpec((), dtype=dtypes.int32),
tensor_spec.TensorSpec((), dtype=dtypes.int32),
tensor_spec.TensorSpec((), dtype=dtypes.int32),
])
def f(d1, d2, d3):
new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1)
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
return y
@def_function.function()
def g():
y = def_function.function(f)(1, 2, 3)
assert y.shape.as_list() == [1, 2, 3], y.shape.as_list()
return y
self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
@test_util.run_v2_only
def testControlDependencyAfterInline(self):
v = variables.Variable(0.)
@def_function.function
def assign():
return v.assign(1.)
@def_function.function
def assign_add():
return v.assign_add(1.)
@def_function.function
def f():
check_ops.assert_equal_v2(assign(), 1.)
check_ops.assert_equal_v2(assign_add(), 2.)
# We don't have a way to inspect the inlined graph in Python, so we run it
# multiple times to have more confidence the dependency is correct.
for _ in range(30):
f()
@test_util.run_v2_only
def testReadInFuncWriteOutside(self):
# Run many times since we are testing for a potential race condition.
for _ in range(30):
# pylint: disable=cell-var-from-loop
v = variables.Variable(1.)
@def_function.function
def add_one():
return v + 1.
@def_function.function
def get_v_plus_one():
v_plus_one = add_one()
v.assign_add(2.0)
return v_plus_one
self.assertAllEqual(get_v_plus_one(), 2.0)
class MultiDeviceTest(test.TestCase, parameterized.TestCase):
@test_util.run_gpu_only
def testMultiDeviceOutput(self):
"""Tests that functions can produce outputs on multiple devices."""
@function.defun
def func(a, b, transpose_a):
with ops.device('/device:CPU:0'):
m1 = math_ops.matmul(a, b, transpose_a=transpose_a)
with ops.device('/device:GPU:0'):
m2 = math_ops.matmul(a, b, transpose_a=transpose_a)
return m1, m2
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
m1, m2 = func(t, t, transpose_a=True)
self.assertAllEqual(m1.numpy(), [[10, 14], [14, 20]])
self.assertRegex(m1.backing_device, 'CPU')
self.assertAllEqual(m2.numpy(), [[10, 14], [14, 20]])
self.assertRegex(m2.backing_device, 'GPU')
@test_util.run_gpu_only
def testEmptyBody(self):
@function.defun
def func(a, b):
return b, a
with ops.device('/device:CPU:0'):
a = array_ops.identity(3.0)
with ops.device('/device:GPU:0'):
b = array_ops.identity(5.0)
m1, m2 = func(a, b)
self.assertAllEqual(m1.numpy(), 5.0)
self.assertRegex(m1.backing_device, 'GPU')
self.assertAllEqual(m2.numpy(), 3.0)
self.assertRegex(m2.backing_device, 'CPU')
@test_util.run_gpu_only
def testMultiDeviceInt32(self):
"""Tests that multi-device functions can take and output INT32s.
When an INT32 device tensor is fed into a function, it is copied to CPU
by the eager runtime. The function sees all INT32 inputs on CPU.
We set allocator attribute 'on_host' for INT32 outputs. They can be
partitioned into the GPU component function, but will be allocated on
CPU nevertheless.
There is experimental support for `ints_on_device` in
FunctionLibraryRuntime now. We can try that.
"""
with ops.device('/device:CPU:0'):
int_cpu = constant_op.constant(3, dtype=dtypes.int32)
resource = resource_variable_ops.ResourceVariable(5, dtype=dtypes.int32)
with ops.device('/device:GPU:0'):
int_gpu = constant_op.constant(7, dtype=dtypes.int32)
@function.defun
def func(int_cpu, resource, int_gpu):
with ops.device('/device:CPU:0'):
m1 = int_cpu * resource + int_gpu
with ops.device('/device:GPU:0'):
# This computation will happen on GPU but m2 will be copied to CPU.
m2 = int_gpu * resource + int_cpu + 1
return m1, m2
m1, m2 = func(int_cpu, resource, int_gpu)
self.assertAllEqual(m1.numpy(), 22)
self.assertRegex(m1.backing_device, 'CPU')
self.assertAllEqual(m2.numpy(), 39)
self.assertRegex(m2.backing_device, 'CPU')
# flip arguments
m1, m2 = func(int_gpu, resource, int_cpu)
self.assertAllEqual(m1.numpy(), 38)
self.assertRegex(m1.backing_device, 'CPU')
self.assertAllEqual(m2.numpy(), 23)
self.assertRegex(m2.backing_device, 'CPU')
@test_util.run_gpu_only
def testMultiDeviceColocateWith(self):
"""Tests that function's outputs respect colocation constraints."""
@function.defun
def func(a, b):
with ops.colocate_with(a):
ra = 2 * a
with ops.colocate_with(b):
rb = 3 * b
return ra, rb
devices = ['/device:CPU:0', '/device:GPU:0']
for dev1, dev2 in itertools.product(devices, devices):
with ops.device(dev1):
a = array_ops.identity(1.0)
with ops.device(dev2):
b = array_ops.identity(10.0)
ra, rb = func(a, b)
self.assertEqual(ra.numpy(), 2.0)
self.assertRegex(ra.backing_device, dev1)
self.assertEqual(rb.numpy(), 30.0)
self.assertRegex(rb.backing_device, dev2)
@test_util.run_gpu_only
def testMultiDeviceResources(self):
with ops.device('/device:CPU:0'):
c1 = resource_variable_ops.ResourceVariable(2.0)
c2 = resource_variable_ops.ResourceVariable(7.0)
with ops.device('/device:GPU:0'):
g1 = resource_variable_ops.ResourceVariable(3.0)
g2 = resource_variable_ops.ResourceVariable(5.0)
@function.defun
def func(resource1, resource2):
with ops.device('/device:CPU:0'):
result1 = resource1 * g2
with ops.device('/device:GPU:0'):
result2 = resource2 * c2
return result1, result2
r1, r2 = func(c1, g1)
self.assertEqual(r1.numpy(), 10.0)
self.assertRegex(r1.backing_device, 'CPU')
self.assertEqual(r2.numpy(), 21.0)
self.assertRegex(r2.backing_device, 'GPU')
# Call with flipped inputs. Check that we look at resource's
# device and reinstantiates the function when inputs' devices change.
r1, r2 = func(g1, c1)
self.assertEqual(r1.numpy(), 15.0)
self.assertRegex(r1.backing_device, 'CPU')
self.assertEqual(r2.numpy(), 14.0)
self.assertRegex(r2.backing_device, 'GPU')
@test_util.run_gpu_only
def testOutputResources(self):
with ops.device('/device:CPU:0'):
c1 = resource_variable_ops.ResourceVariable(2.0)
with ops.device('/device:GPU:0'):
g1 = resource_variable_ops.ResourceVariable(3.0)
@function.defun
def func(resource1, resource2):
with ops.device('/device:CPU:0'):
result1 = resource1 * 5
with ops.device('/device:GPU:0'):
result2 = resource2 * 7
return result1, resource1.handle, result2, resource2.handle
r1, res1, r2, res2 = func(c1, g1)
self.assertEqual(r1.numpy(), 10.0)
self.assertRegex(r1.backing_device, 'CPU')
self.assertEqual(r2.numpy(), 21.0)
self.assertRegex(r2.backing_device, 'GPU')
def check_handle(handle, expected_value):
self.assertRegex(handle.backing_device, 'CPU')
tensor = gen_resource_variable_ops.read_variable_op(
handle, dtypes.float32)
self.assertEqual(tensor.numpy(), expected_value)
# Check that handles returned from functions are on CPU and an op using
# the resource handle is correctly placed on the device backing the
# resource.
check_handle(res1, 2.0)
check_handle(res2, 3.0)
# Call with flipped inputs to make sure the same the function is
# reinstantiated and eager runtime does not mess up the device assignment
# for ops consuming handles returned from defuns.
r1, res1, r2, res2 = func(g1, c1)
self.assertEqual(r1.numpy(), 15.0)
self.assertRegex(r1.backing_device, 'CPU')
self.assertEqual(r2.numpy(), 14.0)
self.assertRegex(r2.backing_device, 'GPU')
check_handle(res1, 3.0)
check_handle(res2, 2.0)
@test_util.run_gpu_only
def testPassResourceThroughNestedFunctionCall(self):
"""Test passing GPU resource to noinline function call placed on CPU.
PartitionedCallOp must not enforce any particular device assignment for the
resource output. Inner function marked as `_nospecialize`, so Grappler would
not prune unused function output.
"""
with ops.device('/device:GPU:0'):
g1 = resource_variable_ops.ResourceVariable(3.0)
@function.defun_with_attributes(attributes={
'_noinline': True,
'_nospecialize': True
})
def inner(resource1):
return resource1 * 2, resource1.handle
@function.defun
def outer(resource1):
with ops.device('/device:CPU:0'):
r1, _ = inner(resource1)
return r1
r1 = outer(g1)
self.assertEqual(r1.numpy(), 6.0)
self.assertRegex(r1.backing_device, 'CPU')
@test_util.run_gpu_only
def testReturnResourceFromNestedFunctionCall(self):
"""Test returning GPU resource from noinline function call placed on CPU.
When inferring output devices for the return value, do not set a device for
returns of DT_RESOURCE data type based on the device assignment of the node
that produced that resource. As an example function call placed on CPU can
return resources on GPU.
"""
with ops.device('/device:GPU:0'):
g1 = resource_variable_ops.ResourceVariable(3.0)
@function.defun_with_attributes(attributes={
'_noinline': True
})
def inner(resource1):
resource1.assign_add(2.0)
return resource1 * 2, resource1.handle
@function.defun
def outer(resource1):
with ops.device('/device:CPU:0'):
r1, res1 = inner(resource1)
return r1, res1
r1, res1 = outer(g1)
self.assertEqual(r1.numpy(), 10.0)
self.assertRegex(r1.backing_device, 'CPU')
def check_handle(handle, expected_value):
self.assertRegex(handle.backing_device, 'CPU')
tensor = gen_resource_variable_ops.read_variable_op(
handle, dtypes.float32)
self.assertEqual(tensor.numpy(), expected_value)
# Check that handles returned from functions are on CPU and an op using
# the resource handle is correctly placed on the device backing the
# resource.
check_handle(res1, 5.0)
@test_util.run_gpu_only
def testComplexInputOutputDevicePattern(self):
"""Tests input/output mapping logic in partitioning."""
with ops.device('/device:CPU:0'):
rc0 = resource_variable_ops.ResourceVariable(2.0)
rc1 = resource_variable_ops.ResourceVariable(3.0)
cc0 = array_ops.identity(5.0)
cc1 = array_ops.identity(7.0)
with ops.device('/device:GPU:0'):
rg0 = resource_variable_ops.ResourceVariable(11.0)
rg1 = resource_variable_ops.ResourceVariable(13.0)
cg0 = array_ops.identity(17.0)
cg1 = array_ops.identity(19.0)
# Make sure tensors are on expected devices.
for tensor in [cc0, cc1]:
self.assertRegex(tensor.backing_device, 'CPU:0')
for tensor in [cg0, cg1]:
self.assertRegex(tensor.backing_device, 'GPU:0')
@function.defun
def func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1):
with ops.device('/device:CPU:0'):
m1 = rc0 * cg0
with ops.device('/device:GPU:0'):
m2 = rg0 * cc0
with ops.device('/device:CPU:0'):
r1 = 1000.0 * m2 + rc1 * cg1
with ops.device('/device:GPU:0'):
r2 = 1000.0 * m1 + rg1 * cc1
return r1, r2, m2, m1
r1, r2, m2, m1 = func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1)
self.assertRegex(m1.backing_device, 'CPU')
self.assertRegex(r1.backing_device, 'CPU')
self.assertRegex(m2.backing_device, 'GPU')
self.assertRegex(r2.backing_device, 'GPU')
self.assertEqual(m1.numpy(), 34.0)
self.assertEqual(r1.numpy(), 55000.0 + 3.0 * 19.0)
self.assertEqual(m2.numpy(), 55.0)
self.assertEqual(r2.numpy(), 34000.0 + 13.0 * 7.0)
@test_util.run_gpu_only
def testArgumentPruning(self):
"""Tests functions taking unnecessary arguments."""
with ops.device('/device:CPU:0'):
c1 = constant_op.constant(5.0)
c2 = constant_op.constant(7.0)
with ops.device('/device:GPU:0'):
g1 = constant_op.constant(11.0)
g2 = constant_op.constant(13.0)
g3 = constant_op.constant(17.0)
@function.defun
def func(g1, g2, c1, g3, c2): # pylint: disable=unused-argument
# arguments g1 and g2 are unused and can be pruned by grappler.
return c1 * g3 * c2
result = func(g1, g2, c1, g3, c2)
self.assertEqual(result.numpy(), 5.0 * 7.0 * 17.0)
def testNestedCallWatchedVariables(self):
v = variables.Variable(4.)
@def_function.function
def f():
return v ** 2.
with backprop.GradientTape() as tape:
f()
self.assertEqual((v,), tape.watched_variables())
@def_function.function
def g():
return f()
with backprop.GradientTape() as tape:
g()
self.assertEqual((v,), tape.watched_variables())
# f() can rely on the variable being read during its trace. g() checks that
# variables from a function which knows about them are recorded on the
# tape. h() tests that functions forward knowledge of variables to callers.
@def_function.function
def h():
return g()
with backprop.GradientTape() as tape:
h()
self.assertEqual((v,), tape.watched_variables())
def testDeferredCapture(self):
value = 1.0
@def_function.function
def lazy_capture(x):
y = ops.get_default_graph().capture_call_time_value(
lambda: value, tensor_spec.TensorSpec(None))
return x + y
self.assertAllEqual(lazy_capture(2.0), 3.0)
# After changing the value of `value` the function call should return a
# different result.
value = 2.0
self.assertAllEqual(lazy_capture(2.0), 4.0)
def testDeferredCaptureWithKey(self):
value0 = 1.0
value1 = 2.0
@def_function.function
def lazy_capture(x):
w = ops.get_default_graph().capture_call_time_value(
lambda: value0, tensor_spec.TensorSpec(None), key=0)
y = ops.get_default_graph().capture_call_time_value(
lambda: value1, tensor_spec.TensorSpec(None), key=1)
def bad_closure():
raise ValueError('Should not run')
z = ops.get_default_graph().capture_call_time_value(
bad_closure, tensor_spec.TensorSpec(None), key=1)
return x + y + w + z
self.assertAllEqual(lazy_capture(2.0), 7.0)
value0 = 2.0
value1 = 3.0
self.assertAllEqual(lazy_capture(2.0), 10.0)
def testDeferredCaptureTypeError(self):
value = constant_op.constant(1.0)
@def_function.function
def lazy_capture(x):
y = ops.get_default_graph().capture_call_time_value(
lambda: value, tensor_spec.TensorSpec(()))
return x + y
self.assertAllEqual(lazy_capture(2.0), 3.0)
# dtype mismatch
value = constant_op.constant(1)
with self.assertRaisesRegex(ValueError, 'Value .* to a tensor with dtype'):
lazy_capture(2.0)
# shape mismatch
value = constant_op.constant([1.0])
with self.assertRaisesRegex(ValueError, 'Value .* shape'):
lazy_capture(2.0)
def testDeferredCaptureReturnNestWithCompositeTensor(self):
i_s = indexed_slices.IndexedSlices(
constant_op.constant([1, 2]),
constant_op.constant([0, 1], dtype=dtypes.int64),
constant_op.constant([2]))
r_t = ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]])
s_t = sparse_tensor.SparseTensor(
values=[1, 2, 3], indices=[[0], [8], [10]], dense_shape=[20])
@def_function.function
def lazy_capture():
y = ops.get_default_graph().capture_call_time_value(
lambda: {'i': i_s, 't': (r_t, s_t)},
{'i': indexed_slices.IndexedSlicesSpec(
dtype=dtypes.int32, dense_shape_dtype=dtypes.int32),
't': (ragged_tensor.RaggedTensorSpec([2, None, None], dtypes.int32),
sparse_tensor.SparseTensorSpec([None], dtypes.int32))})
return y['i'], y['t']
i, (r, s) = lazy_capture()
self.assertAllEqual(i_s.values, i.values)
self.assertAllEqual(i_s.indices, i.indices)
self.assertAllEqual(i_s.dense_shape, i.dense_shape)
self.assertAllEqual(r_t, r)
self.assertAllEqual(s_t.indices, s.indices)
self.assertAllEqual(s_t.values, s.values)
self.assertAllEqual(s_t.dense_shape, s.dense_shape)
def testDeferredCaptureCompositeTensorSpecTypeMismatch(self):
value = indexed_slices.IndexedSlices(
constant_op.constant([1, 2]),
constant_op.constant([0, 1], dtype=dtypes.int64))
@def_function.function
def lazy_capture():
return ops.get_default_graph().capture_call_time_value(
lambda: value,
indexed_slices.IndexedSlicesSpec(dtype=dtypes.int32))
# Type matches spec.
lazy_capture()
# Extra dense shape component.
value = indexed_slices.IndexedSlices(
constant_op.constant([1, 2]),
constant_op.constant([0, 1], dtype=dtypes.int64),
constant_op.constant([2]))
with self.assertRaises(ValueError):
lazy_capture()
# Index dtype mismatch int32 vs. int64.
value = indexed_slices.IndexedSlices(
constant_op.constant([1, 2]),
constant_op.constant([0, 1]))
with self.assertRaises(ValueError):
lazy_capture()
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()