blob: 7ebcf77ff61032c66bebe64b24f4b65ba7640bf4 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import functools
import itertools
import multiprocessing.pool
import sys
import time
import weakref
from absl.testing import parameterized
import numpy
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.autograph.core import ag_ctx
from import dataset_ops
from import iterator_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import cancellation
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import function as tf_function
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import type_spec
from tensorflow.python.layers import convolutional
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gen_sendrecv_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.structured import structured_tensor
from tensorflow.python.platform import test
from import training_ops
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
import attr # pylint:disable=g-import-not-at-top
except ImportError:
attr = None
def total_function_cache(defined):
# pylint: disable=protected-access
return (set(defined._function_cache.primary)
| set(defined._function_cache.arg_relaxed))
# pylint: enable=protected-access
def _example_indexed_slices_with_dense_shape():
return indexed_slices.IndexedSlices(
constant_op.constant([1, 2]), constant_op.constant([0, 1]),
def _example_indexed_slices_without_dense_shape():
return indexed_slices.IndexedSlices(
constant_op.constant([1, 2]), constant_op.constant([0, 1]))
def _spec_for_value(value):
"""Returns the (nested) TypeSpec for a value."""
if nest.is_sequence(value):
return nest.map_structure(_spec_for_value, value)
elif isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)):
return type_spec.type_spec_from_value(value)
return value
class FunctionTest(test.TestCase, parameterized.TestCase):
def setUp(self):
super(FunctionTest, self).setUp()
cpus = config.list_physical_devices('CPU')
# Set 4 virtual CPUs
config.set_logical_device_configuration(cpus[0], [
def testBasic(self):
matmul = def_function.function(math_ops.matmul)
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq = matmul(t, t, transpose_a=True)
sq2 = matmul(sq, t, transpose_a=True)
self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
self.assertAllEqual(sq2.numpy().reshape(-1), [52, 76, 74, 108])
def testOnExitCallback(self):
values = []
def append_1():
def append_2():
def g(x):
old_values = list(values)
self.assertEqual(old_values, values)
return x + 1
tf_g = def_function.function(g)
def f(x):
old_values = list(values)
self.assertEqual(old_values, values)
return tf_g(x)
tf_f = def_function.function(f)
self.assertEqual(values, [1, 2]) # Once for g, once for f.
tf_f(constant_op.constant([1.0])) # force a retrace
self.assertEqual(values, [1, 2, 1, 2]) # And again.
def testCannotAddExitCallbackWhenNotInFunctionScope(self):
with self.assertRaisesRegex(RuntimeError, 'when not building a function.'):
ops.add_exit_callback_to_default_func_graph(lambda: None)
def testVariable(self):
v1 = variables.Variable(1.0)
add = def_function.function(lambda x, v: x + v1 + v)
v2 = variables.Variable(1.0)
x = constant_op.constant(1.0)
r = add(x, v2)
self.assertEqual(3.0, self.evaluate(r))
def testVariableOnly(self):
v = variables.Variable(1.0)
add = def_function.function(lambda x: x.assign_add(1.0))
r1 = add(v)
self.assertEqual(2.0, self.evaluate(r1))
c = constant_op.constant(1.0)
with self.assertRaisesRegex(AttributeError, 'no attribute'):
@test_util.disable_tfrt('Packed tensor is not supported in tfrt yet.')
def testPackedVariable(self):
with ops.device('/cpu:0'):
v0_0 = resource_variable_ops.ResourceVariable(1.0)
with ops.device('/cpu:1'):
v0_1 = resource_variable_ops.ResourceVariable(2.0)
v1_0 = resource_variable_ops.ResourceVariable(3.0)
with ops.device('/cpu:2'):
v1_1 = resource_variable_ops.ResourceVariable(4.0)
packed_var_0 = ops.pack_eager_tensors([v0_0.handle, v0_1.handle])
packed_var_1 = ops.pack_eager_tensors([v1_0.handle, v1_1.handle])
# TODO(b/145922293): use ResourceVariable.assign_add and
# ResourceVariable.read_value directly once we support packing multiple
# ResourceVariable into one ResourceVariable.
def read_var():
packed_var_0, constant_op.constant(5.0))
packed_var_1, constant_op.constant(6.0))
with ops.device('/cpu:0'):
read0 = resource_variable_ops.read_variable_op(
packed_var_0, dtype=dtypes.float32)
with ops.device('/cpu:1'):
read1 = resource_variable_ops.read_variable_op(
packed_var_0, dtype=dtypes.float32)
read2 = resource_variable_ops.read_variable_op(
packed_var_1, dtype=dtypes.float32)
with ops.device('/cpu:2'):
read3 = resource_variable_ops.read_variable_op(
packed_var_1, dtype=dtypes.float32)
return read0, read1, read2, read3
arg_attrs = read_var.get_concrete_function().function_def.arg_attr
self.assertLen(arg_attrs, 2)
self.assertAllEqual(read_var(), (1 + 5, 2 + 5, 3 + 6, 4 + 6))
def testImplementsAttributeBasic(self):
v = def_function.function(
experimental_implements='func')(lambda x, y: x + y)
with context.graph_mode(), self.cached_session():
a = array_ops.placeholder(dtypes.float32, ())
b = array_ops.placeholder(dtypes.float32, ())
v(a, b)
gradients_impl.gradients(v(a, b), [a, b])
fdefs = ops.get_default_graph().as_graph_def().library.function
self.assertLen(fdefs, 3)
not_present = 0
present = 0
for f in fdefs:
name =
if 'forward' in name or 'backward' in name:
not_present += 1
self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f)
present += 1
'func'.encode('ascii'), f)
self.assertEqual(not_present, 2, fdefs)
self.assertEqual(present, 1, fdefs)
def testImplementsAttributeAssertsOnSideInput(self):
with context.graph_mode(), self.cached_session():
z = array_ops.zeros(0)
v = def_function.function(
experimental_implements='func')(lambda x, y: x + y + z)
a = array_ops.ones((1.0,))
b = array_ops.ones((1.0,))
with self.assertRaisesRegex(AssertionError,
'variables are always captured'):
v(a, b)
functions = ops.get_default_graph().as_graph_def().library.function
def testImplementsAttributeWorksOnVariables(self):
with context.graph_mode(), self.cached_session():
v = def_function.function(
experimental_implements='func')(lambda x, y: x + y)
a = variables.Variable((1.0,))
b = variables.Variable((1.0,))
r1 = v(a, b)
_ = v(a, a)
functions = ops.get_default_graph().as_graph_def().library.function
# Verify that we created only one function
self.assertLen(functions, 1)
# Verify that eval() reads the current values.
self.assertEqual(r1.eval(), 2)
self.assertEqual(r1.eval(), 3)
def testImplementsAttributeWorksOnConstants(self):
with context.graph_mode(), self.cached_session():
v = def_function.function(
experimental_implements='func')(lambda x, y: x + y)
a = variables.Variable(1.0)
r1 = v(a, 2.)
r2 = v(2., a)
functions = ops.get_default_graph().as_graph_def().library.function
self.assertLen(functions, 1)
self.assertLen(functions[0].signature.input_arg, 2)
# Verify that eval() reads the current values.
self.assertEqual(r1.eval(), 3)
self.assertEqual(r2.eval(), 3)
def testImplementsAttributeSpecializes(self):
with context.graph_mode(), self.cached_session():
v = def_function.function(
experimental_implements='func')(lambda x, y: x + y)
a = variables.Variable(1.0)
r1 = v(a, [2.])
r2 = v([2., 2], a)
functions = ops.get_default_graph().as_graph_def().library.function
self.assertLen(functions, 2)
# Ensure that all parameters are still there and haven't been inlined!
self.assertLen(functions[0].signature.input_arg, 2)
self.assertLen(functions[1].signature.input_arg, 2)
# Verify that eval() reads the current values.
numpy.testing.assert_equal(r1.eval(), [3.])
numpy.testing.assert_equal(r2.eval(), [3., 3.])
def testImplementsAttributeAsNameAttrList(self):
implements_attr = (
'name: "embedding_matmul" attr { key: "key1" value { i: 2 } '
'} attr { key: "key2" value { b: false } }')
v = def_function.function(
experimental_implements=implements_attr)(lambda x, y: x + y)
with context.graph_mode(), self.cached_session():
a = array_ops.placeholder(dtypes.float32, ())
b = array_ops.placeholder(dtypes.float32, ())
v(a, b)
gradients_impl.gradients(v(a, b), [a, b])
fdefs = ops.get_default_graph().as_graph_def().library.function
self.assertLen(fdefs, 3)
not_present = 0
present = 0
for f in fdefs:
name =
if 'forward' in name or 'backward' in name:
not_present += 1
self.assertNotIn(function.IMPLEMENTS_ATTRIBUTE_NAME, f.attr, f)
present += 1
attr_value = f.attr[function.IMPLEMENTS_ATTRIBUTE_NAME]
self.assertIsNotNone(attr_value.func, f)
self.assertEqual(, 'embedding_matmul')
name_attrs = attr_value.func.attr
self.assertLen(name_attrs, 2)
self.assertEqual(not_present, 2, fdefs)
self.assertEqual(present, 1, fdefs)
def testExternalControlDependency(self):
with ops.Graph().as_default(), self.test_session():
v = variables.Variable(1.0)
op = v.assign_add(1.0)
def f():
with ops.control_dependencies([op]):
return 1.0
self.assertAllEqual(self.evaluate(v), 2.0)
def testInputShapeFunctionRelaxation(self):
unknown_dim = [False]
def func(a):
if a._shape_tuple()[0] is None:
unknown_dim[0] = True
return a + 1
self.assertLen(total_function_cache(func), 1)
self.assertLen(total_function_cache(func), 2)
func(constant_op.constant([1.0, 2.0]))
self.assertLen(total_function_cache(func), 2)
def testInputShapeRelaxationOnInstanceMethod(self):
# Test that experimental_relax_shapes is passed during
# instance method bounding.
unknown_dim = [False]
class Foo(object):
def func(self, a):
if a._shape_tuple()[0] is None:
unknown_dim[0] = True
return a + 1
foo = Foo()
foo.func(constant_op.constant([1.0, 2.0]))
def testInputShapeFunctionRelaxationWithRaggedTensors(self):
traced_type_spec = [None]
def func(x):
traced_type_spec[0] = x._type_spec
return x
def check_trace(x, expected_trace):
traced_type_spec[0] = None
self.assertEqual(traced_type_spec[0], expected_trace)
check_trace( # Initial call gets traced.
ragged_factory_ops.constant([[1], [2, 3, 4]]),
ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32))
check_trace( # Input TypeSpec is the same -> no retrace.
ragged_factory_ops.constant([[1, 2], [3, 4]]), None)
check_trace( # Even if component tensor shapes change -> no retrace.
ragged_factory_ops.constant([[1, 2], [3, 4, 5, 6]]), None)
check_trace( # Different TypeSpec shape (nrows): retrace
ragged_factory_ops.constant([[1], [2], [3]]),
ragged_tensor.RaggedTensorSpec([3, None], dtypes.int32))
check_trace( # Different nrows again: relax & retrace
ragged_factory_ops.constant([[1], [2], [3], [4]]),
ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32))
check_trace( # Different nrows yet again: not retrace
ragged_factory_ops.constant([[1]]), None)
check_trace( # Different ragged_rank: retrace
ragged_tensor.RaggedTensorSpec([1, None, None], dtypes.int32))
check_trace( # Different ragged_rank again: retrace & relax
ragged_factory_ops.constant([[[1]], [[2]]]),
ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32))
def testInputShapeFunctionRelaxationWithStructuredTensors(self):
traced_type_spec = [None]
def func(x):
traced_type_spec[0] = x._type_spec
return x
def check_trace(x, expected_trace):
traced_type_spec[0] = None
self.assertEqual(traced_type_spec[0], expected_trace)
# If we have TypeSpecs that differ in ways other than just their shape,
# then retrace each time.
structured_tensor.StructuredTensor.from_pyval({'a': [1]}),
[], {'a': tensor_spec.TensorSpec((1,), dtypes.int32)}))
structured_tensor.StructuredTensor.from_pyval({'b': [1]}),
[], {'b': tensor_spec.TensorSpec((1,), dtypes.int32)}))
structured_tensor.StructuredTensor.from_pyval({'c': [1]}),
[], {'c': tensor_spec.TensorSpec((1,), dtypes.int32)}))
# But if we call again with only shape different, then do relax:
check_trace( # retrace
structured_tensor.StructuredTensor.from_pyval({'a': [1, 2]}),
[], {'a': tensor_spec.TensorSpec((2,), dtypes.int32)}))
check_trace( # relax & retrace
structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3]}),
[], {'a': tensor_spec.TensorSpec((None,), dtypes.int32)}))
check_trace( # use relaxed graph
structured_tensor.StructuredTensor.from_pyval({'a': [1, 2, 3, 4]}),
def testInputShapeFunctionRelaxationWithDatasetIterators(self):
# For dataset iterators, the TypeSpec includes type information that's
# not derivable from the component tensors. Make sure that the TypeSpec
# shapes get relaxed as appropriate.
traced_type_spec = [None]
def func(x):
traced_type_spec[0] = x._type_spec
return x
def check_trace(x, expected_trace):
traced_type_spec[0] = None
self.assertEqual(traced_type_spec[0], expected_trace)
ds_1_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([1, 2]))
ds_2_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 2]))
ds_3_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([3, 2]))
ds_4_2 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([4, 2]))
ds_2_1 = dataset_ops.DatasetV2.from_tensors(array_ops.zeros([2, 1]))
check_trace( # shape=[1, 2]: retrace
tensor_spec.TensorSpec([1, 2], dtypes.float32)))
check_trace( # shape=[1, 2]: no retrace (use the [1, 2] graph)
dataset_ops.make_one_shot_iterator(ds_1_2), None)
check_trace( # shape=[2, 2]: retrace
tensor_spec.TensorSpec([2, 2], dtypes.float32)))
check_trace( # shape=[3, 2]: relax to [None, 2] and retrace
tensor_spec.TensorSpec([None, 2], dtypes.float32)))
check_trace( # shape=[4, 2]: no retrace (use the [None, 2] graph)
dataset_ops.make_one_shot_iterator(ds_4_2), None)
check_trace( # shape=[2, 1]: relax to [None, None] and retrace
tensor_spec.TensorSpec([None, None], dtypes.float32)))
def testCapturesVariables(self):
a = variables.Variable(1.0, trainable=False)
b = variables.Variable(1.0)
cc = [None]
def f():
c = cc[0]
if c is None:
c = cc[0] = variables.Variable(1.)
return a + b + c + 1
cf = f.get_concrete_function()
c = cc[0]
captured_variables = {v.ref() for v in (a, b, c)}
trainable_variables = {v.ref() for v in (b, c)}
self.assertEqual({v.ref() for v in cf.variables}, captured_variables)
self.assertEqual({v.ref() for v in cf.trainable_variables},
self.assertEqual(cf.variables, cf.graph.variables)
self.assertEqual(cf.trainable_variables, cf.graph.trainable_variables)
def testNestedInputShapeFunctionRelaxation(self):
unknown_dim = [False]
def func(a_, b_=None):
del a_ # Only used to check which cache is used.
self.assertEqual(b_[0]._shape_tuple(), ())
if b_[1]._shape_tuple()[0] is None:
unknown_dim[0] = True
return b_[0] + 1
a = 'hi'
b0 = constant_op.constant(1.0)
func(a, b_=[b0, constant_op.constant([])])
self.assertLen(total_function_cache(func), 1)
func(a, b_=[b0, constant_op.constant([1.0])])
self.assertLen(total_function_cache(func), 2)
func(a, b_=[b0, constant_op.constant([1.0, 1.0])])
self.assertLen(total_function_cache(func), 2)
unknown_dim[0] = False
# Now do the same except with a new a which is not a tensor; this should
# change the cache key.
a = 'bye'
func(a, b_=[b0, constant_op.constant([])])
self.assertLen(total_function_cache(func), 3)
# Since we already marked a cache miss for a function with the same
# non-input signatures, here we will immediately start relaxing shapes.
func(a, b_=[b0, constant_op.constant([1.0])])
self.assertLen(total_function_cache(func), 3)
def testNestedShapeFunctionRelaxation(self):
got_shape = [None]
# The inner function will go through shape relaxation because the shapes it
# receives will be [1], [2], [3], ...
def bar(x_shape):
got_shape[0] = x_shape._shape_tuple()
return x_shape
# The outer function will not go through shape relaxation because the shapes
# it receives will be [1], [[1]], [[[1]]], ...
def foo(ones):
return bar(array_ops.shape(ones))
for rank in range(1, 6):
x_shape = self.evaluate(foo(array_ops.ones([1] * rank)))
self.assertAllEqual(x_shape, [1] * rank)
if rank < 3:
self.assertEqual(got_shape[0], (rank,))
self.assertEqual(got_shape[0], (None,))
def testNoHash(self):
def f(_):
return 1.0
with self.assertRaisesRegex(ValueError, r'Got type: set'):
def testFuncName(self):
@function.defun_with_attributes(attributes={'func_name': 'multiply'})
def add(x, y):
_ = x * y
return x + y
def add_2(x, y):
_ = x * y
return x + y
self.assertEqual(add._name, 'multiply')
self.assertEqual(add_2._name, 'add_2')
def testBasicGraphMode(self):
matmul = def_function.function(math_ops.matmul)
def sq(a):
return matmul(a, a)
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
out = sq(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testNestedInputsGraphMode(self):
matmul = def_function.function(math_ops.matmul)
pair = collections.namedtuple('pair', ['a', 'b'])
def a_times_b(inputs):
return matmul(inputs.a['a'], inputs.b['b'])
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
out = a_times_b(pair({'a': t}, {'b': t}))
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testNestedOutputsGraphMode(self):
matmul = def_function.function(math_ops.matmul)
pair = collections.namedtuple('pair', ['a', 'b'])
def pairs_mul(pair_a, pair_b):
return pair(matmul(pair_a.a, pair_b.a), matmul(pair_a.b, pair_b.b))
a = constant_op.constant([[1.0, 2.0], [1.0, 2.0]])
b = constant_op.constant([[3.0, 4.0], [3.0, 4.0]])
out = pairs_mul(pair(a, b), pair(b, a))
expected = pair(math_ops.matmul(a, b).numpy(),
math_ops.matmul(b, a).numpy())
self.assertAllClose(out, expected)
def testNestedFunctionGraphNotOutOfDate(self, function_decorator):
def f():
return constant_op.constant(1.)
class _Model(object):
def g(self):
self.f = f.get_concrete_function()
model = _Model()
concrete = model.f
weak_g_graph = weakref.ref(model.g.get_concrete_function().graph)
self.assertIs(weak_g_graph(), concrete.graph.outer_graph)
weak_g = weakref.ref(model.g)
del model
self.assertIs(ops.get_default_graph(), concrete.graph.outer_graph)
def testGraphEagerIsolation(self):
def f():
self.v = variables.Variable(1.0)
return self.v.read_value()
self.assertAllEqual(f(), 1.0)
with ops.Graph().as_default():
self.assertEqual(f().shape, ())
def testBasicGraphFunction(self):
matmul = def_function.function(math_ops.matmul)
def sq(a):
return matmul(a, a)
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq_op = sq.get_concrete_function(t)
self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
out = sq_op(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testGetConcreteFunctionThreadSafety(self):
def sq():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
return math_ops.matmul(t, t)
concrete_functions = []
def thread_func(_):
cf = sq.get_concrete_function()
num_threads = 100
pool = multiprocessing.pool.ThreadPool(num_threads)
_ =, list(range(num_threads)))
self.assertLen(set(concrete_functions), 1)
def testGetConcreteFunctionThreadSafetyWithArgs(self):
def add_100(*args):
return math_ops.add_n(args)
p = multiprocessing.pool.ThreadPool(2)
args = (constant_op.constant(1.),) * 100
f1, f2 =, [args] * 2)
# I see about len(args) + max(0, len(args) - 3) arguments expected.
del f2
def testInputSpecGraphFunction(self):
matmul = def_function.function(math_ops.matmul)
def sq(a):
return matmul(a, a)
sq_op = sq.get_concrete_function(
tensor_spec.TensorSpec((None, None), dtypes.float32))
self.assertEqual([None, None], sq_op.output_shapes.as_list())
t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
out1 = sq_op(t1)
self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy())
t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
out2 = sq_op(t2)
self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy())
def testNestedInputSpecGraphFunction(self):
matmul = def_function.function(math_ops.matmul)
def sq(mats):
((a, b),) = mats
return matmul(a, b)
sq_op_autonamed = sq.get_concrete_function(
[(tensor_spec.TensorSpec((None, None), dtypes.float32),
tensor_spec.TensorSpec((None, None), dtypes.float32))])
self.assertEqual([None, None], sq_op_autonamed.output_shapes.as_list())
sq_op = sq.get_concrete_function(
[(tensor_spec.TensorSpec((None, None), dtypes.float32,
tensor_spec.TensorSpec((None, None), dtypes.float32,
self.assertEqual([None, None], sq_op.output_shapes.as_list())
t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]])
out = sq_op(first_mat=t1, second_mat=t2)
self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy())
self.assertAllEqual(sq_op_autonamed(t1, t2),
math_ops.matmul(t1, t2).numpy())
def testExecutingStatelessDefunConcurrently(self):
def stateless(x):
return math_ops.multiply(2.0, x)
pool = multiprocessing.pool.ThreadPool()
inputs = [constant_op.constant(1.0 * x) for x in range(100)]
outputs = [float(out) for out in, inputs)]
expected = [float(2.0 * x) for x in inputs]
self.assertSequenceEqual(outputs, expected)
def testExecutingManyStatelessDefunsConcurrently(self):
def stateless(x):
del x
return math_ops.multiply(2.0, 2.0)
pool = multiprocessing.pool.ThreadPool()
# `` below instantiates 100 functions, one for each object.
objects = [object() for _ in range(100)]
outputs = [float(out) for out in, objects)]
expected = [4.0] * 100
self.assertSequenceEqual(outputs, expected)
@test_util.disable_tfrt('b/169431085: This test is flaky on tfrt')
def testExecutingStatefulDefunConcurrently(self):
v = resource_variable_ops.ResourceVariable(1.0)
def stateful(x):
pool = multiprocessing.pool.ThreadPool()
inputs = [constant_op.constant(0.0)] * 100, inputs)
self.assertEqual(float(v.read_value()), 0.0)
def testExecutingManyStatefulDefunsConcurrently(self):
v = resource_variable_ops.ResourceVariable(1.0)
def stateful(x):
del x
return v.assign(0.0)
pool = multiprocessing.pool.ThreadPool()
# `` below instantiates 100 functions, one for each object., [object() for _ in range(100)])
self.assertEqual(float(v.read_value()), 0.0)
def testShareRendezvous(self):
# Disable grappler from inlining the functions. Note we run the send & recv
# in graph mode since with eager mode the function should automatically be
# inlined.
{'disable_meta_optimizer': True})
cpu = '/device:CPU:0'
signature = [tensor_spec.TensorSpec([], dtypes.int32)]
def send():
x = constant_op.constant(1)
gen_sendrecv_ops.send(x, 'x', cpu, 0, cpu)
return x
send._shared_rendezvous = True # pylint: disable=protected-access
def send_body(n):
return n - 1
def recv():
return gen_sendrecv_ops.recv(dtypes.int32, 'x', cpu, 0, cpu)
recv._shared_rendezvous = True # pylint: disable=protected-access
def recv_body(n):
return n - 1
def cond(n):
return n > 0
# Instead of calling the send & recv functions directly we want to call them
# through a functional while to ensure the rendezvous is shared across the
# while boundary.
def fn(n):
functional_ops.While([n], cond.get_concrete_function(),
return functional_ops.While([n], cond.get_concrete_function(),
# Use a graph context since functions will not be automatically inlined
with context.graph_mode(), self.cached_session():
def disabled_testRandomSeed(self):
def f():
return random_ops.random_normal(())
x = f()
self.assertNotEqual(x, f())
self.assertAllEqual(f(), x)
def testNestedInputsGraphFunction(self):
matmul = def_function.function(math_ops.matmul)
pair = collections.namedtuple('pair', ['a', 'b'])
def a_times_b(inputs):
return matmul(inputs.a['a'], inputs.b['b'])
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq_op = a_times_b.get_concrete_function(
pair(dict(a=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'a')),
dict(b=tensor_spec.TensorSpec([2, 2], dtypes.float32, 'b'))))
self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
out = sq_op(a=t, b=t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
def testNestedOutputGraphFunction(self):
matmul = def_function.function(math_ops.matmul)
def sq(a):
return (matmul(a, a), {'b': constant_op.constant(1.0)})
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq_op = sq.get_concrete_function(t)
(tensor_shape.TensorShape([2, 2]),
{'b': tensor_shape.TensorShape([])}))
(dtypes.float32, {'b': dtypes.float32}))
(a, b) = sq_op(t)
self.assertAllEqual(a, math_ops.matmul(t, t).numpy())
self.assertAllEqual(b['b'].numpy(), 1.0)
def testGraphFunctionNoneOutput(self):
def fn(unused_a, unused_b):
return None
x = constant_op.constant(1)
fn_op = fn.get_concrete_function(x, x)
self.assertEqual(fn_op.output_dtypes, None)
self.assertEqual(fn_op.output_shapes, None)
self.assertAllEqual(fn_op(x, x), None)
def testDefunNumpyArraysConvertedToTensors(self):
def f(x):
self.assertIsInstance(x, ops.Tensor)
return x
x = random_ops.random_uniform([2, 2]).numpy()
defined = function.defun(f)
self.assertLen(total_function_cache(defined), 1)
x = random_ops.random_uniform([2, 2]).numpy()
# A NumPy array with different values but the same shape and dtype
# shouldn't trigger another function definition.
self.assertLen(total_function_cache(defined), 1)
np_ones = numpy.ones([], numpy.float32)
np_zeros = numpy.zeros([], numpy.float32)
tf_ones = array_ops.ones([])
tf_zeros = array_ops.zeros([])
# Test that the numpy array is properly an argument to the graph function.
self.assertEqual(1., defined(np_ones).numpy())
self.assertLen(total_function_cache(defined), 2)
self.assertEqual(0., defined(np_zeros).numpy())
self.assertEqual(1., defined(tf_ones).numpy())
self.assertEqual(0., defined(tf_zeros).numpy())
self.assertLen(total_function_cache(defined), 2)
# Test that mutable inputs are supported.
mutable = numpy.ones([], numpy.float32)
self.assertEqual(1., defined(mutable).numpy())
self.assertEqual(0., defined(mutable).numpy())
class MyNdarray(numpy.ndarray):
# Test that the subclasses of ndarray are converted too.
self.assertEqual(1., defined(np_ones.view(MyNdarray)).numpy())
self.assertEqual(0., defined(np_zeros.view(MyNdarray)).numpy())
# We should not have triggered any re-tracing of the python function.
self.assertLen(total_function_cache(defined), 2)
def testNumpyDtypeInputSupported(self):
def f(x, dtype):
return constant_op.constant(dtype(x))
self.assertEqual(f(1, numpy.float32).numpy(), numpy.float32(1))
self.assertEqual(f(2, numpy.float32).numpy(), numpy.float32(2))
self.assertEqual(f(1, numpy.int32).numpy(), numpy.int32(1))
self.assertEqual(f(2, numpy.int32).numpy(), numpy.int32(2))
def testDefunNumpyArraysConvertedToTensorsInKwargs(self):
def f(**kwargs):
x = kwargs.pop('x')
self.assertIsInstance(x, ops.Tensor)
return x
x = random_ops.random_uniform([2, 2]).numpy()
defined = function.defun(f)
self.assertLen(total_function_cache(defined), 1)
x = random_ops.random_uniform([2, 2]).numpy()
# A NumPy array with different values but the same shape and dtype
# shouldn't trigger another function definition.
self.assertLen(total_function_cache(defined), 1)
# Test that the numpy array is properly an argument to the graph function.
self.assertEqual(1., defined(x=numpy.ones([])).numpy())
self.assertEqual(0., defined(x=numpy.zeros([])).numpy())
self.assertEqual(1., defined(x=array_ops.ones([])).numpy())
self.assertEqual(0., defined(x=array_ops.zeros([])).numpy())
def testDefunCapturedInt32(self):
x = constant_op.constant(1, dtype=dtypes.int32)
def add_int32s():
return x + x
self.assertEqual(2, int(add_int32s()))
def testDefunReadVariable(self):
v = resource_variable_ops.ResourceVariable(1.0)
def f():
return v.read_value()
self.assertEqual(1.0, float(f()))
def testDefunAssignAddVariable(self):
v = resource_variable_ops.ResourceVariable(1.0)
x = constant_op.constant(2.0)
def test_assign_add():
return v.read_value()
self.assertEqual(3.0, float(test_assign_add()))
def testTensorInitializationInFunctionRaisesError(self):
error_msg = ('Tensor-typed variable initializers must either be '
'wrapped in an init_scope or callable.*')
def tensor_init():
with self.assertRaisesRegex(ValueError, error_msg):
def testCallableTensorInitializationInFunction(self):
def tensor_init():
self.v = resource_variable_ops.ResourceVariable(
lambda: constant_op.constant(2.0))
return self.v.read_value()
value = tensor_init()
if not context.executing_eagerly():
self.assertEqual(self.evaluate(value), 2.0)
def testInitScopeTensorInitializationInFunction(self):
def tensor_init():
with ops.init_scope():
const = constant_op.constant(2.0)
# Note: this variable bypasses tf.function's variable creation
# requirements by bypassing variable_creator_scope by using
# ResourceVariable instead of Variable.
self.v = resource_variable_ops.ResourceVariable(const)
return self.v.read_value()
value = tensor_init()
self.assertAllEqual(value, 2.0)
def testGetConcreteFunctionCreatesVariables(self):
v_holder = []
def tensor_init():
if not v_holder:
return v_holder[0].read_value()
concrete = tensor_init.get_concrete_function()
self.assertAllEqual(5., self.evaluate(concrete()))
self.assertAllEqual(5., self.evaluate(tensor_init()))
def testFuncGraphCaptureByValue(self):
v = variables.Variable(1.0)
def trivial_function():
return v.read_value()
graph_function = function.Function(
trivial_function, 'test', capture_by_value=True)
self.assertAllEqual(graph_function(), 1.0)
self.assertAllEqual(graph_function(), 1.0)
def testFuncGraphCaptureByValueNested(self):
v = variables.Variable(1.0)
def trivial_function():
return control_flow_ops.cond(
array_ops.placeholder_with_default(True, ()),
v.read_value, v.read_value)
graph_function = function.Function(
trivial_function, 'test', capture_by_value=True)
self.assertAllEqual(graph_function(), 1.0)
self.assertAllEqual(graph_function(), 1.0)
def testDefunShapeInferenceWithCapturedResourceVariable(self):
v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
# We do not return v directly since the tensor conversion function of
# ResourceVariable returns the read value and not the resource itself.
return v._handle
compiled = def_function.function(f)
var_handle = compiled()
self.assertEqual(var_handle.dtype, dtypes.resource)
self.assertEqual(var_handle.shape, tensor_shape.TensorShape([]))
var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
def testShapeInferenceForMoreSpecificInput(self):
def f(a):
return array_ops.reshape(a, [-1, 3])
signature = [tensor_spec.TensorSpec(None, dtypes.float32)]
compiled = def_function.function(f, input_signature=signature)
def use_f():
inputs = array_ops.zeros([10, 10, 3])
self.assertAllEqual(f(inputs).shape, compiled(inputs).shape)
def testFuncListAttr(self):
def test_function(val):
def fn1():
return array_ops.ones([10])
fn2 = lambda: array_ops.ones([10]) * 2
def fn3(x=3):
return array_ops.ones([10]) * x
fn4 = functools.partial(fn3, x=4)
fn5 = functools.partial(fn3, 5)
return, [], [dtypes.float32],
for f in (fn1, fn2, fn3, fn4, fn5)])
ones = array_ops.ones([10])
self.assertAllEqual([ones], test_function(0))
self.assertAllEqual([ones * 2], test_function(1))
self.assertAllEqual([ones * 3], test_function(2))
self.assertAllEqual([ones * 4], test_function(3))
self.assertAllEqual([ones * 5], test_function(4))
self.assertAllEqual([ones * 5], test_function(22)) # default branch
def testVariableInLoopInFunction(self):
def test_function():
def loop_test(_):
return False
def loop_body(_):
return variable_scope.get_variable('a', shape=())
return control_flow_ops.while_loop(loop_test, loop_body, [0.0])
self.assertEqual(test_function().shape, [])
def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self):
with context.graph_mode():
v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
# We do not return v directly since the tensor conversion function of
# ResourceVariable returns the read value and not the resource itself.
return v._handle
compiled = def_function.function(f)
var_handle = compiled()
self.assertEqual(var_handle.dtype, dtypes.resource)
self.assertEqual(var_handle.shape, tensor_shape.TensorShape([]))
var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
def testDefunShapeInferenceWithCapturedVariableInGraphMode(self):
with context.graph_mode():
v = variables.Variable([[1, 2], [3, 4]])
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
# Check that shape inference works while creating the defun
compiled = def_function.function(f)
def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self):
with context.graph_mode():
tensor_list = list_ops.empty_tensor_list(
element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
tensor_list = list_ops.tensor_list_push_back(tensor_list,
tensor_list = list_ops.tensor_list_push_back(tensor_list,
def f():
tl, value = list_ops.tensor_list_pop_back(
tensor_list, element_dtype=dtypes.float32)
self.assertEqual(value.shape, tensor_shape.TensorShape([]))
return tl
compiled = def_function.function(f)
output_tensor_list = compiled()
_, value = list_ops.tensor_list_pop_back(
output_tensor_list, element_dtype=dtypes.float32)
self.assertEqual(value.shape, tensor_shape.TensorShape([]))
def testDefunForcesResourceVariables(self):
def variable_creator():
self.v = variables.Variable(0.0)
return self.v.read_value()
self.v = None
defined = function.defun(variable_creator)
defined() # Create the variable.
self.v, resource_variable_ops.ResourceVariable)
def testRunMetadata(self):
def f(x):
return x * x
with ops.device('cpu:0'):
run_metadata = context.export_run_metadata()
self.assertLen(run_metadata.partition_graphs, 1)
def testGraphModeCaptureVariable(self):
with context.graph_mode(), self.cached_session():
class HasAVar(object):
def __init__(self):
self.v = resource_variable_ops.ResourceVariable(1.0)
def call(self):
return self.v * 2
o = HasAVar()
call = def_function.function(
op = call()
self.assertAllEqual(self.evaluate(op), 2.0)
def testGraphModeManyFunctions(self):
with ops.Graph().as_default(), self.cached_session():
def f(x):
return x * x
def g(x):
return f(x) + 1
self.assertAllEqual(g(constant_op.constant(2.0)), 5.0)
def testDict(self):
def f(x):
return {'name': x + 1}
self.assertAllEqual(f(constant_op.constant(1.0))['name'], 2.0)
def testTensorConversionWithDefun(self):
def f(x):
return math_ops.add(x, constant_op.constant(3))
self.assertAllEqual(5, f(constant_op.constant(2)))
def testTensorConversionCall(self):
def f(x):
return math_ops.add(x, constant_op.constant(3))
def g(x):
return f(f(x))
self.assertAllEqual(8, g(constant_op.constant(2)))
def testCallShape(self):
def f(x):
return x + 1
def g(x):
x = f(x)
self.assertEqual(x.shape.as_list(), [])
return None
def testNestedDefunWithNoOutputAndTapedInput(self):
three = resource_variable_ops.ResourceVariable(3.0, name='v')
def f(x):
# This function intentionally takes a taped variable as input,
# but does not return any values
math_ops.add(x, three)
def g(x):
y = math_ops.add(x, three)
def testGatherResourceWithDefun(self):
with ops.device('cpu:0'):
v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
def sum_gather():
return math_ops.reduce_sum(array_ops.gather(v, [1, 2]))
defined = def_function.function(sum_gather)
self.assertAllEqual(sum_gather(), defined())
('RaggedTensorRaggedRank1', ragged_tensor.RaggedTensor.from_row_lengths,
{'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}),
{'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}),
('SparseTensor', sparse_tensor.SparseTensor,
{'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}),
]) # pyformat: disable
def testReturnCompositeTensorWithDefun(self,
input_ct = factory_fn(**factory_kwargs)
def f():
return input_ct
output_ct = f()
self.assertIsInstance(output_ct, type(input_ct))
nest.assert_same_structure(input_ct, output_ct, expand_composites=True)
input_flat = nest.flatten(input_ct, expand_composites=True)
output_flat = nest.flatten(output_ct, expand_composites=True)
for (input_component, output_component) in zip(input_flat, output_flat):
self.assertAllEqual(input_component, output_component)
{'values': [1, 2, 3], 'row_lengths': [2, 0, 1]}),
{'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]}),
{'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]}),
{'values': [1, 2, 3], 'row_lengths': [2, 0, 1]},
[ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)]),
{'flat_values': [1, 2, 3], 'nested_row_lengths': [[1, 2], [2, 0, 1]]},
[ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.int32)]),
{'values': [1, 2, 3], 'indices': [[0], [8], [10]], 'dense_shape': [20]},
[sparse_tensor.SparseTensorSpec([None], dtypes.int32)]),
]) # pyformat: disable
def testCompositeAsArgumentTensorWithDefun(self,
input_ct = factory_fn(**factory_kwargs)
def f(x):
return x
output_ct = f(input_ct)
self.assertIsInstance(output_ct, type(input_ct))
nest.assert_same_structure(input_ct, output_ct, expand_composites=True)
input_flat = nest.flatten(input_ct, expand_composites=True)
output_flat = nest.flatten(output_ct, expand_composites=True)
for (input_component, output_component) in zip(input_flat, output_flat):
self.assertAllEqual(input_component, output_component)
def testTracedCompositeDiscardsShapeInfo(self):
# SparseTensorSpec intentionally excludes info about the number of elements
# that are in a sparse tensor (which is recorded as st.indices.shape[0] and
# st.values.shape[0]). Similarly, RaggedTensorSpec intentionally excludes
# info about the total number of values in a RaggedTensor (stored as
# rt.values.shape[0]). This test checks that the placeholders created by
# tf.function() properly mask this shape info.
def f(rt, st):
self.assertEqual(st.indices.shape.as_list()[:1], [None])
self.assertEqual(st.values.shape.as_list(), [None])
return (rt, st)
rt = ragged_factory_ops.constant([[1, 2], [3]])
st = sparse_tensor.SparseTensor([[0]], [0], [10])
f(rt, st)
def testFunctionOnDevice(self):
x = constant_op.constant([1.]).gpu()
f = def_function.function(math_ops.add)
y = f(x, x).cpu()
self.assertAllEqual(y, [2.])
def testFunctionWithResourcesOnDifferentDevices(self):
with ops.device('/cpu:0'):
v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
with ops.device('/gpu:0'):
v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
def sum_gather():
cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2]))
gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2]))
return cpu_result, gpu_result
defined = function.defun(sum_gather)
if not context.executing_eagerly():
expected = self.evaluate(sum_gather())
self.assertAllEqual(expected, self.evaluate(defined()))
def testOpInFunctionWithConflictingResourceInputs(self):
with ops.device('/cpu:0'):
v_cpu = resource_variable_ops.ResourceVariable(
[0.0, 1.0, 2.0], name='cpu')
v_also_cpu = resource_variable_ops.ResourceVariable(
[0.0, 1.0, 2.0], name='also_cpu')
with ops.device('/gpu:0'):
v_gpu = resource_variable_ops.ResourceVariable(
[0.0, 1.0, 2.0], name='gpu')
def resource_apply_adam():
1.0, # beta1_power
1.0, # beta2_power
1.0, # learning_rate
1.0, # beta1
1.0, # beta2
1.0, # epsilon,
[1.0, 1.0, 1.0], # grad
False) # use_locking
return None
with self.assertRaisesRegex(
'Cannot place the graph because a reference or resource edge connects '
'colocation groups with incompatible assigned devices'):
if not context.executing_eagerly():
def testFunctionHandlesInputsOnDifferentDevices(self):
# The Reshape op requires the shape tensor to be placed in host memory.
reshape = def_function.function(array_ops.reshape)
value = constant_op.constant([1., 2.]).gpu()
shape = constant_op.constant([2, 1])
reshaped = reshape(value, shape).cpu()
self.assertAllEqual(reshaped, [[1], [2]])
def testFunctionHandlesInputsPlacedOnTheWrongDeviceGracefully(self):
# The Reshape op requires the shape tensor to be placed in host memory.
reshape = def_function.function(array_ops.reshape)
value = constant_op.constant([1., 2.])
shape = constant_op.constant([2, 1]).gpu()
reshape(value, shape) # No error is raised
def testNoneOutput(self):
def my_function(_):
return None
self.assertAllEqual(my_function(1), None)
def testNestedFunctions(self):
# TensorFlow function (which is what would be used in TensorFlow graph
# construction).
@tf_function.Defun(dtypes.int32, dtypes.int32)
def add(a, b):
return math_ops.add(a, b)
def add_one(x):
return add(x, 1)
self.assertAllEqual(3, add_one(constant_op.constant(2)))
def testVariableCaptureInNestedFunctions(self):
v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32)
def inner_read():
return v.read_value()
def outer():
return inner_read()
self.assertEqual(1, int(outer()))
def testReturnCapturedEagerTensor(self):
t = constant_op.constant(1)
def read():
return t
self.assertEqual(1, int(read()))
def testReturnCapturedGraphTensor(self):
with context.graph_mode(), self.cached_session():
t = constant_op.constant(1)
def read():
return t
self.assertEqual(1, int(self.evaluate(read())))
def testSequenceInputs(self):
clip_by_global_norm = def_function.function(clip_ops.clip_by_global_norm)
t_list = [constant_op.constant(1.0), constant_op.constant(2.0)]
clipped_list, global_norm = clip_by_global_norm(t_list,
for t in clipped_list:
self.assertIsInstance(t, ops.Tensor)
self.assertIsInstance(global_norm, ops.Tensor)
def testNestedSequenceInputs(self):
def my_op(inputs):
a, b, c = inputs
e, f = b
g, h = e
return [a + a, [tuple([f + f, g + g]), h + h], c + c], a + f + g + h + c
my_eager_op = def_function.function(my_op)
ret = my_eager_op([
constant_op.constant(1), [(constant_op.constant(2),
self.assertLen(ret, 2)
self.assertAllEqual(ret[0][0], 2)
self.assertAllEqual(ret[0][1][0][0], 8)
self.assertAllEqual(ret[0][1][0][1], 4)
self.assertIsInstance(ret[0][1][0], tuple)
self.assertAllEqual(ret[0][1][1], 6)
self.assertAllEqual(ret[0][2], 10)
self.assertAllEqual(ret[1], 15)
def testVariableNamesRespectNameScopesWithDefun(self):
def create_variable():
with ops.name_scope('foo', skip_on_eager=False):
v = resource_variable_ops.ResourceVariable(0.0, name='bar')
self.assertEqual(, 'foo/bar:0')
def testVariableNamesRespectNameScopesWithDefunInGraph(self):
with context.graph_mode():
def create_variable():
with ops.name_scope('foo', skip_on_eager=False):
v = resource_variable_ops.ResourceVariable([1.0, 2.0], name='bar')
self.assertEqual(, 'foo/bar:0')
with ops.get_default_graph().as_default():
def testLayerInDefun(self):
conv = convolutional.Conv2D(
def model(x):
return conv(x)
x = array_ops.ones([1, 2, 2, 1])
y = model(x)
if not context.executing_eagerly():
self.assertAllClose([[[[4.0]]]], self.evaluate(y))
# Variable lifting is somewhat different between defun/tf.function, so testing
# device placement on both makes sense.
def testVariablesPlacedOnOutsideDevice(self, function_decorator):
class _Obj(object):
def __init__(self):
self.v = None
def f(self):
if self.v is None:
self.v = variables.Variable(1.)
return self.v + 1.
has_device = _Obj()
with ops.device('cpu:0'):
self.assertIn('CPU', has_device.v.device)
def testDeviceAnnotationsRespected(self):
def multi_device_fn():
with ops.device('/cpu:0'):
s0 = test_ops.device_placement_op()
with ops.device('/cpu:1'):
s1 = test_ops.device_placement_op()
with ops.device('/cpu:2'):
s2 = test_ops.device_placement_op()
s3 = test_ops.device_placement_op()
return s0, s1, s2, s3
defined = function.defun(multi_device_fn)
outputs = self.evaluate(defined())
self.assertLen(total_function_cache(defined), 1)
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
with ops.device('/cpu:3'):
outputs = self.evaluate(defined())
# All function definitions are agnostic to call site devices.
self.assertLen(total_function_cache(defined), 1)
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
self.assertIn(compat.as_bytes('CPU:3'), outputs[3])
with ops.device('/cpu:0'):
outputs = self.evaluate(defined())
self.assertLen(total_function_cache(defined), 1)
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
self.assertIn(compat.as_bytes('CPU:0'), outputs[3])
def testCallingGraphFunctionOnDifferentDevice(self):
def func():
return constant_op.constant(0)
defined = def_function.function(func)
with ops.device('cpu:0'):
cpu_graph_function = defined.get_concrete_function()
with ops.device('cpu:0'):
self.evaluate(cpu_graph_function()), self.evaluate(func()))
with ops.device('cpu:1'):
self.assertEqual(0., self.evaluate(cpu_graph_function()))
with ops.device(None):
self.assertEqual(0., self.evaluate(cpu_graph_function()))
default_graph_function = defined.get_concrete_function()
self.evaluate(default_graph_function()), self.evaluate(func()))
with ops.device('cpu:1'):
self.assertEqual(0., self.evaluate(default_graph_function()))
def testColocateWithRespected(self):
# TODO(b/113291792): Use multiple CPUs instead of a GPU.
with ops.device('cpu:0'):
x = array_ops.identity(1.0)
with ops.device('gpu:0'):
y = array_ops.identity(1.0)
def foo():
return test_ops.device_placement_op()
with ops.colocate_with(x):
self.assertIn(compat.as_bytes('CPU:0'), self.evaluate(foo()))
with ops.colocate_with(y):
self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo()))
def testVariablesAreTracked(self):
v = resource_variable_ops.ResourceVariable(1.0)
def foo(x):
return v * x
defined = def_function.function(foo)
x = constant_op.constant([1.0])
self.assertEqual(1., self.evaluate(defined(x)))
x = constant_op.constant([1.0, 2.0])
self.assertAllEqual([2., 4.], self.evaluate(defined(x)))
def testCacheObjectHashCollisions(self):
class Foo(object):
def __hash__(self):
return 42
def func(foo):
del foo
defined = function.defun(func)
self.assertLen(total_function_cache(defined), 1)
self.assertLen(total_function_cache(defined), 2)
def testCacheTensorDtypeCollision(self):
def func(t):
return t + t
defined = function.defun(func)
t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
self.assertLen(total_function_cache(defined), 1)
t = constant_op.constant([[1.0]], dtype=dtypes.complex128)
self.assertLen(total_function_cache(defined), 2)
def testCacheTensorShapeCollision(self):
def func(t):
return t + t
defined = function.defun(func)
t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
self.assertLen(total_function_cache(defined), 1)
t = constant_op.constant([1.0], dtype=dtypes.complex64)
self.assertLen(total_function_cache(defined), 2)
def testCacheTensorShapeDtypeCollision(self):
def func(t):
return t + t
defined = function.defun(func)
t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
self.assertLen(total_function_cache(defined), 1)
t = constant_op.constant([1.0], dtype=dtypes.complex128)
self.assertLen(total_function_cache(defined), 2)
def testCacheTensorUnknownShapesCollisionRelaxedShapes(self):
def func(t):
return t + t
with context.graph_mode(), self.cached_session():
defined = function.defun(func, experimental_relax_shapes=True)
p = array_ops.placeholder(dtype=dtypes.float32, shape=[])
self.assertLen(total_function_cache(defined), 1)
p = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
self.assertLen(total_function_cache(defined), 2)
p = array_ops.placeholder(dtype=dtypes.float32, shape=[2])
# Gradual shape relaxation is performed; and the common shape between
# [1] and [2] is one containing unknown dimensions.
self.assertLen(total_function_cache(defined), 2)
# pylint: disable=protected-access
self.assertLen(defined._function_cache.arg_relaxed_specs, 1)
relaxed_specs = (
self.assertLen(relaxed_specs, 1)
relaxed_shape = relaxed_specs[0].shape
# pylint: enable=protected-access
self.assertEqual(relaxed_shape.rank, 1)
self.assertEqual(tensor_shape.dimension_value(relaxed_shape[0]), None)
t = constant_op.constant([1.0, 1.0, 1.0], dtype=dtypes.float32)
# Shape (3,) matches the relaxed shape TensorShape([None])
self.assertLen(total_function_cache(defined), 2)
def testPythonFunctionWithDefaultArgs(self):
def func(foo, bar=1, baz=2):
del foo
del bar
del baz
defined = function.defun(func)
defined(0, baz=20)
def cache_keys():
"""Sanitizes cache keys of non-input metadata."""
return tuple(key[0] for key in total_function_cache(defined))
# `True` corresponds to the fact that we're executing eagerly
self.assertIn(('URRRu', (0, 1, 20)), cache_keys())
defined(1) # bar=1, baz=2
self.assertIn(('URRRu', (1, 1, 2)), cache_keys())
# This matches the previous call.
self.assertLen(total_function_cache(defined), 2)
defined(1, 2, 3)
self.assertLen(total_function_cache(defined), 3)
self.assertIn(('URRRu', (1, 2, 3)), cache_keys())
# This matches the previous call.
defined(1, bar=2, baz=3)
self.assertLen(total_function_cache(defined), 3)
# This matches the previous call.
defined(1, baz=3, bar=2)
self.assertLen(total_function_cache(defined), 3)
def testFunctoolsPartialUnwrappedCorrectly(self):
def full_function(a, b, c=3):
return a, b, c
partial = functools.partial(full_function, 1, c=4)
a, b, c = partial(2)
defined = function.defun(partial)
func_a, func_b, func_c = defined(2)
self.assertEqual(func_a.numpy(), a)
self.assertEqual(func_b.numpy(), b)
self.assertEqual(func_c.numpy(), c)
def testInputSignatureWithMatchingInputs(self):
def foo(a):
self.assertEqual(a.shape, (2,))
return a
signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)]
defined = function.defun(foo, input_signature=signature)
a = array_ops.ones([2])
self.assertAllEqual(a, defined(a))
self.assertLen(total_function_cache(defined), 1)
self.assertAllEqual(a, defined.get_concrete_function()(a))
self.assertAllEqual(a, defined.get_concrete_function(a)(a))
self.assertAllEqual(a, defined.get_concrete_function(
tensor_spec.TensorSpec((2,), dtype=dtypes.float32))(a))
self.assertLen(total_function_cache(defined), 1)
def bar(a):
self.assertEqual(a._shape_tuple(), (2, None))
return a
signature = [tensor_spec.TensorSpec((2, None), dtypes.float32)]
defined = function.defun(bar, input_signature=signature)
a = array_ops.ones([2, 1])
out = defined(a)
self.assertLen(total_function_cache(defined), 1)
self.assertAllEqual(out, a)
# Changing the second dimension shouldn't create a new function.
b = array_ops.ones([2, 3])
out = defined(b)
self.assertLen(total_function_cache(defined), 1)
self.assertAllEqual(out, b)
def testInputSignatureWithCompatibleInputs(self):
rank2_spec = tensor_spec.TensorSpec(shape=(None, None),
def func(a):
self.assertEqual([None, None], a.shape.as_list())
return array_ops.shape(a)
self.assertAllEqual([3, 1], func([[0], [1.0], [1]]))
self.assertAllEqual([2, 2], func(numpy.array([[1, 1], [2, 2]])))
with self.assertRaisesRegex(ValueError, 'incompatible'):
func([0.0, 1.0, 2.0]) # Wrong shape.
with self.assertRaisesRegex(ValueError, 'incompatible'):
func([['wrong dtype']])
def testNoKeywordOnlyArgumentsWithInputSignature(self):
if sys.version_info[0] < 3:
self.skipTest('keyword_only arguments only exist in Python 3.')
func = eval('lambda x, *, y: x') # pylint: disable=eval-used
signature = [tensor_spec.TensorSpec(None, dtypes.int32)]
with self.assertRaisesRegex(
ValueError, 'Cannot define a TensorFlow function from a Python '
'function with keyword-only arguments when input_signature is '
def_function.function(func, signature)
def testNestedInputSignatures(self):
def expected_foo(a, b):
return [a, b]
[tensor_spec.TensorSpec((2, None), dtypes.float32)] * 2,
tensor_spec.TensorSpec((1,), dtypes.float32),
def foo(a, b):
self.assertEqual(a[0]._shape_tuple(), (2, None))
self.assertEqual(a[1]._shape_tuple(), (2, None))
self.assertEqual(b._shape_tuple(), (1,))
return [a, b]
a = array_ops.ones([2, 1])
b = array_ops.ones([1])
expected = expected_foo([a, a], b)
out = foo([a, a], b)
self.assertLen(total_function_cache(foo), 1)
nest.assert_same_structure(out, expected)
self.assertAllEqual(out[0][0], a)
self.assertAllEqual(out[0][1], a)
self.assertAllEqual(out[1], b)
# Changing the unspecified dimensions shouldn't create a new function.
a = array_ops.ones([2, 3])
b = array_ops.ones([2, 5])
c = array_ops.ones([1])
expected = expected_foo([a, b], c)
out = foo([a, b], c)
self.assertLen(total_function_cache(foo), 1)
nest.assert_same_structure(out, expected)
self.assertAllEqual(out[0][0], a)
self.assertAllEqual(out[0][1], b)
self.assertAllEqual(out[1], c)
# Passing compatible inputs should work.
a = a.numpy().tolist()
b = b.numpy().tolist()
c = c.numpy().tolist()
out = foo([a, b], c)
self.assertLen(total_function_cache(foo), 1)
nest.assert_same_structure(out, expected)
self.assertAllEqual(out[0][0], a)
self.assertAllEqual(out[0][1], b)
self.assertAllEqual(out[1], c)
def testNestedInputSignaturesWithDict(self):
def expected_bar(a):
return a
'a': tensor_spec.TensorSpec((2, None), dtypes.float32),
'b': tensor_spec.TensorSpec((2, None), dtypes.float32),
'c': tensor_spec.TensorSpec((1,), dtypes.float32)}])
def bar(a):
self.assertEqual(a['a']._shape_tuple(), (2, None))
self.assertEqual(a['b']._shape_tuple(), (2, None))
self.assertEqual(a['c']._shape_tuple(), (1,))
return a
a = array_ops.ones([2, 3])
b = array_ops.ones([1])
inputs = {'a': a, 'b': a, 'c': b}
expected = expected_bar(inputs)
out = bar(inputs)
nest.assert_same_structure(out, expected)
self.assertAllEqual(out['a'], expected['a'])
self.assertAllEqual(out['b'], expected['b'])
self.assertAllEqual(out['c'], expected['c'])
# Passing compatible inputs should work.
a = a.numpy().tolist()
b = b.numpy().tolist()
inputs = {'a': a, 'b': a, 'c': b}
out = bar(inputs)
nest.assert_same_structure(out, expected)
self.assertAllEqual(out['a'], expected['a'])
self.assertAllEqual(out['b'], expected['b'])
self.assertAllEqual(out['c'], expected['c'])
def testInputSignatureMustBeSequenceOfTensorSpecs(self):
def foo(a, b):
del a
del b
# Signatures must consist exclusively of `TensorSpec` objects.
signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)]
with self.assertRaisesRegex(TypeError, 'Invalid input_signature.*'):
def_function.function(foo, input_signature=signature)
# Signatures must be either lists or tuples on their outermost levels.
signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)}
with self.assertRaisesRegex(
TypeError, 'input_signature must be either a '
'tuple or a list.*'):
function.defun(foo, input_signature=signature)
def testInputsIncompatibleWithSignatureRaisesError(self):
def foo(a):
return a
signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)]
defined = def_function.function(foo, input_signature=signature)
# Invalid shapes.
with self.assertRaisesRegex(ValueError, 'Python inputs incompatible.*'):
with self.assertRaisesRegex(ValueError, 'Python inputs incompatible.*'):
defined(array_ops.ones([2, 1]))
# Wrong number of arguments.
with self.assertRaisesRegex(
TypeError, r'takes 1 positional arguments \(as specified by the '
r'input_signature\) but 2 were given'):
defined(array_ops.ones([2]), array_ops.ones([2]))
with self.assertRaisesRegex(ValueError,
'Structure of Python function inputs.*'):
with self.assertRaisesRegex(ValueError,
'inputs incompatible with input_signature'):
tensor_spec.TensorSpec(shape=(3,), dtype=dtypes.float32))
def testInputsIncompatibleWithNestedSignatureRaisesError(self):
def foo(a, b):
return [a, b]
signature = [[tensor_spec.TensorSpec((1,), dtypes.float32)] * 2,
[tensor_spec.TensorSpec((1,), dtypes.float32)] * 2]
defined = function.defun(foo, input_signature=signature)
a = array_ops.ones([1])
with self.assertRaisesRegex(ValueError,
'Structure of Python function inputs.*'):
defined([a, a, a], [a])
with self.assertRaisesRegex(ValueError,
'Structure of Python function inputs.*'):
defined([a], [a, a, a])
defined([a, a], [a, a])
def testUnderspecifiedInputSignature(self):
tensor_spec.TensorSpec([], dtypes.float32),
def foo(a, training=True):
if training:
return a
return -1.0 * a
x = constant_op.constant(1.0)
with self.assertRaisesRegex(
TypeError, 'got keyword argument `training` '
'that was not included in input_signature'):
foo(x, training=True)
with self.assertRaisesRegex(
TypeError, 'got keyword argument `training` '
'that was not included in input_signature'):
foo(x, training=False)
self.assertAllEqual(x.numpy(), foo(x).numpy())
def testInputSignatureWithPartialFunction(self):
def full_function(a, b, c=3.0):
return a, b, c
partial = functools.partial(full_function, 1, c=4)
a, b, c = partial(2.0)
signature = [tensor_spec.TensorSpec([], dtypes.float32)]
defined = function.defun(partial, input_signature=signature)
x = constant_op.constant(2.0)
func_a, func_b, func_c = defined(x)
self.assertEqual(func_a.numpy(), a)
self.assertEqual(func_b.numpy(), b)
self.assertEqual(func_c.numpy(), c)
def testInputSignatureConversionWithDefaultArg(self):
def foo(a, training=True):
if training:
return a
return -1.0 * a
signature = [
tensor_spec.TensorSpec([], dtypes.float32),
tensor_spec.TensorSpec([], dtypes.bool),
defined = def_function.function(foo, input_signature=signature)
a = constant_op.constant(1.0)
self.assertAllEqual(a.numpy(), defined(a))
self.assertAllEqual(a.numpy(), defined(a, training=True))
self.assertAllEqual(-a.numpy(), defined(a, training=False))
def testInputSignatureWithKeywordPositionalArgs(self):
tensor_spec.TensorSpec([], dtypes.float32),
tensor_spec.TensorSpec([], dtypes.int64)
def foo(flt, integer):
return flt, integer
flt = constant_op.constant(1.0)
integer = constant_op.constant(2, dtypes.int64)
out1, out2 = foo(flt, integer)
self.assertLen(total_function_cache(foo), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(flt=flt, integer=integer)
self.assertLen(total_function_cache(foo), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(integer=integer, flt=flt)
self.assertLen(total_function_cache(foo), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(flt, integer=integer)
self.assertLen(total_function_cache(foo), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
def testInputSignatureWithKeywordArgs(self):
def foo(a, b, **kwargs):
del kwargs
return a, b
x = function.defun(
tensor_spec.TensorSpec([], dtypes.float32),
tensor_spec.TensorSpec([], dtypes.int32)
result = x(constant_op.constant(5.0), constant_op.constant(5))
self.assertAllEqual(result, [5.0, 5])
def testInputSignatureWithCompositeTensors(self):
def f(rt):
self.assertEqual(rt.values.shape.as_list(), [None])
self.assertEqual(rt.row_splits.shape.as_list(), [4])
return rt
signature = [ragged_tensor.RaggedTensorSpec(
shape=[3, None], dtype=dtypes.int32)]
defined = function.defun(f, input_signature=signature)
rt1 = ragged_factory_ops.constant([[1], [], [2, 3, 4]])
out1 = defined(rt1)
self.assertLen(total_function_cache(defined), 1)
self.assertAllEqual(out1.values, rt1.values)
self.assertAllEqual(out1.row_splits, rt1.row_splits)
# Changing the row lengths shouldn't create a new function.
rt2 = ragged_factory_ops.constant([[1, 2], [3, 4], [5]])
out2 = defined(rt2)
self.assertLen(total_function_cache(defined), 1)
self.assertAllEqual(out2.values, rt2.values)
self.assertAllEqual(out2.row_splits, rt2.row_splits)
# Different number of rows
rt3 = ragged_factory_ops.constant([[1, 2], [3, 4], [5], [6]])
with self.assertRaisesRegex(ValueError, 'incompatible'):
# Different dtype
rt4 = ragged_factory_ops.constant([[1.0, 2.0], [], [3.0]])
with self.assertRaisesRegex(ValueError, 'Structure .* does not match'):
# Different rank
rt5 = ragged_factory_ops.constant([[[1]], [[2]], [[3]]])
with self.assertRaisesRegex(ValueError, 'does not match'):
def testInputSignatureWithVariableArgs(self):
def f(v):
signature = [
resource_variable_ops.VariableSpec(shape=[], dtype=dtypes.int32)
defined = function.defun(f, input_signature=signature)
v1 = variables.Variable(0)
v2 = variables.Variable(0)
self.assertEqual(v1.numpy(), 1)
self.assertEqual(v2.numpy(), 0)
self.assertEqual(v1.numpy(), 1)
self.assertEqual(v2.numpy(), 1)
def testTensorKeywordArguments(self):
def foo(a, b):
del a
return b
defined = function.defun(foo)
a = constant_op.constant(2.0)
b = constant_op.constant([1.0, 2.0])
one = defined(a, b)
self.assertLen(total_function_cache(defined), 1)
two = defined(a=a, b=b)
self.assertLen(total_function_cache(defined), 1)
three = defined(b=b, a=a)
self.assertLen(total_function_cache(defined), 1)
four = defined(a, b=b)
self.assertLen(total_function_cache(defined), 1)
# The next call corresponds to a new input signature, hence
# we expect another function to be defined.
five = defined(b, a)
self.assertLen(total_function_cache(defined), 2)
six = defined(a=b, b=a)
self.assertLen(total_function_cache(defined), 2)
seven = defined(b=a, a=b)
self.assertLen(total_function_cache(defined), 2)
self.assertAllEqual(one, [1.0, 2.0])
self.assertAllEqual(two, [1.0, 2.0])
self.assertAllEqual(three, [1.0, 2.0])
self.assertAllEqual(four, [1.0, 2.0])
self.assertAllEqual(five, 2.0)
self.assertAllEqual(six, 2.0)
self.assertAllEqual(seven, 2.0)
def testDefuningInstanceMethod(self):
integer = constant_op.constant(2, dtypes.int64)
class Foo(object):
def one(self, tensor):
return tensor
def two(self, tensor, other=integer):
return, other
foo = Foo()
t = constant_op.constant(1.0)
one, two = foo.two(t)
self.assertEqual(one.numpy(), 1.0)
self.assertEqual(two.numpy(), 2)
def testDefuningInstanceMethodWithDefaultArgument(self):
integer = constant_op.constant(2, dtypes.int64)
class Foo(object):
def func(self, other=integer):
return other
foo = Foo()
self.assertEqual(foo.func().numpy(), int(integer))
def testPythonCallWithSideEffects(self):
state = []
def side_effecting_function():
self.assertAllEqual(state, [0])
# The second invocation should call the graph function, which shouldn't
# trigger the list append.
self.assertAllEqual(state, [0])
# Whereas calling the python function directly should create a side-effect.
self.assertAllEqual(state, [0, 0])
def testFunctionWithNestedFunctionCallAndSideEffects(self):
v1 = variables.Variable(1.0)
v2 = variables.Variable(1.0)
def add_one(a):
# Grappler will inline calls to `add_one` into the function body, we check
# that all side-effects were executed.
def side_effecting_function(a, b):
return a + b
result = side_effecting_function(v1, v2)
self.assertEqual(result.numpy(), 4.0)
def testFunctionWithExtraAttributes(self):
@function.defun_with_attributes(attributes={'experimental_1': 'value1',
'experimental_2': 2})
def matmul(x, y):
return math_ops.matmul(x, y)
def add(x, y):
return math_ops.add(x, y)
defun_add = function.defun_with_attributes(
add, attributes={'experimental_3': True, 'experimental_4': 1.0})
with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq = matmul(t, t)
double = defun_add(t, t)
self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
graph = ops.get_default_graph()
# pylint: disable=protected-access
self.assertLen(graph._functions, 2)
functions = list(graph._functions.values())
self.assertRegex(functions[0], '.*matmul.*')
attrs = functions[0].definition.attr
self.assertLen(attrs, 2)
self.assertEqual(attrs['experimental_1'].s, b'value1')
self.assertEqual(attrs['experimental_2'].i, 2)
self.assertRegex(functions[1], '.*add.*')
attrs = functions[1].definition.attr
self.assertLen(attrs, 2)
self.assertEqual(attrs['experimental_3'].b, True)
self.assertEqual(attrs['experimental_4'].f, 1.0)
# pylint: enable=protected-access
def testFunctionWithInvalidAttribute(self):
@function.defun_with_attributes(attributes={'experimental_1': ['value1']})
def add(x, y):
return math_ops.add(x, y)
with self.assertRaisesRegex(ValueError, '.*Unsupported attribute type.*'):
with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
add(t, t)
def testRegisterFunction(self):
def add(x, y):
return math_ops.add(x, y)
def matmul(x, y):
return math_ops.matmul(x, y)
defun_matmul = function.defun(matmul)
with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
function.register(defun_matmul, t, t)
function.register(add, t, t)
graph = ops.get_default_graph()
# pylint: disable=protected-access
self.assertLen(graph._functions, 6)
# two sets of functions, each of them are (inference, forward, backward)
functions = list(graph._functions.values())
captured_function_names = [ for f in functions
expected_func_name_regex = [
for i in range(len(functions)):
# Check the forward and backward function has the correct attributes.
sq = defun_matmul(t, t)
double = add(t, t)
self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
# Make sure the pre registered function is used, and no other function
# is added.
self.assertLen(graph._functions, 6)
functions = list(graph._functions.values())
for i in range(len(functions)):
def testRegisterConcreteFunction(self, function_decorator):
def py_add(x, y):
return math_ops.add(x, y)
py_add(array_ops.ones([]), array_ops.ones([]))
add = py_add.get_concrete_function(
tensor_spec.TensorSpec(None, dtypes.float32),
tensor_spec.TensorSpec(None, dtypes.float32))
def py_composite(x, y):
return x, add(x, y)
py_composite(array_ops.ones([]), array_ops.ones([]))
composite = py_composite.get_concrete_function(
tensor_spec.TensorSpec(None, dtypes.float32),
tensor_spec.TensorSpec(None, dtypes.float32))
with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
graph = ops.get_default_graph()
# pylint: disable=protected-access
self.assertLen(graph._functions, 6)
# two sets of functions, each of them are (inference, forward, backward)
functions = list(graph._functions.values())
captured_function_names = [ for f in functions
expected_func_name_regex = [
for expected, found in zip(
self.assertRegex(found, expected)
composite_t, composite_double = composite(t, t)
double = add(t, t)
self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(double))
self.assertAllEqual([[2, 4], [6, 8]], self.evaluate(composite_double))
self.assertAllEqual([[1, 2], [3, 4]], self.evaluate(composite_t))
# Make sure the pre registered function is used, and no other function
# is added.
self.assertLen(graph._functions, 6)
def testEagerCaptures(self, function_decorator):
with context.eager_mode():
large_tensor = array_ops.ones(shape=(256,))
self.assertGreater(256, func_graph._EAGER_CONST_THRESHOLD)
small_tensor = array_ops.ones(shape=(4,))
self.assertLessEqual(4, func_graph._EAGER_CONST_THRESHOLD)
v = resource_variable_ops.ResourceVariable(0.0)
for captured, op_type in [(large_tensor, 'Placeholder'),
(small_tensor, 'Const'), (v, 'Placeholder')]:
def test_fn():
return captured + 1 # pylint: disable=cell-var-from-loop
g = test_fn.get_concrete_function().graph
internal_captures = g.internal_captures
self.assertLen(internal_captures, 1)
self.assertEqual(internal_captures[0].op.type, op_type)
def testRegisterFunctionWithInputSignature(self):
def matmul(x, y):
return math_ops.matmul(x, y)
defun_matmul = function.defun(
tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
function.register(defun_matmul, t, t)
graph = ops.get_default_graph()
# pylint: disable=protected-access
self.assertLen(graph._functions, 3)
# Test register function with cache, note inputs are ignored.
graph = ops.get_default_graph()
self.assertLen(graph._functions, 3)
def testRegisterFunctionWithCache(self):
def matmul(x, y):
return math_ops.matmul(x, y)
defun_matmul = function.defun(matmul)
with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]])
function.register(defun_matmul, t, t)
function.register(defun_matmul, t2, t2)
graph = ops.get_default_graph()
# Only one function is registered since the input param are in same type
# pylint: disable=protected-access
self.assertLen(graph._functions, 3)
def testCallingFunctionWithDifferentVariables(self):
def foo(v):
return v.read_value()
v = resource_variable_ops.ResourceVariable(0.0)
graph_function = foo.get_concrete_function(v)
self.assertLen(graph_function.inputs, 1)
self.assertEqual(float(graph_function(v)), 1.0)
self.assertEqual(float(graph_function(v)), 2.0)
w = resource_variable_ops.ResourceVariable(0.0)
def bar(v):
del v
return constant_op.constant(1.0)
graph_function = bar.get_concrete_function(v)
self.assertEqual(float(graph_function(v)), 1.0)
self.assertEqual(float(graph_function(w)), 1.0)
def testCallingFunctionWithNonTensorsFails(self):
def foo(x):
return x
graph_function = foo.get_concrete_function(constant_op.constant(1.0))
with self.assertRaises((TypeError, ValueError)):
graph_function('Not a Tensor.')
def testSwapImplementationWithGrapplerPlugin(self):
# Set the min_graph_nodes to -1 since the graph in this test is too small,
# and will be ignored by grappler if don't set this.
rewrites = rewriter_config_pb2.RewriterConfig()
rewrites.implementation_selector = rewriter_config_pb2.RewriterConfig.ON
rewrites.min_graph_nodes = -1
graph_options = config_pb2.GraphOptions(
rewrite_options=rewrites, build_cost_model=1)
config_proto = config_pb2.ConfigProto(graph_options=graph_options)
with context.graph_mode(), self.cached_session(
config=config_proto, graph=ops.Graph(), use_gpu=True):
'api_implements': 'random_boost',
'api_preferred_device': 'CPU'
def cpu_boost(x):
return math_ops.add(x, 2.0)
'api_implements': 'random_boost',
'api_preferred_device': 'GPU'
def gpu_boost(x):
return math_ops.add(x, 4.0)
x = constant_op.constant(1.0)
function.register(cpu_boost, x)
y = gpu_boost(x)
y_value = self.evaluate(y)
if test.is_gpu_available():
self.assertEqual(y_value, 5.0)
# Grappler fallback to use the CPU impl even called with GPU function.
self.assertEqual(y_value, 3.0)
def testSwapImplementationInEager(self):
if not context.executing_eagerly():
self.skipTest('eager only')
# testSharedRendezvous sets the disable_meta_optimizer flag to True
# if that subtest runs before this one, then having that set to True
# will cause this subtest to fail. To avoid that scenario, explicitly
# set the disable_meta_optimizer flag to false here
'min_graph_nodes': -1,
'implementation_selector': True,
'disable_meta_optimizer': False
attributes={'api_implements': 'foo',
'api_preferred_device': 'CPU'})
def on_cpu(x):
return x + 2
attributes={'api_implements': 'foo',
'api_preferred_device': 'GPU'})
def on_gpu(x):
return x + 4
def run_on_cpu(t):
function.register(on_cpu, t)
with ops.device('CPU:0'):
return on_gpu(t)
# Expect to run the on_cpu branch, regardless whether gpu is available.
self.assertEqual(run_on_cpu(constant_op.constant(1)).numpy(), 3)
def testDefunFunctionSeparateGraphs(self):
with context.graph_mode():
def add(x):
return x + 5
def maybe_add(x, should_add):
if should_add:
return add(x)
return x
with ops.Graph().as_default():
x = constant_op.constant(11)
maybe_add(x, True)
self.assertLen(total_function_cache(maybe_add), 1)
self.assertLen(total_function_cache(add), 1)
maybe_add(x, False)
self.assertLen(total_function_cache(maybe_add), 2)
self.assertLen(total_function_cache(add), 1)
with ops.Graph().as_default():
x = constant_op.constant(11)
maybe_add(x, True)
self.assertLen(total_function_cache(maybe_add), 3)
self.assertLen(total_function_cache(add), 2)
def testCacheKeyOverlappingShapes(self):
def defined(t):
return t
defined(array_ops.zeros([12, 1]))
self.assertLen(total_function_cache(defined), 1)
defined(array_ops.zeros([1, 21]))
self.assertLen(total_function_cache(defined), 2)
def testCacheKeyNestedLists(self):
def defined(l):
return l
a = constant_op.constant(1.)
b = constant_op.constant(2.)
c = constant_op.constant(3.)
defined([[a], b, c])
self.assertLen(total_function_cache(defined), 1)
defined([[a, b], c])
self.assertLen(total_function_cache(defined), 2)
def testCacheKeyAttrsClass(self):
if attr is None:
self.skipTest('attr module is unavailable.')
class TestClass(object):
a = attr.ib()
b = attr.ib()
def defined(l):
return l
self.assertLen(total_function_cache(defined), 1)
self.assertLen(total_function_cache(defined), 1)
constant_op.constant(2.)], constant_op.constant(3.)))
self.assertLen(total_function_cache(defined), 2)
def testCacheKeyVariables(self):
def defined(a, b, c):
return a + b + c
x = resource_variable_ops.ResourceVariable(0.0)
y = resource_variable_ops.ResourceVariable(0.0)
z = resource_variable_ops.ResourceVariable(0.0)
# If tensor equality is not enabled, we always get a cache miss if the
# function is called with different variables. With equality enabled we
# should only get a miss if the aliasing changed.
defined(x, y, z)
self.assertLen(total_function_cache(defined), 1)
defined(x, y, z)
self.assertLen(total_function_cache(defined), 1)
# Re-arranging arguments causes cache miss
defined(z, y, x)
self.assertLen(total_function_cache(defined), 2)
defined(z, y, x)
self.assertLen(total_function_cache(defined), 2)
# Aliasing causes cache miss
defined(x, x, z)
self.assertLen(total_function_cache(defined), 3)
defined(x, x, z)
self.assertLen(total_function_cache(defined), 3)
# Re-arranging arguments causes cache miss
defined(y, y, z)
self.assertLen(total_function_cache(defined), 4)
defined(y, y, z)
self.assertLen(total_function_cache(defined), 4)
# Different alias positions causes cache miss
defined(z, y, y)
self.assertLen(total_function_cache(defined), 5)
defined(z, y, y)
self.assertLen(total_function_cache(defined), 5)
x_copy = copy.deepcopy(x)
# Deep copy causes cache miss
defined(x_copy, y, z)
self.assertLen(total_function_cache(defined), 6)
defined(x_copy, y, z)
self.assertLen(total_function_cache(defined), 6)
def testVariableRetracing(self):
v1 = variables.Variable(1.)
v2 = variables.Variable(1.)
v3 = copy.deepcopy(variables.Variable(1.))
var_dict = {id(v1): constant_op.constant(1),
id(v2): constant_op.constant(2),
id(v3): constant_op.constant(3)}
def lookup_tensor(v):
return var_dict[id(v)]
self.assertEqual(1, lookup_tensor(v1).numpy())
self.assertEqual(2, lookup_tensor(v2).numpy())
self.assertEqual(3, lookup_tensor(v3).numpy())
def testDecoratedMethodInspect(self):
class DefunnedMiniModel(object):
def call(self, inputs, training=True):
m = DefunnedMiniModel()
fullargspec = tf_inspect.getfullargspec(
self.assertIn('training', fullargspec.args)
def testFunctionModifiesInputList(self):
# Tests on `list` methods that do in place modification, except `list.sort`
# since it cannot even be "defunned" in the first place
def get_list():
return [constant_op.constant(0.), constant_op.constant(1.)]
expected_msg = '.*() should not modify'
with self.assertRaisesRegex(ValueError, expected_msg):
def append(l):
with self.assertRaisesRegex(ValueError, expected_msg):
def extend(l):
with self.assertRaisesRegex(ValueError, expected_msg):
def insert(l):
l.insert(0, constant_op.constant(0.))
with self.assertRaisesRegex(ValueError, expected_msg):
def pop(l):
with self.assertRaisesRegex(ValueError, expected_msg):
def reverse(l):
with self.assertRaisesRegex(ValueError, expected_msg):
def remove(l):
# `list.clear` is a method that is in Py3 but not Py2
if sys.version.startswith('3'):
with self.assertRaisesRegex(ValueError, expected_msg):
def clear(l):
# One last test for keyword arguments
with self.assertRaisesRegex(ValueError, expected_msg):
def kwdappend(**kwargs):
l = kwargs['l']
def testFunctionModifiesInputDict(self):
def get_dict():
return {'t1': constant_op.constant(0.), 't2': constant_op.constant(1.)}
expected_msg = '.* should not modify'
with self.assertRaisesRegex(ValueError, expected_msg):
def clear(m):
with self.assertRaisesRegex(ValueError, expected_msg):
def pop(m):
with self.assertRaisesRegex(ValueError, expected_msg):
def popitem(m):
with self.assertRaisesRegex(ValueError, expected_msg):
def update(m):
m.update({'t1': constant_op.constant(3.)})
with self.assertRaisesRegex(ValueError, expected_msg):
def setdefault(m):
m.setdefault('t3', constant_op.constant(3.))
def testFunctionModifiesInputNest(self):
with self.assertRaisesRegex(ValueError, 'modify.* should not modify'):
def modify(n):
nested_input = [{
't1': [constant_op.constant(0.),
with self.assertRaisesRegex(ValueError,
'modify_same_flat.* should not modify'):
# The flat list doesn't change whereas the true structure changes
def modify_same_flat(n):
nested_input = [[constant_op.constant(0.)],
def testExecutorType(self):
def add_five(x):
return x + 5
add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy())
with self.assertRaisesRegex(errors.NotFoundError, 'NON_EXISTENT_EXECUTOR'):
with context.function_executor_type('NON_EXISTENT_EXECUTOR'):
add_five(constant_op.constant(0, dtype=dtypes.int32))
for executor_type in ('', 'DEFAULT', None):
with context.function_executor_type(executor_type):
add_five(constant_op.constant(0, dtype=dtypes.int32)).numpy())
def testReferenceCycles(self):
fn = function.defun(lambda x: 2. * x)
weak_fn = weakref.ref(fn)
del fn
# Tests that the weak reference we made to the function is now dead, which
# means the object has been deleted. This should be true as long as the
# function itself is not involved in a reference cycle.
self.assertIs(None, weak_fn())
def testFunctionStackInErrorMessage(self):
if context.executing_eagerly():
# TODO(b/122736651): Remove this skipTest once fixed.
self.skipTest('Error interpolation is not working when function is '
'invoked without PartitionedCallOp.')
def fn3(x):
return x + 2
def fn2(x):
check_ops.assert_equal(fn3(x), 3)
return 2
def fn(x):
return fn2(x)
with self.assertRaises(errors.InvalidArgumentError) as cm:
e = cm.exception
self.assertIn('fn -> fn2', e.message)
self.assertIn('node assert_equal/Assert/Assert (defined at', e.message)
self.assertNotIn('fn3', e.message)
def testFunctionIsNotPinned(self):
"""Tests that functions aren't pinned to the CPU by the eager runtime."""
seed1, seed2 = 79, 25
shape = constant_op.constant([4, 7])
dtype = dtypes.float32
def func():
with ops.device('GPU:0'):
return gen_random_ops.random_standard_normal(
shape, dtype=dtype, seed=seed1, seed2=seed2)
with ops.device('GPU:0'):
x = func()
self.assertRegex(x.device, 'GPU')
def testShapeCaching(self):
def func(x):
return array_ops.shape(x)
input_signature=[tensor_spec.TensorSpec([None, None], dtypes.float32)])
def calls_func(x):
return func(x)
self.assertAllEqual([1, 1], self.evaluate(func(array_ops.zeros([1, 1]))))
self.assertAllEqual([2, 2], self.evaluate(func(array_ops.zeros([2, 2]))))
[3, 3],
self.evaluate(calls_func(array_ops.zeros([3, 3]))))
def testLimitedRetracing(self):
trace_count = [0]
def func(x):
trace_count[0] += 1
return x
for _ in range(50):
func(constant_op.constant([[1., 2.]]))
func(constant_op.constant([[3., 4.], [5., 6.]]))
func(constant_op.constant([[3., 4.], [5., 6.], [7., 8.]]))
# Tracing more than twice per input doesn't make sense.
self.assertLess(trace_count[0], 13)
def testLimitedRetracingWithCompositeTensors(self):
trace_count = [0]
def f(x):
trace_count[0] += 1
return x
for i in range(10):
f(ragged_factory_ops.constant([[1, 2], [i]]))
f(ragged_factory_ops.constant([[1, 2], [], [3, 4, 5]]))
f(ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]]))
self.assertEqual(trace_count[0], 3)
def test_concrete_function_shape_mismatch(self):
def f(argument_name):
return argument_name + 1.
f_concrete = f.get_concrete_function(constant_op.constant([1.]))
# Calling a function from eager doesn't do any shape checking above what
# kernels do while executing.
[2., 3.],
f_concrete(constant_op.constant([1., 2.])).numpy())
def g():
f_concrete(constant_op.constant([1., 2.]))
with self.assertRaisesRegex(ValueError, 'argument_name'):
def test_shape_inference_with_symbolic_shapes(self):
def _uses_symbolic_shapes(w, x, y):
x = array_ops.identity(x, name='name_collision')
x = array_ops.transpose(x, [1, 0, 2])
x_batch = array_ops.shape(x)[0]
y_batch = array_ops.shape(y)[0]
y *= w
n = y_batch // x_batch
return array_ops.reshape(y, [n, x_batch, -1])
conc = _uses_symbolic_shapes.get_concrete_function(
tensor_spec.TensorSpec(None, dtypes.float32),
tensor_spec.TensorSpec(None, dtypes.float32),
tensor_spec.TensorSpec(None, dtypes.float32))
def _call_concrete():
c = constant_op.constant(1.)
array_ops.identity(c, name='name_collision')
output1 = conc(array_ops.ones([2]),
array_ops.ones([5, 4, 2]),
array_ops.ones([20, 2]))
self.assertEqual([5, 4, 2], output1.shape)
output2 = conc(array_ops.ones([3]),
array_ops.ones([5, 4, 3]),
array_ops.ones([40, 3]))
self.assertEqual([10, 4, 3], output2.shape)
return output1, output2
output1, output2 = _call_concrete()
self.assertEqual((5, 4, 2), self.evaluate(output1).shape)
self.assertEqual((10, 4, 3), self.evaluate(output2).shape)
def testAutoGraphContext(self):
def test_fn():
ag_ctx.control_status_ctx().status, ag_ctx.Status.ENABLED)
prev_status = ag_ctx.control_status_ctx().status
self.assertEqual(ag_ctx.control_status_ctx().status, prev_status)
def testCancelBeforeFunctionExecution(self):
if not context.executing_eagerly():
self.skipTest('eager only')
q = data_flow_ops.FIFOQueue(1, dtypes.int32)
def f():
return q.dequeue()
c_mgr = cancellation.CancellationManager()
cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
with self.assertRaises(errors.CancelledError):
def testCancelBlockedFunctionExecution(self):
if not context.executing_eagerly():
self.skipTest('eager only')
q = data_flow_ops.FIFOQueue(1, dtypes.int32)
def f():
return q.dequeue()
c_mgr = cancellation.CancellationManager()
cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
def cancel_thread():
t = self.checkedThread(cancel_thread)
with self.assertRaises(errors.CancelledError):
def testCancelAfterFunctionExecution(self):
if not context.executing_eagerly():
self.skipTest('eager only')
q = data_flow_ops.FIFOQueue(1, dtypes.int32)
def f():
return q.dequeue()
c_mgr = cancellation.CancellationManager()
cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
self.assertAllEqual(37, cancelable_func().numpy())
# Cancellation after the function executes is a no-op.
def testAddFunctionCallback(self):
functions = []
def function_callback(f):
def plus_one(x):
return x + 1
x_float32 = numpy.array(3.0, dtype=numpy.float32)
self.assertAllClose(plus_one(x_float32), 4.0)
self.assertLen(functions, 1)
# Function is already created. Executing it again should not invoke the
# function callback.
self.assertAllClose(plus_one(x_float32), 4.0)
self.assertLen(functions, 1)
# Signature change leads to a new Function being built.
x_float64 = numpy.array(3.0, dtype=numpy.float64)
self.assertAllClose(plus_one(x_float64), 4.0)
self.assertLen(functions, 2)
def testRemoveFunctionCallback(self):
functions_1 = []
def function_callback_1(f):
functions_2 = []
def function_callback_2(f):
def plus_one(x):
return x + 1
self.assertAllClose(plus_one(numpy.array(3.0, dtype=numpy.float32)), 4.0)
self.assertLen(functions_1, 1)
self.assertLen(functions_2, 1)
# The 1st callback should not be invokved after remove_function_callback()
# is called.
self.assertAllClose(plus_one(numpy.array(3.0, dtype=numpy.float64)), 4.0)
self.assertLen(functions_1, 1)
self.assertLen(functions_2, 2)
def testClearFunctionCallbacks(self):
function.add_function_callback(lambda f: None)
function.add_function_callback(lambda f: None)
self.assertLen(function._function_callbacks, 2)
self.assertEmpty(function._function_callbacks) # pylint:disable=protected-access
def testConcreteFunctionWithNestedTensorInputs(self):
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = constant_op.constant(1000)
b = constant_op.constant(200)
c = constant_op.constant(30)
d = {'a': a, 'b': b}
e = (c, 4)
# Test different argument signatures when constructing the concrete func.
for cf in [
f.get_concrete_function(d, e),
f.get_concrete_function(d, y=e),
f.get_concrete_function(y=e, x=d),
f.get_concrete_function(_spec_for_value(d), _spec_for_value(e)),
f.get_concrete_function(_spec_for_value(d), y=_spec_for_value(e)),
f.get_concrete_function(y=_spec_for_value(e), x=_spec_for_value(d))
# Test different calling conventions when calling the concrete func.
for output in [
cf(d, e), # structured signature
cf(d, y=e), # structured signature w/ kwarg
cf(y=e, x=d), # structured signature w/ 2 kwargs
cf(a, b, c), # flat signature
cf(x=a, x_1=b, y=c) # flat signature w/ kwargs
self.assertIsInstance(output, tuple)
self.assertLen(output, 2)
self.assertAllEqual(output[0], 1200)
self.assertAllEqual(output[1], 34)
def testConcreteFunctionWithNestedNonTensorInputs(self):
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = {'a': constant_op.constant(1000), 'b': constant_op.constant(200)}
b = (50, 3)
for cf in [ # argument y is bound to non-Tensor value (50, 3).
f.get_concrete_function(a, b),
f.get_concrete_function(a, y=b),
f.get_concrete_function(x=a, y=b)
for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]:
self.assertAllEqual(output[0] + output[1], 1253)
def testConcreteFunctionWithBoundNestedNonTensorInputs(self):
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = {'a': 3000, 'b': 200, 'c': 9000}
b = (constant_op.constant(30), 4)
for cf in [ # argument x is bound to non-tensor value `a`
f.get_concrete_function(a, b),
f.get_concrete_function(a, y=b),
f.get_concrete_function(x=a, y=b)
for output in [cf(a, b), cf(a, y=b), cf(y=b), cf(x=a, y=b)]:
self.assertAllEqual(output[0] + output[1], 3234)
def testConcreteFunctionWithAllBoundNestedNonTensorInputs(self):
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = {'a': 5000, 'b': 500}
b = (50, 5)
cf = f.get_concrete_function(a, b)
for output in [cf(), cf(a), cf(y=b)]:
self.assertAllEqual(output[0] + output[1], 5555)
def testConcreteFunctionMethodWithVarargs(self):
float32_scalar = tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)
class MyModel(module.Module):
@def_function.function(input_signature=[float32_scalar, float32_scalar])
def add(self, *arg):
return math_ops.add(*arg)
m = MyModel()
cf = m.add.get_concrete_function()
cf(-12.0, 3.0)
def testConcreteFunctionStructuredSignatureKeywordOrder(self):
# Check that keyword-only arguments are sorted appropriately, so that they
# feed the right tensor into each input.
def g(**kwargs):
return string_ops.reduce_join(
separator=', ')
s = constant_op.constant('s')
g.get_concrete_function(q=s, a=s, p=s, r=s, v=s, m=s, l=s)
g(m='a', r='b', v='c', q='d', l='e', a='f', p='g'),
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
g(q='d', a='f', p='g', r='b', v='c', m='a', l='e'),
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
g(a='f', l='e', m='a', p='g', q='d', r='b', v='c'),
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
# pylint: disable=g-long-lambda
conc_args=lambda: (1, constant_op.constant(2)),
call_args=lambda: (1,),
error=r'func\(x, y\) missing required arguments: y'),
conc_args=lambda: (1, 2, constant_op.constant(1.0)),
call_args=lambda: (1, 2),
error=r'func\(x, y, <arg3>\) missing required arguments: <arg3>'),
conc_args=lambda: (1, 2),
call_args=lambda: (1, 2, 3),
error=r'func\(x, y\) takes 2 positional arguments but 3 were given'),
conc_args=lambda: (1, 2),
conc_kwargs=lambda: {'c': constant_op.constant(1.0)},
call_args=lambda: (1, 2),
error=r'func\(x, y, \*, c\) missing required arguments: c'),
conc_args=lambda: (1, 2),
call_args=lambda: (1, 2),
call_kwargs=lambda: {'c': constant_op.constant(1.0)},
error=r'func\(x, y\) got unexpected keyword arguments: c'),
conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
call_args=lambda: ({
'a': constant_op.constant([1, 2, 3])
error=r'func\(x, y\): argument x had incorrect type\n'
r' expected: RaggedTensor\n'
r" got: {'a': (Eager)?Tensor}"),
conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
call_args=lambda: (ragged_factory_ops.constant([[[1]]]),),
error=r'func\(x, y\): argument x had incorrect type\n'),
conc_args=lambda: (ragged_factory_ops.constant([[1]]),),
call_args=lambda: (ragged_factory_ops.constant([[1.0]]),),
error=r'func\(x, y\): argument x had incorrect type\n'),
conc_args=lambda: ({
'a': constant_op.constant(1),
'b': constant_op.constant(1)
call_args=lambda: (constant_op.constant(1),),
error=r'func\(x, y\): argument x had incorrect type\n'),
((constant_op.constant(1), constant_op.constant(2)),),
call_args=lambda: (constant_op.constant(1),),
error=r'func\(x, y\): argument x had incorrect type\n'),
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (constant_op.constant(1.0),),
exception=(ValueError, errors.InvalidArgumentError,
# on xla_gpu, we get InternalError instead.
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (5,),
error=r'func\(x, y\) expected a Tensor in x, but got int value 5'),
conc_args=lambda: (5,),
call_args=lambda: (8,),
error=r'ConcreteFunction func\(x, y\) was constructed with int '
r'value 5 in x, but was called with int value 8'),
conc_args=lambda: (5,),
call_args=lambda: (constant_op.constant(6),),
error=r'ConcreteFunction func\(x, y\) was constructed with int '
'value 5 in x, but was called with (Eager)?Tensor value .*'),
conc_args=lambda: (1, 2),
call_args=lambda: (1, 2),
call_kwargs=lambda: {'x': 3},
error=r"func\(x, y\) got two values for argument 'x'"),
# pylint: enable=g-long-lambda
def testConcreteFunctionStructuredSignatureError(self,
"""Tests for errors in the structrued signature.
conc_args: Positional arguments used for get_concrete_function.
conc_kwargs: Keyword arguments used for get_concrete_function.
call_args: Positional arguments used to call the function.
call_kwargs: Keyword arguments used to call the function.
error: Expected exception message.
exception: Expected exception type.
conc_args = conc_args() if callable(conc_args) else conc_args
conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
call_args = call_args() if callable(call_args) else call_args
call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
self.assertIsInstance(conc_args, tuple)
self.assertIsInstance(call_args, tuple)
self.assertIsInstance(conc_kwargs, dict)
self.assertIsInstance(call_kwargs, dict)
def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg
del y, varargs, kwargs
return x
conc = func.get_concrete_function(*conc_args, **conc_kwargs)
with self.assertRaisesRegex(exception, error):
self.evaluate(conc(*call_args, **call_kwargs))
# pylint: disable=g-long-lambda
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (constant_op.constant(1),),
error=r'func\(x, y\) missing required arguments: y'),
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (constant_op.constant(1),),
call_kwargs=lambda: {
'x': constant_op.constant(1),
'y': constant_op.constant(1)
error=r"func\(x, y\) got two values for argument 'x'"),
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (constant_op.constant(1), constant_op.constant(2),
error=r'func\(x, y\) takes 2 positional arguments but 3 were given'),
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (constant_op.constant(1),),
call_kwargs=lambda: {'c': constant_op.constant(1)},
error=r'func\(x\) got unexpected keyword arguments: c'),
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2),
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
error=r'func\(x, y, varargs_0\) missing required '
r'arguments: varargs_0'),
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
conc_kwargs=lambda: {'c': constant_op.constant(1)},
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
error=r'func\(x, y, c\) missing required arguments: c'),
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (5, constant_op.constant(2)),
error=r'func\(x, y\): expected argument #0\(zero-based\) to be '
r'a Tensor; got int \(5\)'),
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (constant_op.constant(1.0),),
exception=(ValueError, errors.InvalidArgumentError,
# on xla_gpu, we get InternalError instead.
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
conc_kwargs=lambda: {'c': ragged_factory_ops.constant([[1]])},
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_kwargs=lambda: {'c': constant_op.constant(1)},
error=r'func\(x, y, c, c_1\) missing required arguments: c_1'),
# pylint: enable=g-long-lambda
def testConcreteFunctionFlatSignatureError(self,
"""Tests for errors in the flat signature.
conc_args: Positional arguments used for get_concrete_function.
conc_kwargs: Keyword arguments used for get_concrete_function.
call_args: Positional arguments used to call the function.
call_kwargs: Keyword arguments used to call the function.
error: Expected exception message.
exception: Expected exception type.
conc_args = conc_args() if callable(conc_args) else conc_args
conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
call_args = call_args() if callable(call_args) else call_args
call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
self.assertIsInstance(conc_args, tuple)
self.assertIsInstance(call_args, tuple)
self.assertIsInstance(conc_kwargs, dict)
self.assertIsInstance(call_kwargs, dict)
def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg
del y, varargs, kwargs
return x
conc = func.get_concrete_function(*conc_args, **conc_kwargs)
# Remove _function_spec, to disable the structured signature.
conc._set_function_spec(None) # pylint: disable=protected-access
with self.assertRaisesRegex(exception, error):
self.evaluate(conc(*call_args, **call_kwargs))
def testConcreteFunctionAmbiguousSignature(self):
# When both the flat & structured signatures are applicable, but they
# give different results, we use the structured signature. Note: we expect
# this to be extremely rare.
def f(x, y):
return x * 10 + y
conc = f.get_concrete_function(
x=tensor_spec.TensorSpec(None, dtypes.int32, name='y'),
y=tensor_spec.TensorSpec(None, dtypes.int32, name='x'))
result = conc(x=constant_op.constant(5), y=constant_op.constant(6))
self.assertAllEqual(result, 56)
def testPrettyPrintedSignature(self):
def func(x, kangaroo=None, octopus=7):
del octopus, kangaroo
return x
scalar = constant_op.constant(5)
vector = constant_op.constant([10, 10, 20])
ragged = ragged_factory_ops.constant([[10, 20], [40]])
c1 = func.get_concrete_function(scalar, vector)
c1_summary = r'func\(x, kangaroo, octopus=7\)'
c1_details = (r' Args:\n'
r' x: int32 Tensor, shape=\(\)\n'
r' kangaroo: int32 Tensor, shape=\(3,\)\n'
r' Returns:\n'
r' int32 Tensor, shape=\(\)')
self.assertRegex(c1.pretty_printed_signature(verbose=False), c1_summary)
c1_summary + '\n' + c1_details)
repr(c1), r'<ConcreteFunction func\(x, kangaroo, octopus=7\) at .*>')
str(c1), 'ConcreteFunction {}\n{}'.format(c1_summary, c1_details))
c2 = func.get_concrete_function(scalar, ragged, 3)
c2_summary = r'func\(x, kangaroo, octopus=3\)'
c2_details = (r' Args:\n'
r' x: int32 Tensor, shape=\(\)\n'
r' kangaroo: RaggedTensorSpec\(.*\)\n'
r' Returns:\n'
r' int32 Tensor, shape=\(\)')
c2_summary + '\n' + c2_details)
c3 = func.get_concrete_function({'a': scalar, 'b': [ragged, ragged]})
c3_summary = r'func\(x, kangaroo=None, octopus=7\)'
c3_details = (r' Args:\n'
r" x: {'a': <1>, 'b': \[<2>, <3>\]}\n"
r' <1>: int32 Tensor, shape=\(\)\n'
r' <2>: RaggedTensorSpec\(.*\)\n'
r' <3>: RaggedTensorSpec\(.*\)\n'
r' Returns:\n'
r" {'a': <1>, 'b': \[<2>, <3>\]}\n"
r' <1>: int32 Tensor, shape=\(\)\n'
r' <2>: RaggedTensorSpec\(.*\)\n'
r' <3>: RaggedTensorSpec\(.*\)')
# python 3.5 does not gurantee deterministic iteration of dict contents
# which can lead mismatch on pretty_printed_signature output for "Args"
if sys.version_info >= (3, 6):
c3_summary + '\n' + c3_details)
# pylint: disable=keyword-arg-before-vararg
def func2(x, y=3, *args, **kwargs):
return (x, y, args, kwargs)
c4 = func2.get_concrete_function(scalar, 4, 5, a=scalar)
c4_summary = 'func2(x, y=4, <arg3>=5, *, a)'
self.assertEqual(c4.pretty_printed_signature(verbose=False), c4_summary)
c5 = func2.get_concrete_function(8, vector)
c5_summary = 'func2(x=8, y)'
self.assertEqual(c5.pretty_printed_signature(verbose=False), c5_summary)
def testPrettyPrintedExplicitSignatureWithKeywordArg(self): # b/159639913
def fn(a, b=1):
return a + b
concrete_fn = fn.get_concrete_function()
self.assertEqual(concrete_fn.pretty_printed_signature(False), 'fn(a)')
concrete_fn.pretty_printed_signature(True), 'fn(a)\n'
' Args:\n'
' a: float32 Tensor, shape=<unknown>\n'
' Returns:\n'
' float32 Tensor, shape=<unknown>')
def testIndexedSlicesAsGradientsForConcreteFunctions(self):
def summing_rnn(inputs):
return math_ops.reduce_sum(inputs, axis=1)
def gradients(inputs):
with backprop.GradientTape() as tape:
hidden = summing_rnn(inputs)
hidden = array_ops.gather(hidden, constant_op.constant([0]))
loss = math_ops.reduce_mean(hidden)
return tape.gradient(loss, inputs)
gradients(constant_op.constant([[[1.0], [2.0]]])) # No error is raised
def testFollowTypeHintsTraceBasic(self):
trace_count = [0]
def func(x: ops.Tensor):
trace_count[0] += 1
return x
enabled = def_function.function(func, experimental_follow_type_hints=True)
disabled = def_function.function(func, experimental_follow_type_hints=False)
enabled(1) # Initial call gets traced
self.assertEqual(trace_count[0], 1)
trace_count = [0]
disabled(2) # Retrace
disabled(3) # Retrace
self.assertEqual(trace_count[0], 3)
def testFollowTypeHintsTraceWithArgs(self):
trace_count = [0]
def func(*args: ops.Tensor):
trace_count[0] += 1
return args
enabled = def_function.function(func, experimental_follow_type_hints=True)
disabled = def_function.function(func, experimental_follow_type_hints=False)
args = (
) * 20
args2 = (
) * 20
self.assertEqual(trace_count[0], 1)
trace_count = [0]
disabled(args2) # Retrace
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithKwargs(self):
trace_count = [0]
def func(t: ops.Tensor, **kwargs: ops.Tensor):
del kwargs
trace_count[0] += 1
return t
enabled = def_function.function(func, experimental_follow_type_hints=True)
disabled = def_function.function(func, experimental_follow_type_hints=False)
enabled(1, x=1, y=1.0, z='one')
enabled(2, x=2, y=2.0, z='two')
self.assertEqual(trace_count[0], 1)
trace_count = [0]
disabled(1, x=1, y=1.0, z='one')
disabled(2, x=2, y=2.0, z='two') # Retrace
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithMultipleInputTypes(self):
trace_count = [0]
def func(t: ops.Tensor, *args: ops.Tensor, **kwargs: ops.Tensor):
del args, kwargs
trace_count[0] += 1
return t
enabled = def_function.function(func, experimental_follow_type_hints=True)
disabled = def_function.function(func, experimental_follow_type_hints=False)
enabled(1, constant_op.constant(1), 'str', x=4.0)
enabled(2, constant_op.constant(2), 'str2', x=5.0)
self.assertEqual(trace_count[0], 1)
trace_count = [0]
disabled(1, constant_op.constant(1), 'str', x=4.0)
disabled(2, constant_op.constant(2), 'str2', x=5.0) # Retrace
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithOnlyArgNamed(self):
trace_count = [0]
def func(t: ops.Tensor, i: int = 1, **kwargs): # pylint: disable=bad-whitespace
del i, kwargs
trace_count[0] += 1
return t
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(1, 3, x=4.0, y='str')
enabled(2, 4, x=4.0, y='str') # Retrace
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithNotAllNamed(self):
trace_count = [0]
def func(x, y: ops.Tensor, z: int):
del y, z
trace_count[0] += 1
return x
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(1, 2, 3)
enabled(1, 20, 3) # No retrace - change in ops.Tensor typed arg
enabled(2, 2, 3) # Retrace - change in untyped arg
enabled(2, 2, 4) # Retrace - change in typed arg
self.assertEqual(trace_count[0], 3)
def testFollowTypeHintsTraceWithOnlyArgsNamed(self):
trace_count = [0]
def func(x, y, *args: ops.Tensor):
del y, args
trace_count[0] += 1
return x
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(1, 20, 3, 4, 5, 6)
enabled(1, 20, 3, 4, 5, 60) # No retrace - change in *args
enabled(1, 30, 7, 8, 9, 10) # Retrace - change in args
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithOnlyKwargsNamed(self):
trace_count = [0]
def func(x, y, *args, **kwargs: ops.Tensor):
del y, args, kwargs
trace_count[0] += 1
return x
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(1, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0)
1, 2, 3, 4, 5, 6, a=1.5, b=2.5,
c=3.5) # No retrace - change in **kwargs
enabled(100, 2, 3, 4, 5, 6, a=1.0, b=2.0, c=3.0) # Retrace - change in args
1, 2, 3, 4, 5, 100, a=1.0, b=2.0, c=3.0) # Retrace - change in *args
self.assertEqual(trace_count[0], 3)
def testFollowTypeHintsTraceWithArgsEquals(self):
trace_count = [0]
def func(
x: ops.Tensor = 0, # pylint:disable=bad-whitespace
y: int = 1, # pylint:disable=bad-whitespace
**kwargs: ops.Tensor):
del y, kwargs
trace_count[0] += 1
return x
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(x=1, y=2, z=3)
enabled(x=1, y=3, z=3) # Retrace - change in args
enabled(x=2, y=2, z=4) # No retrace - change in args and **kwargs
enabled(x=2, y=2, z=4, u=5) # Retrace - change in **kwargs
self.assertEqual(trace_count[0], 3)
def testFollowTypeHintsTraceWithArgsEqualsTypedKwargs(self):
trace_count = [0]
def func(x, y, **kwargs: ops.Tensor):
del y, kwargs
trace_count[0] += 1
return x
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(x=1, y=2, z=3)
enabled(x=1, y=3, z=3) # Retrace
enabled(x=1, y=2, z=4) # No retrace
enabled(x=2, y=2, z=4) # Retrace
enabled(x=2, y=2, z=4, u=5) # Retrace
self.assertEqual(trace_count[0], 4)
def testFollowTypeHintsTraceWithArgsEqualsTypedArgs(self):
trace_count = [0]
def func(x: ops.Tensor, y: int, **kwargs):
del y, kwargs
trace_count[0] += 1
return x
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(x=1, y=2, z=3)
enabled(x=1, y=3, z=3) # Retrace
enabled(x=1, y=2, z=4) # Retrace
enabled(x=2, y=2, z=3) # No retrace
enabled(x=2, y=2, z=4, u=5) # Retrace
self.assertEqual(trace_count[0], 4)
def testFollowTypeHintsTraceWithKwOnlyArgsBasic(self):
trace_count = [0]
def func(*, a: ops.Tensor = None, b=1): # pylint: disable=bad-whitespace
del b
trace_count[0] += 1
return a
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(a=1, b=2)
enabled(a=2, b=2) # No retrace
enabled(a=1, b=1) # Retrace
self.assertEqual(trace_count[0], 2)
def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArg(self):
trace_count = [0]
def func(arg: ops.Tensor, *args, kwonly, **kwargs):
del args, kwonly, kwargs
trace_count[0] += 1
return arg
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # No retrace
enabled(1000, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # No retrace
enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace
self.assertEqual(trace_count[0], 4)
def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedArgs(self):
trace_count = [0]
def func(arg, *args: ops.Tensor, kwonly, **kwargs):
del args, kwonly, kwargs
trace_count[0] += 1
return arg
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # No retrace
enabled(1, 200, 300, 400, kwonly=5, kwarg1=6, kwarg2=7) # No retrace
enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace
self.assertEqual(trace_count[0], 4)
def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwOnlyArg(self):
trace_count = [0]
def func(arg, *args, kwonly: ops.Tensor, **kwargs):
del args, kwonly, kwargs
trace_count[0] += 1
return arg
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # No retrace
enabled(1, 2, 3, 4, kwonly=500, kwarg1=6, kwarg2=7) # No retrace
enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # Retrace
self.assertEqual(trace_count[0], 4)
def testFollowTypeHintsTraceWithArgsKwOnlyArgsKwargsAndTypedKwargs(self):
trace_count = [0]
def func(arg, *args, kwonly, **kwargs: ops.Tensor):
del args, kwonly, kwargs
trace_count[0] += 1
return arg
enabled = def_function.function(func, experimental_follow_type_hints=True)
enabled(1, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7)
enabled(100, 2, 3, 4, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 20, 30, 40, kwonly=5, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 2, 3, 4, kwonly=50, kwarg1=6, kwarg2=7) # Retrace
enabled(1, 2, 3, 4, kwonly=5, kwarg1=60, kwarg2=70) # No retrace
enabled(1, 2, 3, 4, kwonly=5, kwarg1=600, kwarg2=700) # No retrace
self.assertEqual(trace_count[0], 4)
def testWithModuleNameScope(self):
self.skipTest('b/166158748:function does not handle this case correctly.')
class Foo(module.Module):
def __init__(self):
self.var = None
def add(self, x, y, z=1):
if self.var is None:
return x + y + z
foo = Foo()
self.assertEqual(foo.add(2, 3), 6)
def testWithModuleNameScopeRedundantArgs(self):
self.skipTest('b/166158748:function does not handle this case correctly.')
class Foo(module.Module):
def __init__(self):
self.var = None
def add(self, x, y):
if self.var is None:
return x + y
foo = Foo()
with self.assertRaisesRegex(TypeError, 'got two values for argument'):
foo.add(2, x=3) # pylint: disable=redundant-keyword-arg,no-value-for-parameter
def testWithModuleNameScopeMissingArgs(self):
self.skipTest('b/166158748:function does not handle this case correctly.')
class Foo(module.Module):
def __init__(self):
self.var = None
def add(self, x, y):
if self.var is None:
return x + y
foo = Foo()
with self.assertRaisesRegex(TypeError, 'missing required arguments: y'):
foo.add(2) # pylint: disable=no-value-for-parameter
def testShapeInferencePropagateConstNestedStack(self):
tensor_spec.TensorSpec((None, None), dtype=dtypes.int32),
tensor_spec.TensorSpec((), dtype=dtypes.int32),
def f(x, s):
old_shape = array_ops.shape(x)
new_shape = array_ops.stack([old_shape[0], s], axis=0)
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
return y
tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32)
def g(x):
y = f(x, s=5)
assert y.shape.as_list() == [3, 5], y.shape.as_list()
return y
g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5]))
def testShapeInferencePropagateConstNestedUnstackStack(self):
tensor_spec.TensorSpec((None, None), dtype=dtypes.int32),
tensor_spec.TensorSpec((), dtype=dtypes.int32),
def f(x, s):
s0, _ = array_ops.unstack(array_ops.shape(x), axis=0)
new_shape = array_ops.stack([s0, s], axis=0)
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
return y
tensor_spec.TensorSpec(shape=(3, 6), dtype=dtypes.int32)
def g(x):
y = f(x, s=5)
assert y.shape.as_list() == [3, 5], y.shape.as_list()
return y
g(array_ops.zeros([3, 6], dtype=dtypes.int32)), array_ops.ones([3, 5]))
def testShapeInferencePropagateConstNestedConcat(self):
tensor_spec.TensorSpec((), dtype=dtypes.int32),
tensor_spec.TensorSpec((), dtype=dtypes.int32),
tensor_spec.TensorSpec((), dtype=dtypes.int32),
def f(d1, d2, d3):
new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1)
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
return y
def g():
y = f(1, 2, 3)
assert y.shape.as_list() == [1, 2, 3], y.shape.as_list()
return y
self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
def testShapeInferencePropagateConstDoubleNested(self):
tensor_spec.TensorSpec((), dtype=dtypes.int32),
tensor_spec.TensorSpec((), dtype=dtypes.int32),
tensor_spec.TensorSpec((), dtype=dtypes.int32),
def f(d1, d2, d3):
new_shape = array_ops.concat([[d1], [d2], [d3]], axis=-1)
y = array_ops.ones(shape=new_shape, dtype=dtypes.int32)
return y
def g():
y = def_function.function(f)(1, 2, 3)
assert y.shape.as_list() == [1, 2, 3], y.shape.as_list()
return y
self.assertAllEqual(g(), array_ops.ones([1, 2, 3]))
def testControlDependencyAfterInline(self):
v = variables.Variable(0.)
def assign():
return v.assign(1.)
def assign_add():
return v.assign_add(1.)
def f():
check_ops.assert_equal_v2(assign(), 1.)
check_ops.assert_equal_v2(assign_add(), 2.)
# We don't have a way to inspect the inlined graph in Python, so we run it
# multiple times to have more confidence the dependency is correct.
for _ in range(30):
def testReadInFuncWriteOutside(self):
# Run many times since we are testing for a potential race condition.
for _ in range(30):
# pylint: disable=cell-var-from-loop
v = variables.Variable(1.)
def add_one():
return v + 1.
def get_v_plus_one():
v_plus_one = add_one()
return v_plus_one
self.assertAllEqual(get_v_plus_one(), 2.0)
class MultiDeviceTest(test.TestCase, parameterized.TestCase):
def testMultiDeviceOutput(self):
"""Tests that functions can produce outputs on multiple devices."""
def func(a, b, transpose_a):
with ops.device('/device:CPU:0'):
m1 = math_ops.matmul(a, b, transpose_a=transpose_a)
with ops.device('/device:GPU:0'):
m2 = math_ops.matmul(a, b, transpose_a=transpose_a)
return m1, m2
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
m1, m2 = func(t, t, transpose_a=True)
self.assertAllEqual(m1.numpy(), [[10, 14], [14, 20]])
self.assertRegex(m1.backing_device, 'CPU')
self.assertAllEqual(m2.numpy(), [[10, 14], [14, 20]])
self.assertRegex(m2.backing_device, 'GPU')
def testEmptyBody(self):
def func(a, b):
return b, a
with ops.device('/device:CPU:0'):
a = array_ops.identity(3.0)
with ops.device('/device:GPU:0'):
b = array_ops.identity(5.0)
m1, m2 = func(a, b)
self.assertAllEqual(m1.numpy(), 5.0)
self.assertRegex(m1.backing_device, 'GPU')
self.assertAllEqual(m2.numpy(), 3.0)
self.assertRegex(m2.backing_device, 'CPU')
def testMultiDeviceInt32(self):
"""Tests that multi-device functions can take and output INT32s.
When an INT32 device tensor is fed into a function, it is copied to CPU
by the eager runtime. The function sees all INT32 inputs on CPU.
We set allocator attribute 'on_host' for INT32 outputs. They can be
partitioned into the GPU component function, but will be allocated on
CPU nevertheless.
There is experimental support for `ints_on_device` in
FunctionLibraryRuntime now. We can try that.
with ops.device('/device:CPU:0'):
int_cpu = constant_op.constant(3, dtype=dtypes.int32)
resource = resource_variable_ops.ResourceVariable(5, dtype=dtypes.int32)
with ops.device('/device:GPU:0'):
int_gpu = constant_op.constant(7, dtype=dtypes.int32)
def func(int_cpu, resource, int_gpu):
with ops.device('/device:CPU:0'):
m1 = int_cpu * resource + int_gpu
with ops.device('/device:GPU:0'):
# This computation will happen on GPU but m2 will be copied to CPU.
m2 = int_gpu * resource + int_cpu + 1
return m1, m2
m1, m2 = func(int_cpu, resource, int_gpu)
self.assertAllEqual(m1.numpy(), 22)
self.assertRegex(m1.backing_device, 'CPU')
self.assertAllEqual(m2.numpy(), 39)
self.assertRegex(m2.backing_device, 'CPU')
# flip arguments
m1, m2 = func(int_gpu, resource, int_cpu)
self.assertAllEqual(m1.numpy(), 38)
self.assertRegex(m1.backing_device, 'CPU')
self.assertAllEqual(m2.numpy(), 23)
self.assertRegex(m2.backing_device, 'CPU')
def testMultiDeviceColocateWith(self):
"""Tests that function's outputs respect colocation constraints."""
def func(a, b):
with ops.colocate_with(a):
ra = 2 * a
with ops.colocate_with(b):
rb = 3 * b
return ra, rb
devices = ['/device:CPU:0', '/device:GPU:0']
for dev1, dev2 in itertools.product(devices, devices):
with ops.device(dev1):
a = array_ops.identity(1.0)
with ops.device(dev2):
b = array_ops.identity(10.0)
ra, rb = func(a, b)
self.assertEqual(ra.numpy(), 2.0)
self.assertRegex(ra.backing_device, dev1)
self.assertEqual(rb.numpy(), 30.0)
self.assertRegex(rb.backing_device, dev2)
def testMultiDeviceResources(self):
with ops.device('/device:CPU:0'):
c1 = resource_variable_ops.ResourceVariable(2.0)
c2 = resource_variable_ops.ResourceVariable(7.0)
with ops.device('/device:GPU:0'):
g1 = resource_variable_ops.ResourceVariable(3.0)
g2 = resource_variable_ops.ResourceVariable(5.0)
def func(resource1, resource2):
with ops.device('/device:CPU:0'):
result1 = resource1 * g2
with ops.device('/device:GPU:0'):
result2 = resource2 * c2
return result1, result2
r1, r2 = func(c1, g1)
self.assertEqual(r1.numpy(), 10.0)
self.assertRegex(r1.backing_device, 'CPU')
self.assertEqual(r2.numpy(), 21.0)
self.assertRegex(r2.backing_device, 'GPU')
# Call with flipped inputs. Check that we look at resource's
# device and reinstantiates the function when inputs' devices change.
r1, r2 = func(g1, c1)
self.assertEqual(r1.numpy(), 15.0)
self.assertRegex(r1.backing_device, 'CPU')
self.assertEqual(r2.numpy(), 14.0)
self.assertRegex(r2.backing_device, 'GPU')
def testOutputResources(self):
with ops.device('/device:CPU:0'):
c1 = resource_variable_ops.ResourceVariable(2.0)
with ops.device('/device:GPU:0'):
g1 = resource_variable_ops.ResourceVariable(3.0)
def func(resource1, resource2):
with ops.device('/device:CPU:0'):
result1 = resource1 * 5
with ops.device('/device:GPU:0'):
result2 = resource2 * 7
return result1, resource1.handle, result2, resource2.handle
r1, res1, r2, res2 = func(c1, g1)
self.assertEqual(r1.numpy(), 10.0)
self.assertRegex(r1.backing_device, 'CPU')
self.assertEqual(r2.numpy(), 21.0)
self.assertRegex(r2.backing_device, 'GPU')
def check_handle(handle, expected_value):
self.assertRegex(handle.backing_device, 'CPU')
tensor = gen_resource_variable_ops.read_variable_op(
handle, dtypes.float32)
self.assertEqual(tensor.numpy(), expected_value)
# Check that handles returned from functions are on CPU and an op using
# the resource handle is correctly placed on the device backing the
# resource.
check_handle(res1, 2.0)
check_handle(res2, 3.0)
# Call with flipped inputs to make sure the same the function is
# reinstantiated and eager runtime does not mess up the device assignment
# for ops consuming handles returned from defuns.
r1, res1, r2, res2 = func(g1, c1)
self.assertEqual(r1.numpy(), 15.0)
self.assertRegex(r1.backing_device, 'CPU')
self.assertEqual(r2.numpy(), 14.0)
self.assertRegex(r2.backing_device, 'GPU')
check_handle(res1, 3.0)
check_handle(res2, 2.0)
def testPassResourceThroughNestedFunctionCall(self):
"""Test passing GPU resource to noinline function call placed on CPU.
PartitionedCallOp must not enforce any particular device assignment for the
resource output. Inner function marked as `_nospecialize`, so Grappler would
not prune unused function output.
with ops.device('/device:GPU:0'):
g1 = resource_variable_ops.ResourceVariable(3.0)
'_noinline': True,
'_nospecialize': True
def inner(resource1):
return resource1 * 2, resource1.handle
def outer(resource1):
with ops.device('/device:CPU:0'):
r1, _ = inner(resource1)
return r1
r1 = outer(g1)
self.assertEqual(r1.numpy(), 6.0)
self.assertRegex(r1.backing_device, 'CPU')
def testReturnResourceFromNestedFunctionCall(self):
"""Test returning GPU resource from noinline function call placed on CPU.
When inferring output devices for the return value, do not set a device for
returns of DT_RESOURCE data type based on the device assignment of the node
that produced that resource. As an example function call placed on CPU can
return resources on GPU.
with ops.device('/device:GPU:0'):
g1 = resource_variable_ops.ResourceVariable(3.0)
'_noinline': True
def inner(resource1):
return resource1 * 2, resource1.handle
def outer(resource1):
with ops.device('/device:CPU:0'):
r1, res1 = inner(resource1)
return r1, res1
r1, res1 = outer(g1)
self.assertEqual(r1.numpy(), 10.0)
self.assertRegex(r1.backing_device, 'CPU')
def check_handle(handle, expected_value):
self.assertRegex(handle.backing_device, 'CPU')
tensor = gen_resource_variable_ops.read_variable_op(
handle, dtypes.float32)
self.assertEqual(tensor.numpy(), expected_value)
# Check that handles returned from functions are on CPU and an op using
# the resource handle is correctly placed on the device backing the
# resource.
check_handle(res1, 5.0)
def testComplexInputOutputDevicePattern(self):
"""Tests input/output mapping logic in partitioning."""
with ops.device('/device:CPU:0'):
rc0 = resource_variable_ops.ResourceVariable(2.0)
rc1 = resource_variable_ops.ResourceVariable(3.0)
cc0 = array_ops.identity(5.0)
cc1 = array_ops.identity(7.0)
with ops.device('/device:GPU:0'):
rg0 = resource_variable_ops.ResourceVariable(11.0)
rg1 = resource_variable_ops.ResourceVariable(13.0)
cg0 = array_ops.identity(17.0)
cg1 = array_ops.identity(19.0)
# Make sure tensors are on expected devices.
for tensor in [cc0, cc1]:
self.assertRegex(tensor.backing_device, 'CPU:0')
for tensor in [cg0, cg1]:
self.assertRegex(tensor.backing_device, 'GPU:0')
def func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1):
with ops.device('/device:CPU:0'):
m1 = rc0 * cg0
with ops.device('/device:GPU:0'):
m2 = rg0 * cc0
with ops.device('/device:CPU:0'):
r1 = 1000.0 * m2 + rc1 * cg1
with ops.device('/device:GPU:0'):
r2 = 1000.0 * m1 + rg1 * cc1
return r1, r2, m2, m1
r1, r2, m2, m1 = func(rc0, cc0, cg0, rc1, cg1, rg0, rg1, cc1)
self.assertRegex(m1.backing_device, 'CPU')
self.assertRegex(r1.backing_device, 'CPU')
self.assertRegex(m2.backing_device, 'GPU')
self.assertRegex(r2.backing_device, 'GPU')
self.assertEqual(m1.numpy(), 34.0)
self.assertEqual(r1.numpy(), 55000.0 + 3.0 * 19.0)
self.assertEqual(m2.numpy(), 55.0)
self.assertEqual(r2.numpy(), 34000.0 + 13.0 * 7.0)
def testArgumentPruning(self):
"""Tests functions taking unnecessary arguments."""
with ops.device('/device:CPU:0'):
c1 = constant_op.constant(5.0)
c2 = constant_op.constant(7.0)
with ops.device('/device:GPU:0'):
g1 = constant_op.constant(11.0)
g2 = constant_op.constant(13.0)
g3 = constant_op.constant(17.0)
def func(g1, g2, c1, g3, c2): # pylint: disable=unused-argument
# arguments g1 and g2 are unused and can be pruned by grappler.
return c1 * g3 * c2
result = func(g1, g2, c1, g3, c2)
self.assertEqual(result.numpy(), 5.0 * 7.0 * 17.0)
def testNestedCallWatchedVariables(self):
v = variables.Variable(4.)
def f():
return v ** 2.
with backprop.GradientTape() as tape:
self.assertEqual((v,), tape.watched_variables())
def g():
return f()
with backprop.GradientTape() as tape:
self.assertEqual((v,), tape.watched_variables())
# f() can rely on the variable being read during its trace. g() checks that
# variables from a function which knows about them are recorded on the
# tape. h() tests that functions forward knowledge of variables to callers.
def h():
return g()
with backprop.GradientTape() as tape:
self.assertEqual((v,), tape.watched_variables())
def testDeferredCapture(self):
value = 1.0
def lazy_capture(x):
y = ops.get_default_graph().capture_call_time_value(
lambda: value, tensor_spec.TensorSpec(None))
return x + y
self.assertAllEqual(lazy_capture(2.0), 3.0)
# After changing the value of `value` the function call should return a
# different result.
value = 2.0
self.assertAllEqual(lazy_capture(2.0), 4.0)
def testDeferredCaptureWithKey(self):
value0 = 1.0
value1 = 2.0
def lazy_capture(x):
w = ops.get_default_graph().capture_call_time_value(
lambda: value0, tensor_spec.TensorSpec(None), key=0)
y = ops.get_default_graph().capture_call_time_value(
lambda: value1, tensor_spec.TensorSpec(None), key=1)
def bad_closure():
raise ValueError('Should not run')
z = ops.get_default_graph().capture_call_time_value(
bad_closure, tensor_spec.TensorSpec(None), key=1)
return x + y + w + z
self.assertAllEqual(lazy_capture(2.0), 7.0)
value0 = 2.0
value1 = 3.0
self.assertAllEqual(lazy_capture(2.0), 10.0)
def testDeferredCaptureTypeError(self):
value = constant_op.constant(1.0)
def lazy_capture(x):
y = ops.get_default_graph().capture_call_time_value(
lambda: value, tensor_spec.TensorSpec(()))
return x + y
self.assertAllEqual(lazy_capture(2.0), 3.0)
# dtype mismatch
value = constant_op.constant(1)
with self.assertRaisesRegex(ValueError, 'Value .* to a tensor with dtype'):
# shape mismatch
value = constant_op.constant([1.0])
with self.assertRaisesRegex(ValueError, 'Value .* shape'):
def testDeferredCaptureReturnNestWithCompositeTensor(self):
i_s = indexed_slices.IndexedSlices(
constant_op.constant([1, 2]),
constant_op.constant([0, 1], dtype=dtypes.int64),
r_t = ragged_factory_ops.constant([[[1, 2], [3]], [[4, 5, 6]]])
s_t = sparse_tensor.SparseTensor(
values=[1, 2, 3], indices=[[0], [8], [10]], dense_shape=[20])
def lazy_capture():
y = ops.get_default_graph().capture_call_time_value(
lambda: {'i': i_s, 't': (r_t, s_t)},
{'i': indexed_slices.IndexedSlicesSpec(
dtype=dtypes.int32, dense_shape_dtype=dtypes.int32),
't': (ragged_tensor.RaggedTensorSpec([2, None, None], dtypes.int32),
sparse_tensor.SparseTensorSpec([None], dtypes.int32))})
return y['i'], y['t']
i, (r, s) = lazy_capture()
self.assertAllEqual(i_s.values, i.values)
self.assertAllEqual(i_s.indices, i.indices)
self.assertAllEqual(i_s.dense_shape, i.dense_shape)
self.assertAllEqual(r_t, r)
self.assertAllEqual(s_t.indices, s.indices)
self.assertAllEqual(s_t.values, s.values)
self.assertAllEqual(s_t.dense_shape, s.dense_shape)
def testDeferredCaptureCompositeTensorSpecTypeMismatch(self):
value = indexed_slices.IndexedSlices(
constant_op.constant([1, 2]),
constant_op.constant([0, 1], dtype=dtypes.int64))
def lazy_capture():
return ops.get_default_graph().capture_call_time_value(
lambda: value,
# Type matches spec.
# Extra dense shape component.
value = indexed_slices.IndexedSlices(
constant_op.constant([1, 2]),
constant_op.constant([0, 1], dtype=dtypes.int64),
with self.assertRaises(ValueError):
# Index dtype mismatch int32 vs. int64.
value = indexed_slices.IndexedSlices(
constant_op.constant([1, 2]),
constant_op.constant([0, 1]))
with self.assertRaises(ValueError):
if __name__ == '__main__':