| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import functools |
| import itertools |
| import re |
| import weakref |
| |
| from absl.testing import parameterized |
| from six.moves import range |
| |
| from tensorflow.python.autograph.core import converter |
| from tensorflow.python.eager import backprop |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.eager import lift_to_graph |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_spec |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.keras.engine import training |
| from tensorflow.python.keras.layers import core |
| from tensorflow.python.module import module |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import random_ops |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops import variables |
| from tensorflow.python.platform import test |
| from tensorflow.python.training import adam |
| |
| |
| class _ModelWithOptimizer(training.Model): |
| |
| def __init__(self): |
| super(_ModelWithOptimizer, self).__init__() |
| self.dense = core.Dense(1) |
| self.optimizer = adam.AdamOptimizer(0.01) |
| |
| @def_function.function( |
| input_signature=(tensor_spec.TensorSpec([None, 2], dtypes.float32), |
| tensor_spec.TensorSpec([None], dtypes.float32))) |
| def call(self, x, y): |
| with backprop.GradientTape() as tape: |
| loss = math_ops.reduce_mean((self.dense(x) - y) ** 2.) |
| trainable_variables = self.trainable_variables |
| gradients = tape.gradient(loss, trainable_variables) |
| self.optimizer.apply_gradients(zip(gradients, trainable_variables)) |
| return {'loss': loss} |
| |
| |
| class _HasDecoratedMethod(object): |
| |
| @def_function.function |
| def f(self, x): |
| return x * 3. |
| |
| class DefFunctionTest(test.TestCase, parameterized.TestCase): |
| |
| def testNoVariables(self): |
| |
| @def_function.function |
| def fn(x): |
| return 2 * x |
| |
| self.assertAllEqual(fn(constant_op.constant(4.0)), 8.0) |
| |
| def testFailIfVariablesAreCreatedMoreThanOnce(self): |
| |
| @def_function.function |
| def fn(x): |
| return variables.Variable(1.0) + x |
| |
| with self.assertRaises(ValueError): |
| fn(1.0) |
| |
| def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self): |
| state = [] |
| |
| @def_function.function |
| def fn(x): |
| state.append(variables.Variable(1.0)) |
| return state[-1] + x |
| |
| with self.assertRaises(ValueError): |
| fn(1.0) |
| |
| def testRange(self): |
| |
| @def_function.function |
| def f(unused_x): |
| return 1.0 |
| |
| self.assertAllEqual(f(range(5)), 1.0) |
| |
| def testCorrectVariableCreation(self): |
| |
| state = [] |
| |
| @def_function.function |
| def fn(x): |
| if not state: |
| state.append(variables.Variable(2.0)) |
| return state[0] * x |
| |
| self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0) |
| self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0) |
| |
| def testFunctionInitializer(self): |
| |
| state = [] |
| |
| @def_function.function |
| def fn(x): |
| if not state: |
| state.append(variables.Variable(lambda: 2.0)) |
| return state[0] * x |
| |
| self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0) |
| |
| def testFunctionMultipleVariableInitializer(self): |
| |
| state = [] |
| |
| @def_function.function |
| def fn(x): |
| if not state: |
| state.append(variables.Variable(lambda: 2.0)) |
| state.append(variables.Variable(lambda: 5.0)) |
| return state[0] * x, state[1] * x |
| |
| self.assertAllEqual(fn(constant_op.constant(1.0)), [2.0, 5.0]) |
| |
| def testFunctionInitializationFunction(self): |
| |
| state = [] |
| |
| @def_function.function |
| def fn(x): |
| if not state: |
| state.append(variables.Variable(2.0)) |
| return state[0] * x |
| |
| init_fn = fn.get_initialization_function(constant_op.constant(1.0)) |
| self.assertEqual(len(state), 1) |
| self.assertFalse( |
| resource_variable_ops.var_is_initialized_op(state[0].handle)) |
| init_fn() |
| self.assertEqual(state[0].numpy(), 2.0) |
| |
| def testVariableInitializerNotConstant(self): |
| |
| state = [] |
| |
| @def_function.function |
| def fn(x): |
| if not state: |
| state.append(variables.Variable(2.0 * x)) |
| return state[0] * x |
| |
| self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0) |
| self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0) |
| |
| def testLegacyGraphModeVariables(self): |
| with ops.Graph().as_default(), self.test_session() as sess: |
| state = [] |
| |
| @def_function.function |
| def fn(x): |
| if not state: |
| state.append(variables.Variable(2.0)) |
| return state[0] * x |
| |
| result = fn(3.0) |
| |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertAllEqual(sess.run(state[0]), 2.0) |
| self.assertAllEqual(self.evaluate(result), 6.0) |
| |
| def testLegacyGraphModeVariablesNonTrivialInitializer(self): |
| with ops.Graph().as_default(), self.test_session() as sess: |
| state = [] |
| |
| @def_function.function |
| def fn(x): |
| if not state: |
| two = constant_op.constant(2.0) |
| four = two * two |
| two_again = math_ops.sqrt(four) |
| state.append(variables.Variable(two_again + four)) |
| return state[0] * x |
| |
| result = fn(3.0) |
| |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertAllEqual(sess.run(state[0]), 6.0) |
| self.assertAllEqual(self.evaluate(result), 18.0) |
| |
| def testLegacyGraphModeInputDependentInitializerFails(self): |
| with ops.Graph().as_default(): |
| state = [] |
| |
| @def_function.function |
| def fn(x): |
| if not state: |
| state.append(variables.Variable(2.0 * x)) |
| return state[0] * x |
| |
| with self.assertRaisesRegexp( |
| lift_to_graph.UnliftableError, r'transitively.* mul .* x'): |
| fn(constant_op.constant(3.0)) |
| |
| def testMethod(self): |
| |
| class MyModel(object): |
| |
| def __init__(self): |
| self.var = None |
| |
| @def_function.function |
| def apply(self, x): |
| if self.var is None: |
| self.var = variables.Variable(2.0) |
| return self.var * x |
| |
| m0 = MyModel() |
| self.assertAllEqual(m0.apply(3.0), 6.0) |
| # Calling twice to exercise that we do not recreate variables. |
| m0.var.assign(3.0) |
| self.assertAllEqual(m0.apply(3.0), 9.0) |
| |
| m1 = MyModel() |
| self.assertAllEqual(m1.apply(3.0), 6.0) |
| |
| def test_functools_partial(self): |
| self.assertAllClose( |
| 3., |
| def_function.function(functools.partial(lambda x, y: x + y, 1.))( |
| constant_op.constant(2.))) |
| |
| def test_functools_partial_new_default(self): |
| def f(x=3, y=7): |
| return x + y |
| |
| func = def_function.function(functools.partial(f, y=6)) |
| self.assertEqual(func().numpy(), 9) |
| self.assertEqual(func(y=8).numpy(), 11) |
| |
| def test_functools_partial_keywords(self): |
| def f(x, y): |
| return x + y |
| |
| func = def_function.function( |
| functools.partial(f, x=array_ops.zeros([1]), y=array_ops.zeros([1]))) |
| self.assertAllEqual(func(), [0.0]) |
| |
| def test_functools_partial_single_positional(self): |
| def f(x, y): |
| return x + y |
| |
| func = def_function.function( |
| functools.partial(f, constant_op.constant(1))) |
| self.assertAllEqual(func(5), 6) |
| |
| def test_complicated_partial_with_defaults(self): |
| |
| def identity(*args): |
| return args |
| |
| def dynamic_unroll(core_fn, |
| input_sequence, |
| initial_state, |
| sequence_length=None, |
| parallel_iterations=1, |
| swap_memory=False): |
| del core_fn |
| self.assertIs(None, sequence_length) |
| self.assertEqual(1, parallel_iterations) |
| self.assertTrue(swap_memory) |
| return input_sequence, initial_state |
| |
| input_sequence = random_ops.random_uniform([1, 1, 1]) |
| initial_state = random_ops.random_uniform([1, 1]) |
| |
| func = def_function.function( |
| functools.partial(dynamic_unroll, identity, swap_memory=True)) |
| func(input_sequence, initial_state) |
| |
| def test_unspecified_default_argument(self): |
| wrapped = def_function.function( |
| lambda x, y=2: x + y, |
| input_signature=[tensor_spec.TensorSpec((), dtypes.int32)]) |
| self.assertEqual(3, wrapped(constant_op.constant(1)).numpy()) |
| |
| def test_optimizer(self): |
| x = constant_op.constant([[3., 4.]]) |
| y = constant_op.constant([2.]) |
| model = _ModelWithOptimizer() |
| model(x, y) |
| |
| def test_concrete_function_from_signature(self): |
| |
| @def_function.function( |
| input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)]) |
| def compute(x): |
| return 2. * x |
| |
| concrete = compute.get_concrete_function() |
| self.assertAllClose(1., concrete(constant_op.constant(0.5))) |
| concrete = compute.get_concrete_function( |
| tensor_spec.TensorSpec(None, dtypes.float32)) |
| self.assertAllClose(4., concrete(constant_op.constant(2.))) |
| signature_args, _ = concrete.structured_input_signature |
| self.assertEqual(signature_args, |
| (tensor_spec.TensorSpec( |
| None, dtypes.float32, name='x'),)) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def test_variable_naming(self): |
| class HasVars(module.Module): |
| |
| def __init__(self): |
| self.x = None |
| self.y = None |
| self.z = None |
| |
| @def_function.function |
| def make_x(self): |
| if self.x is None: |
| self.x = variables.Variable(1., name='v') |
| |
| def make_y(self): |
| if self.y is None: |
| self.y = variables.Variable(1., name='v') |
| |
| def make_z(self): |
| if self.z is None: |
| with ops.name_scope('z_scope', skip_on_eager=False): |
| self.z = variables.Variable(1., name='z') |
| |
| root = HasVars() |
| root.make_x() |
| root.make_y() |
| root.make_z() |
| self.assertEqual('v:0', root.x.name) |
| self.assertEqual('z_scope/z:0', root.z.name) |
| |
| def test_concrete_function_keyword_arguments(self): |
| @def_function.function |
| def f(x): |
| return x |
| |
| conc = f.get_concrete_function( |
| tensor_spec.TensorSpec(None, dtypes.float32, 'y')) |
| conc(y=constant_op.constant(3.0)) |
| signature_args, _ = conc.structured_input_signature |
| self.assertEqual('y', signature_args[0].name) |
| |
| conc = f.get_concrete_function(tensor_spec.TensorSpec(None, dtypes.float32)) |
| conc(x=constant_op.constant(3.0)) |
| signature_args, _ = conc.structured_input_signature |
| self.assertEqual('x', signature_args[0].name) |
| |
| @def_function.function |
| def g(x): |
| return x[0] |
| |
| conc = g.get_concrete_function( |
| [tensor_spec.TensorSpec(None, dtypes.float32, 'z'), 2]) |
| conc(z=constant_op.constant(3.0)) |
| signature_args, _ = conc.structured_input_signature |
| self.assertEqual('z', signature_args[0][0].name) |
| |
| with self.assertRaisesRegexp( |
| ValueError, 'either zero or all names have to be specified'): |
| conc = g.get_concrete_function([ |
| tensor_spec.TensorSpec(None, dtypes.float32, 'z'), |
| tensor_spec.TensorSpec(None, dtypes.float32), |
| ]) |
| |
| def test_error_inner_capture(self): |
| |
| @def_function.function |
| def f(inputs): |
| num_steps, _ = inputs.shape[:2] |
| outputs = [] |
| for t in math_ops.range(num_steps): |
| outputs.append(inputs[t]) |
| return outputs |
| |
| with self.assertRaisesRegexp(errors.InaccessibleTensorError, |
| 'defined in another function or code block'): |
| f(array_ops.zeros(shape=(8, 42, 3))) |
| |
| def testRuntimeErrorNotSticky(self): |
| |
| @def_function.function |
| def fail(i): |
| control_flow_ops.Assert(math_ops.equal(i, 0), ['ick']) |
| |
| fail(constant_op.constant(0)) # OK |
| with self.assertRaises(errors.InvalidArgumentError): |
| fail(constant_op.constant(1)) # InvalidArgument: "ick" |
| fail(constant_op.constant(0)) # OK |
| |
| def testUnderscoreName(self): |
| |
| @def_function.function |
| def f(_): |
| return _ + _ |
| |
| self.assertAllEqual(2.0, f(constant_op.constant(1.0))) |
| |
| def test_serialization_signature_cache(self): |
| |
| @def_function.function |
| def f(x, y): |
| return x, y |
| |
| f(constant_op.constant([[3., 4.]]), constant_op.constant([2.])) |
| f(constant_op.constant([[3, 4, 5]]), constant_op.constant([2])) |
| |
| signatures_args = set() |
| concrete_functions = f._list_all_concrete_functions_for_serialization() |
| for concrete_function in concrete_functions: |
| args, kwargs = concrete_function.structured_input_signature |
| signatures_args.add(args) |
| self.assertEqual(dict(), kwargs) |
| |
| self.assertEqual( |
| signatures_args, |
| set(((tensor_spec.TensorSpec([1, 2], dtypes.float32, name='x'), |
| tensor_spec.TensorSpec([1], dtypes.float32, name='y')), |
| (tensor_spec.TensorSpec([1, 3], dtypes.int32, name='x'), |
| tensor_spec.TensorSpec([1], dtypes.int32, name='y'))))) |
| |
| @test_util.assert_no_garbage_created |
| def testFunctionReferenceCycles(self): |
| fn = def_function.function(lambda x: 2. * x) |
| fn(constant_op.constant(4.0)) |
| weak_fn = weakref.ref(fn) |
| del fn |
| # Tests that the weak reference we made to the function is now dead, which |
| # means the object has been deleted. This should be true as long as the |
| # function itself is not involved in a reference cycle. |
| self.assertIs(None, weak_fn()) |
| |
| @test_util.assert_no_garbage_created |
| def testMethodReferenceCycles(self): |
| has_decorated_method = _HasDecoratedMethod() |
| has_decorated_method.f(constant_op.constant(5.)) |
| weak_fn = weakref.ref(has_decorated_method.f) |
| del has_decorated_method |
| # Tests that the weak reference we made to the function is now dead, which |
| # means the object has been deleted. This should be true as long as the |
| # function itself is not involved in a reference cycle. |
| self.assertIs(None, weak_fn()) |
| |
| @test_util.assert_no_new_pyobjects_executing_eagerly |
| def testErrorMessageWhenGraphTensorIsPassedToEager(self): |
| |
| @def_function.function |
| def failing_function(): |
| a = constant_op.constant(1.) |
| |
| with ops.init_scope(): |
| _ = a + a |
| |
| with self.assertRaisesRegexp( |
| TypeError, |
| re.compile('An op outside of the function.*passed.*Const', re.DOTALL)): |
| failing_function() |
| |
| def testNonUniqueNamesGetConcreteFunction(self): |
| @def_function.function |
| def non_unique_arg_names(x, **kwargs): |
| a, b, c = x |
| d = kwargs['d'] |
| return a + b + c + d |
| |
| concrete = non_unique_arg_names.get_concrete_function( |
| (tensor_spec.TensorSpec(None, dtypes.float32), |
| tensor_spec.TensorSpec(None, dtypes.float32), |
| tensor_spec.TensorSpec(None, dtypes.float32)), |
| d=tensor_spec.TensorSpec(None, dtypes.float32)) |
| self.assertAllClose( |
| 10., |
| concrete(x=constant_op.constant(1.), |
| x_1=constant_op.constant(2.), |
| x_2=constant_op.constant(3.), |
| d=constant_op.constant(4.))) |
| self.assertAllClose( |
| 10., |
| concrete(constant_op.constant(1.), |
| constant_op.constant(2.), |
| constant_op.constant(3.), |
| constant_op.constant(4.))) |
| |
| def testVariableCreatorScope(self): |
| created_variables = [] |
| captured_variables = [] |
| |
| @def_function.function |
| def f(): |
| if not created_variables: |
| created_variables.append(variables.Variable(1.)) |
| return created_variables[0] + 1. |
| |
| def capture_creator(next_creator, **kwargs): |
| created = next_creator(**kwargs) |
| captured_variables.append(created) |
| return created |
| |
| with variable_scope.variable_creator_scope(capture_creator): |
| f() |
| self.assertEqual(created_variables, captured_variables) |
| |
| def testVarAlreadyInitializedNoClobbering(self): |
| v_holder = [] |
| |
| @def_function.function |
| def add_var(x): |
| if not v_holder: |
| v = variables.Variable([1., 2.]) |
| v_holder.append(v) |
| already_initialized = variables.Variable(3.) |
| with ops.init_scope(): |
| already_initialized.assign(10.) |
| v_holder.append(already_initialized) |
| return v_holder[0] + v_holder[1] + x |
| |
| add_var.get_concrete_function(constant_op.constant(2.)) |
| self.assertAllClose([13., 14.], add_var(constant_op.constant(2.))) |
| |
| def testSameVariableTwice(self): |
| v = variables.Variable(1.0) |
| |
| @def_function.function |
| def add(a, b): |
| return a + b |
| |
| self.assertAllEqual(add(v, v), 2.0) |
| |
| def testVariableUpdate(self): |
| v1 = variables.Variable(1.0) |
| v2 = variables.Variable(2.0) |
| v3 = variables.Variable(4, dtype=dtypes.int32) |
| |
| trace_count = [0] |
| |
| @def_function.function |
| def double_variable(x): |
| trace_count[0] += 1 |
| x.assign_add(x.read_value()) |
| |
| self.assertEqual(trace_count[0], 0) |
| double_variable(v1) |
| self.assertEqual(trace_count[0], 1) |
| self.assertEqual(self.evaluate(v1), 2.0) |
| double_variable(v2) |
| self.assertEqual(trace_count[0], 1 if ops.Tensor._USE_EQUALITY else 2) |
| self.assertEqual(self.evaluate(v2), 4.0) |
| double_variable(v3) |
| self.assertEqual(trace_count[0], 2 if ops.Tensor._USE_EQUALITY else 3) |
| self.assertEqual(self.evaluate(v3), 8) |
| |
| def testShapeCache(self): |
| @def_function.function |
| def func(x): |
| return 2 * x |
| |
| func_a = func.get_concrete_function( |
| tensor_spec.TensorSpec([None], dtypes.int32)) |
| func_b = func.get_concrete_function( |
| tensor_spec.TensorSpec([None], dtypes.int32)) |
| |
| self.assertIs(func_a, func_b) |
| |
| def testInitializationInNestedCall(self): |
| v_holder = [] |
| |
| @def_function.function |
| def add_var(x): |
| if not v_holder: |
| v = variables.Variable([1., 2.]) |
| v_holder.append(v) |
| already_initialized = variables.Variable(3.) |
| with ops.init_scope(): |
| already_initialized.assign(10.) |
| v_holder.append(already_initialized) |
| return v_holder[0] + v_holder[1] + x |
| |
| @def_function.function |
| def wrapper(x): |
| return add_var(x) |
| |
| self.assertAllClose([13., 14.], wrapper(constant_op.constant(2.))) |
| v_holder[1].assign(11.) |
| self.assertAllClose([14., 15.], wrapper(constant_op.constant(2.))) |
| |
| # TODO(b/137148281): reenable |
| @test_util.run_gpu_only |
| def testDeviceAnnotationRespected(self): |
| a = [] |
| |
| @def_function.function() |
| def create_variable(): |
| with ops.init_scope(): |
| initial_value = random_ops.random_uniform( |
| (2, 2), maxval=1000000, dtype=dtypes.int64) |
| |
| if not a: |
| with ops.device('CPU:0'): |
| a.append(resource_variable_ops.ResourceVariable(initial_value)) |
| |
| return a[0].read_value() |
| |
| created_variable_read = create_variable() |
| self.assertRegexpMatches(a[0].device, 'CPU') |
| |
| def testDecorate(self): |
| func = def_function.function(lambda: 1) |
| def decorator(f): |
| return lambda: 1 + f() |
| |
| func._decorate(decorator) |
| self.assertEqual(func().numpy(), 2) |
| |
| @parameterized.parameters(*itertools.product( |
| (None, (tensor_spec.TensorSpec([]),)), # input_signature |
| (True, False), # autograph |
| (None, converter.Feature.ALL), # autograph_options |
| (None, 'foo.bar'), # implements |
| (None, True, False), # relax_shapes |
| (True, False), # compile |
| (True, False), # override_function |
| )) |
| def testClone(self, input_signature, autograph, autograph_options, implements, |
| relax_shapes, compile_, override_function): |
| original_py_function = lambda x: x |
| |
| compile_ = False |
| func = def_function.function( |
| func=original_py_function, |
| input_signature=input_signature, |
| autograph=autograph, |
| experimental_implements=implements, |
| experimental_autograph_options=autograph_options, |
| experimental_relax_shapes=relax_shapes, |
| experimental_compile=compile_) |
| |
| if override_function: |
| cloned_py_function = lambda x: x + 1 |
| else: |
| cloned_py_function = original_py_function |
| |
| cloned = func._clone(python_function=cloned_py_function) |
| |
| self.assertEqual(cloned_py_function, cloned._python_function) |
| self.assertEqual(func._name, cloned._name) |
| self.assertEqual(input_signature, cloned._input_signature) |
| self.assertEqual(autograph, cloned._autograph) |
| self.assertEqual(implements, cloned._implements) |
| self.assertEqual(autograph_options, cloned._experimental_autograph_options) |
| self.assertEqual(relax_shapes, cloned.experimental_relax_shapes) |
| self.assertEqual(compile_, cloned._experimental_compile) |
| |
| # This test does not run with XLA JIT support linked in so we can only check |
| # the output of the function if compile is disabled. |
| if not compile_: |
| x = array_ops.zeros([]) |
| self.assertEqual(self.evaluate(cloned(x)), |
| self.evaluate(cloned_py_function(x))) |
| |
| def testLiftPlaceholderInitializedVariable(self): |
| with ops.Graph().as_default(): |
| var_list = [] |
| |
| @def_function.function |
| def use_variable(): |
| if not var_list: |
| initial_value = array_ops.placeholder(shape=[], dtype=dtypes.float32) |
| v = variables.Variable(initial_value) |
| var_list.append(v) |
| return var_list[0] + 1. |
| |
| var_plus_one = use_variable() |
| with self.session() as session: |
| init_op = var_list[0].initializer |
| session.run(init_op, feed_dict={init_op.inputs[1]: 2.}) |
| self.assertEqual(3., session.run(var_plus_one)) |
| |
| def testDecorate_rejectedAfterTrace(self): |
| func = def_function.function(lambda: 1) |
| self.assertEqual(func().numpy(), 1) |
| msg = 'Functions cannot be decorated after they have been traced.' |
| with self.assertRaisesRegexp(ValueError, msg): |
| func._decorate(lambda f: f) |
| |
| def testGetConcreteFunctionGraphLifetime(self): |
| |
| @def_function.function |
| def func(): |
| pass |
| |
| graph = func.get_concrete_function().graph |
| del func |
| |
| # If the graph is deleted, then an exception is raised on reading `captures` |
| self.assertEmpty(graph.captures) |
| |
| |
| if __name__ == '__main__': |
| ops.enable_eager_execution() |
| test.main() |