blob: 6a8582183bc3f56617ceda9c70abbb8a8bdbbcae [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import itertools
import pickle
import re
import weakref
from absl.testing import parameterized
from six.moves import range
from tensorflow.python.autograph.core import converter
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.eager import lift_to_graph
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import core
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import adam
class _ModelWithOptimizer(training.Model):
def __init__(self):
super(_ModelWithOptimizer, self).__init__()
self.dense = core.Dense(1)
self.optimizer = adam.AdamOptimizer(0.01)
@def_function.function(
input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32),
tensor_spec.TensorSpec([None], dtypes.float32)))
def call(self, x, y):
with backprop.GradientTape() as tape:
loss = math_ops.reduce_mean((self.dense(x) - y) ** 2.)
trainable_variables = self.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
return {'loss': loss}
def undecorated_function(x):
return x * 3.
class _HasDecoratedMethod(object):
@def_function.function
def f(self, x):
return x * 3.
class DefFunctionTest(test.TestCase, parameterized.TestCase):
def testNoVariables(self):
@def_function.function
def fn(x):
return 2 * x
self.assertAllEqual(fn(constant_op.constant(4.0)), 8.0)
def testFailIfVariablesAreCreatedMoreThanOnce(self):
@def_function.function
def fn(x):
return variables.Variable(1.0) + x
with self.assertRaises(ValueError):
fn(1.0)
def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self):
state = []
@def_function.function
def fn(x):
state.append(variables.Variable(1.0))
return state[-1] + x
with self.assertRaises(ValueError):
fn(1.0)
def testRange(self):
@def_function.function
def f(unused_x):
return 1.0
self.assertAllEqual(f(range(5)), 1.0)
def testCorrectVariableCreation(self):
state = []
@def_function.function
def fn(x):
if not state:
state.append(variables.Variable(2.0))
return state[0] * x
self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)
def testFunctionInitializer(self):
state = []
@def_function.function
def fn(x):
if not state:
state.append(variables.Variable(lambda: 2.0))
return state[0] * x
self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
def testFunctionMultipleVariableInitializer(self):
state = []
@def_function.function
def fn(x):
if not state:
state.append(variables.Variable(lambda: 2.0))
state.append(variables.Variable(lambda: 5.0))
return state[0] * x, state[1] * x
self.assertAllEqual(fn(constant_op.constant(1.0)), [2.0, 5.0])
def testFunctionInitializationFunction(self):
state = []
@def_function.function
def fn(x):
if not state:
state.append(variables.Variable(2.0))
return state[0] * x
init_fn = fn.get_initialization_function(constant_op.constant(1.0))
self.assertLen(state, 1)
self.assertFalse(
resource_variable_ops.var_is_initialized_op(state[0].handle))
init_fn()
self.assertEqual(state[0].numpy(), 2.0)
def testVariableInitializerNotConstant(self):
state = []
@def_function.function
def fn(x):
if not state:
state.append(variables.Variable(2.0 * x))
return state[0] * x
self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)
def testLegacyGraphModeVariables(self):
with ops.Graph().as_default(), self.test_session() as sess:
state = []
@def_function.function
def fn(x):
if not state:
state.append(variables.Variable(2.0))
return state[0] * x
result = fn(3.0)
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual(sess.run(state[0]), 2.0)
self.assertAllEqual(self.evaluate(result), 6.0)
def testLegacyGraphModeVariablesNonTrivialInitializer(self):
with ops.Graph().as_default(), self.test_session() as sess:
state = []
@def_function.function
def fn(x):
if not state:
two = constant_op.constant(2.0)
four = two * two
two_again = math_ops.sqrt(four)
state.append(variables.Variable(two_again + four))
return state[0] * x
result = fn(3.0)
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual(sess.run(state[0]), 6.0)
self.assertAllEqual(self.evaluate(result), 18.0)
def testLegacyGraphModeInputDependentInitializerFails(self):
with ops.Graph().as_default():
state = []
@def_function.function
def fn(x):
if not state:
state.append(variables.Variable(2.0 * x))
return state[0] * x
with self.assertRaisesRegexp(
lift_to_graph.UnliftableError, r'transitively.* mul .* x'):
fn(constant_op.constant(3.0))
def testMethod(self):
class MyModel(object):
def __init__(self):
self.var = None
@def_function.function
def apply(self, x):
if self.var is None:
self.var = variables.Variable(2.0)
return self.var * x
m0 = MyModel()
self.assertAllEqual(m0.apply(3.0), 6.0)
# Calling twice to exercise that we do not recreate variables.
m0.var.assign(3.0)
self.assertAllEqual(m0.apply(3.0), 9.0)
m1 = MyModel()
self.assertAllEqual(m1.apply(3.0), 6.0)
def test_functools_partial(self):
self.assertAllClose(
3.,
def_function.function(functools.partial(lambda x, y: x + y, 1.))(
constant_op.constant(2.)))
def test_functools_partial_new_default(self):
def f(x=3, y=7):
return x + y
func = def_function.function(functools.partial(f, y=6))
self.assertEqual(func().numpy(), 9)
self.assertEqual(func(y=8).numpy(), 11)
def test_functools_partial_keywords(self):
def f(x, y):
return x + y
func = def_function.function(
functools.partial(f, x=array_ops.zeros([1]), y=array_ops.zeros([1])))
self.assertAllEqual(func(), [0.0])
def test_functools_partial_single_positional(self):
def f(x, y):
return x + y
func = def_function.function(
functools.partial(f, constant_op.constant(1)))
self.assertAllEqual(func(5), 6)
def test_complicated_partial_with_defaults(self):
def identity(*args):
return args
def dynamic_unroll(core_fn,
input_sequence,
initial_state,
sequence_length=None,
parallel_iterations=1,
swap_memory=False):
del core_fn
self.assertIs(None, sequence_length)
self.assertEqual(1, parallel_iterations)
self.assertTrue(swap_memory)
return input_sequence, initial_state
input_sequence = random_ops.random_uniform([1, 1, 1])
initial_state = random_ops.random_uniform([1, 1])
func = def_function.function(
functools.partial(dynamic_unroll, identity, swap_memory=True))
func(input_sequence, initial_state)
def test_unspecified_default_argument(self):
wrapped = def_function.function(
lambda x, y=2: x + y,
input_signature=[tensor_spec.TensorSpec((), dtypes.int32)])
self.assertEqual(3, wrapped(constant_op.constant(1)).numpy())
def test_optimizer(self):
x = constant_op.constant([[3., 4.]])
y = constant_op.constant([2.])
model = _ModelWithOptimizer()
model(x, y)
def test_concrete_function_from_signature(self):
@def_function.function(
input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
def compute(x):
return 2. * x
concrete = compute.get_concrete_function()
self.assertAllClose(1., concrete(constant_op.constant(0.5)))
concrete = compute.get_concrete_function(
tensor_spec.TensorSpec(None, dtypes.float32))
self.assertAllClose(4., concrete(constant_op.constant(2.)))
signature_args, _ = concrete.structured_input_signature
self.assertEqual(signature_args,
(tensor_spec.TensorSpec(
None, dtypes.float32, name='x'),))
@test_util.run_in_graph_and_eager_modes
def test_variable_naming(self):
class HasVars(module.Module):
def __init__(self):
self.x = None
self.y = None
self.z = None
@def_function.function
def make_x(self):
if self.x is None:
self.x = variables.Variable(1., name='v')
def make_y(self):
if self.y is None:
self.y = variables.Variable(1., name='v')
def make_z(self):
if self.z is None:
with ops.name_scope('z_scope', skip_on_eager=False):
self.z = variables.Variable(1., name='z')
root = HasVars()
root.make_x()
root.make_y()
root.make_z()
self.assertEqual('v:0', root.x.name)
self.assertEqual('z_scope/z:0', root.z.name)
def test_concrete_function_keyword_arguments(self):
@def_function.function
def f(x):
return x
conc = f.get_concrete_function(
tensor_spec.TensorSpec(None, dtypes.float32, 'y'))
conc(y=constant_op.constant(3.0))
signature_args, _ = conc.structured_input_signature
self.assertEqual('y', signature_args[0].name)
conc = f.get_concrete_function(tensor_spec.TensorSpec(None, dtypes.float32))
conc(x=constant_op.constant(3.0))
signature_args, _ = conc.structured_input_signature
self.assertEqual('x', signature_args[0].name)
@def_function.function
def g(x):
return x[0]
conc = g.get_concrete_function(
[tensor_spec.TensorSpec(None, dtypes.float32, 'z'), 2])
conc(z=constant_op.constant(3.0))
signature_args, _ = conc.structured_input_signature
self.assertEqual('z', signature_args[0][0].name)
with self.assertRaisesRegexp(
ValueError, 'either zero or all names have to be specified'):
conc = g.get_concrete_function([
tensor_spec.TensorSpec(None, dtypes.float32, 'z'),
tensor_spec.TensorSpec(None, dtypes.float32),
])
def test_error_inner_capture(self):
@def_function.function
def f(inputs):
num_steps, _ = inputs.shape[:2]
outputs = []
for t in math_ops.range(num_steps):
outputs.append(inputs[t])
return outputs
with self.assertRaisesRegexp(errors.InaccessibleTensorError,
'defined in another function or code block'):
f(array_ops.zeros(shape=(8, 42, 3)))
def testRuntimeErrorNotSticky(self):
@def_function.function
def fail(i):
control_flow_ops.Assert(math_ops.equal(i, 0), ['ick'])
fail(constant_op.constant(0)) # OK
with self.assertRaises(errors.InvalidArgumentError):
fail(constant_op.constant(1)) # InvalidArgument: "ick"
fail(constant_op.constant(0)) # OK
def testUnderscoreName(self):
@def_function.function
def f(_):
return _ + _
self.assertAllEqual(2.0, f(constant_op.constant(1.0)))
def test_serialization_signature_cache(self):
@def_function.function
def f(x, y):
return x, y
f(constant_op.constant([[3., 4.]]), constant_op.constant([2.]))
f(constant_op.constant([[3, 4, 5]]), constant_op.constant([2]))
signatures_args = set()
concrete_functions = f._list_all_concrete_functions_for_serialization()
for concrete_function in concrete_functions:
args, kwargs = concrete_function.structured_input_signature
signatures_args.add(args)
self.assertEqual(dict(), kwargs)
self.assertEqual(
signatures_args,
set(((tensor_spec.TensorSpec([1, 2], dtypes.float32, name='x'),
tensor_spec.TensorSpec([1], dtypes.float32, name='y')),
(tensor_spec.TensorSpec([1, 3], dtypes.int32, name='x'),
tensor_spec.TensorSpec([1], dtypes.int32, name='y')))))
@test_util.assert_no_garbage_created
def testFunctionReferenceCycles(self):
fn = def_function.function(lambda x: 2. * x)
fn(constant_op.constant(4.0))
weak_fn = weakref.ref(fn)
del fn
# Tests that the weak reference we made to the function is now dead, which
# means the object has been deleted. This should be true as long as the
# function itself is not involved in a reference cycle.
self.assertIs(None, weak_fn())
@test_util.assert_no_garbage_created
def testMethodReferenceCycles(self):
has_decorated_method = _HasDecoratedMethod()
has_decorated_method.f(constant_op.constant(5.))
weak_fn = weakref.ref(has_decorated_method.f)
del has_decorated_method
# Tests that the weak reference we made to the function is now dead, which
# means the object has been deleted. This should be true as long as the
# function itself is not involved in a reference cycle.
self.assertIs(None, weak_fn())
@test_util.assert_no_new_pyobjects_executing_eagerly
def testErrorMessageWhenGraphTensorIsPassedToEager(self):
@def_function.function
def failing_function():
a = constant_op.constant(1.)
with ops.init_scope():
_ = a + a
with self.assertRaisesRegexp(
TypeError,
re.compile('An op outside of the function.*passed.*Const', re.DOTALL)):
failing_function()
def testNonUniqueNamesGetConcreteFunction(self):
@def_function.function
def non_unique_arg_names(x, **kwargs):
a, b, c = x
d = kwargs['d']
return a + b + c + d
concrete = non_unique_arg_names.get_concrete_function(
(tensor_spec.TensorSpec(None, dtypes.float32),
tensor_spec.TensorSpec(None, dtypes.float32),
tensor_spec.TensorSpec(None, dtypes.float32)),
d=tensor_spec.TensorSpec(None, dtypes.float32))
self.assertAllClose(
10.,
concrete(x=constant_op.constant(1.),
x_1=constant_op.constant(2.),
x_2=constant_op.constant(3.),
d=constant_op.constant(4.)))
self.assertAllClose(
10.,
concrete(constant_op.constant(1.),
constant_op.constant(2.),
constant_op.constant(3.),
constant_op.constant(4.)))
def testVariableCreatorScope(self):
created_variables = []
captured_variables = []
@def_function.function
def f():
if not created_variables:
created_variables.append(variables.Variable(1.))
return created_variables[0] + 1.
def capture_creator(next_creator, **kwargs):
created = next_creator(**kwargs)
captured_variables.append(created)
return created
with variable_scope.variable_creator_scope(capture_creator):
f()
self.assertEqual(created_variables, captured_variables)
def testVarAlreadyInitializedNoClobbering(self):
v_holder = []
@def_function.function
def add_var(x):
if not v_holder:
v = variables.Variable([1., 2.])
v_holder.append(v)
already_initialized = variables.Variable(3.)
with ops.init_scope():
already_initialized.assign(10.)
v_holder.append(already_initialized)
return v_holder[0] + v_holder[1] + x
add_var.get_concrete_function(constant_op.constant(2.))
self.assertAllClose([13., 14.], add_var(constant_op.constant(2.)))
def testSameVariableTwice(self):
v = variables.Variable(1.0)
@def_function.function
def add(a, b):
return a + b
self.assertAllEqual(add(v, v), 2.0)
def testVariableUpdate(self):
v1 = variables.Variable(1.0)
v2 = variables.Variable(2.0)
v3 = variables.Variable(4, dtype=dtypes.int32)
trace_count = [0]
@def_function.function
def double_variable(x):
trace_count[0] += 1
x.assign_add(x.read_value())
self.assertEqual(trace_count[0], 0)
double_variable(v1)
self.assertEqual(trace_count[0], 1)
self.assertEqual(self.evaluate(v1), 2.0)
double_variable(v2)
self.assertEqual(trace_count[0], 1 if ops.Tensor._USE_EQUALITY else 2)
self.assertEqual(self.evaluate(v2), 4.0)
double_variable(v3)
self.assertEqual(trace_count[0], 2 if ops.Tensor._USE_EQUALITY else 3)
self.assertEqual(self.evaluate(v3), 8)
def testShapeCache(self):
@def_function.function
def func(x):
return 2 * x
func_a = func.get_concrete_function(
tensor_spec.TensorSpec([None], dtypes.int32))
func_b = func.get_concrete_function(
tensor_spec.TensorSpec([None], dtypes.int32))
self.assertIs(func_a, func_b)
def testInitializationInNestedCall(self):
v_holder = []
@def_function.function
def add_var(x):
if not v_holder:
v = variables.Variable([1., 2.])
v_holder.append(v)
already_initialized = variables.Variable(3.)
with ops.init_scope():
already_initialized.assign(10.)
v_holder.append(already_initialized)
return v_holder[0] + v_holder[1] + x
@def_function.function
def wrapper(x):
return add_var(x)
self.assertAllClose([13., 14.], wrapper(constant_op.constant(2.)))
v_holder[1].assign(11.)
self.assertAllClose([14., 15.], wrapper(constant_op.constant(2.)))
@test_util.run_gpu_only
def testDeviceAnnotationRespected(self):
a = []
@def_function.function()
def create_variable():
with ops.init_scope():
initial_value = random_ops.random_uniform(
(2, 2), maxval=1000000, dtype=dtypes.int64)
if not a:
with ops.device('CPU:0'):
a.append(resource_variable_ops.ResourceVariable(initial_value))
return a[0].read_value()
create_variable()
self.assertRegexpMatches(a[0].device, 'CPU')
@test_util.run_gpu_only
def testDeviceAnnotationForInitializerRespected(self):
a = []
initial_value = []
def initial_value_fn():
initial_value.append(random_ops.random_uniform((2, 3)))
return initial_value[0]
@def_function.function()
def create_variable():
with ops.init_scope():
if not a:
a.append(variables.Variable(initial_value_fn))
with ops.device('CPU:0'):
create_variable()
self.assertRegexpMatches(a[0].device, 'CPU')
self.assertRegexpMatches(initial_value[0].device, 'CPU')
def testDecorate(self):
func = def_function.function(lambda: 1)
def decorator(f):
return lambda: 1 + f()
func._decorate(decorator)
self.assertEqual(func().numpy(), 2)
@parameterized.parameters(*itertools.product(
(None, (tensor_spec.TensorSpec([]),)), # input_signature
(True, False), # autograph
(None, converter.Feature.ALL), # autograph_options
(None, 'foo.bar'), # implements
(None, True, False), # relax_shapes
(True, False), # compile
(True, False), # override_function
))
def testClone(self, input_signature, autograph, autograph_options, implements,
relax_shapes, compile_, override_function):
original_py_function = lambda x: x
compile_ = False
func = def_function.function(
func=original_py_function,
input_signature=input_signature,
autograph=autograph,
experimental_implements=implements,
experimental_autograph_options=autograph_options,
experimental_relax_shapes=relax_shapes,
experimental_compile=compile_)
if override_function:
cloned_py_function = lambda x: x + 1
else:
cloned_py_function = original_py_function
cloned = func._clone(python_function=cloned_py_function)
self.assertEqual(cloned_py_function, cloned._python_function)
self.assertEqual(func._name, cloned._name)
self.assertEqual(input_signature, cloned._input_signature)
self.assertEqual(autograph, cloned._autograph)
self.assertEqual(implements, cloned._implements)
self.assertEqual(autograph_options, cloned._experimental_autograph_options)
self.assertEqual(relax_shapes, cloned._experimental_relax_shapes)
self.assertEqual(compile_, cloned._experimental_compile)
# This test does not run with XLA JIT support linked in so we can only check
# the output of the function if compile is disabled.
if not compile_:
x = array_ops.zeros([])
self.assertEqual(self.evaluate(cloned(x)),
self.evaluate(cloned_py_function(x)))
def testLiftPlaceholderInitializedVariable(self):
with ops.Graph().as_default():
var_list = []
@def_function.function
def use_variable():
if not var_list:
initial_value = array_ops.placeholder(shape=[], dtype=dtypes.float32)
v = variables.Variable(initial_value)
var_list.append(v)
return var_list[0] + 1.
var_plus_one = use_variable()
with self.session() as session:
init_op = var_list[0].initializer
session.run(init_op, feed_dict={init_op.inputs[1]: 2.})
self.assertEqual(3., session.run(var_plus_one))
def testDecorate_rejectedAfterTrace(self):
func = def_function.function(lambda: 1)
self.assertEqual(func().numpy(), 1)
msg = 'Functions cannot be decorated after they have been traced.'
with self.assertRaisesRegexp(ValueError, msg):
func._decorate(lambda f: f)
def testGetConcreteFunctionGraphLifetime(self):
@def_function.function
def func():
pass
graph = func.get_concrete_function().graph
del func
# If the graph is deleted, then an exception is raised on reading `captures`
self.assertEmpty(graph.captures)
@parameterized.parameters(*itertools.product(
(None, (tensor_spec.TensorSpec([]),)), # input_signature
(True, False), # autograph
(None, converter.Feature.ALL), # autograph_options
(None, 'foo.bar'), # implements
(None, True, False), # relax_shapes
))
def test_pickle(self, input_signature, autograph, autograph_options, implements,
relax_shapes):
"""@function objects can be pickled and unpickled."""
# Can't pickle functions in __main__:
from tensorflow.python.eager.def_function_test import undecorated_function
original_py_function = undecorated_function
func = def_function.function(
func=original_py_function,
input_signature=input_signature,
autograph=autograph,
experimental_implements=implements,
experimental_autograph_options=autograph_options,
experimental_relax_shapes=relax_shapes,
)
cloned = pickle.loads(pickle.dumps(func))
self.assertEqual(func._name, cloned._name)
self.assertEqual(input_signature, cloned._input_signature)
self.assertEqual(autograph, cloned._autograph)
self.assertEqual(implements, cloned._implements)
self.assertEqual(autograph_options, cloned._experimental_autograph_options)
self.assertEqual(relax_shapes, cloned._experimental_relax_shapes)
x = array_ops.ones([])
self.assertEqual(self.evaluate(cloned(x)),
self.evaluate(func(x)))
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()