| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import collections |
| import functools |
| from multiprocessing.pool import ThreadPool |
| import sys |
| import weakref |
| |
| import numpy |
| |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.core.protobuf import rewriter_config_pb2 |
| from tensorflow.python import keras |
| from tensorflow.python.data.ops import iterator_ops |
| from tensorflow.python.eager import backprop |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import function |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import function as tf_function |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import random_seed |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.framework import tensor_spec |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.keras.engine import training as keras_training |
| from tensorflow.python.layers import convolutional |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import clip_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import gradients_impl |
| from tensorflow.python.ops import init_ops |
| from tensorflow.python.ops import list_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import random_ops |
| from tensorflow.python.ops import resource_variable_ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops import variables |
| from tensorflow.python.platform import test |
| from tensorflow.python.training import adam |
| from tensorflow.python.training import momentum |
| from tensorflow.python.training import training_ops |
| from tensorflow.python.util import compat |
| from tensorflow.python.util import nest |
| |
| |
| class MiniModel(keras_training.Model): |
| """Minimal model for mnist. |
| |
| Useful for testing and debugging on slow TPU simulators. |
| """ |
| |
| def __init__(self): |
| super(MiniModel, self).__init__(name='') |
| self.fc = keras.layers.Dense(1, name='fc', kernel_initializer='ones', |
| bias_initializer='ones') |
| |
| def call(self, inputs, training=True): |
| return self.fc(inputs) |
| |
| |
| class DefunnedMiniModel(MiniModel): |
| |
| @function.defun |
| def call(self, inputs, training=True): |
| return super(DefunnedMiniModel, self).call(inputs, training=training) |
| |
| |
| @test_util.with_c_shapes |
| class FunctionTest(test.TestCase): |
| |
| def testBasic(self): |
| matmul = function.defun(math_ops.matmul) |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| sq = matmul(t, t, transpose_a=True) |
| sq2 = matmul(sq, t, transpose_a=True) |
| self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) |
| self.assertAllEqual(sq2.numpy().reshape(-1), [52, 76, 74, 108]) |
| |
| def testBasicGraphMode(self): |
| matmul = function.defun(math_ops.matmul) |
| |
| @function.defun |
| def sq(a): |
| return matmul(a, a) |
| |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| out = sq(t) |
| self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) |
| |
| def testNestedInputsGraphMode(self): |
| matmul = function.defun(math_ops.matmul) |
| |
| pair = collections.namedtuple('pair', ['a', 'b']) |
| |
| @function.defun |
| def a_times_b(inputs): |
| return matmul(inputs.a['a'], inputs.b['b']) |
| |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| |
| out = a_times_b(pair({'a': t}, {'b': t})) |
| self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) |
| |
| def testGraphModeWithGradients(self): |
| v = resource_variable_ops.ResourceVariable(1.0, name='v') |
| |
| @function.defun |
| def step(): |
| def inner(): |
| return v * v |
| |
| return backprop.implicit_grad(inner)()[0][0] |
| |
| self.assertAllEqual(step(), 2.0) |
| |
| def testGraphGradientVariable(self): |
| with ops.Graph().as_default(), self.cached_session(): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| |
| @function.defun |
| def f(): |
| return 2.0 * v |
| |
| node = f() |
| grads, = gradients_impl.gradients(node, v) |
| v.initializer.run() |
| self.assertAllEqual(grads.eval(), 2.0) |
| self.assertEqual(grads.shape, v.shape) |
| |
| def testGraphEagerIsolation(self): |
| |
| @function.defun |
| def f(): |
| self.v = resource_variable_ops.ResourceVariable(1.0) |
| return self.v.read_value() |
| |
| self.assertAllEqual(f(), 1.0) |
| |
| with ops.Graph().as_default(): |
| self.assertEqual(f().shape, ()) |
| |
| def testBasicGraphFunction(self): |
| matmul = function.defun(math_ops.matmul) |
| |
| @function.defun |
| def sq(a): |
| return matmul(a, a) |
| |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| |
| sq_op = sq.get_concrete_function(t) |
| self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) |
| out = sq_op(t) |
| self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) |
| |
| def testExecutingStatelessDefunConcurrently(self): |
| |
| @function.defun |
| def stateless(x): |
| return math_ops.multiply(2.0, x) |
| |
| pool = ThreadPool() |
| inputs = [constant_op.constant(1.0 * x) for x in range(100)] |
| outputs = [float(out) for out in pool.map(stateless, inputs)] |
| expected = [float(2.0 * x) for x in inputs] |
| self.assertSequenceEqual(outputs, expected) |
| |
| def testExecutingManyStatelessDefunsConcurrently(self): |
| |
| @function.defun |
| def stateless(x): |
| del x |
| return math_ops.multiply(2.0, 2.0) |
| |
| pool = ThreadPool() |
| # `pool.map` below instantiates 100 functions, one for each object. |
| outputs = [ |
| float(out) |
| for out in pool.map(stateless, [object() for _ in range(100)]) |
| ] |
| expected = [4.0] * 100 |
| self.assertSequenceEqual(outputs, expected) |
| |
| def testExecutingStatefulDefunConcurrently(self): |
| |
| v = resource_variable_ops.ResourceVariable(1.0) |
| |
| @function.defun |
| def stateful(x): |
| v.assign(x) |
| |
| pool = ThreadPool() |
| inputs = [constant_op.constant(0.0)] * 100 |
| pool.map(stateful, inputs) |
| self.assertEqual(float(v.read_value()), 0.0) |
| |
| def testExecutingManyStatefulDefunsConcurrently(self): |
| |
| v = resource_variable_ops.ResourceVariable(1.0) |
| |
| @function.defun |
| def stateful(x): |
| del x |
| return v.assign(0.0) |
| |
| pool = ThreadPool() |
| # `pool.map` below instantiates 100 functions, one for each object. |
| pool.map(stateful, [object() for _ in range(100)]) |
| self.assertEqual(float(v.read_value()), 0.0) |
| |
| def disabled_testRandomSeed(self): |
| |
| @function.defun |
| def f(): |
| return random_ops.random_normal(()) |
| |
| random_seed.set_random_seed(1) |
| x = f() |
| self.assertNotEqual(x, f()) |
| random_seed.set_random_seed(1) |
| self.assertAllEqual(f(), x) |
| |
| def testSymGradGatherNd(self): |
| with ops.Graph().as_default(), self.cached_session() as sess: |
| |
| @function.defun |
| def f(x): |
| return array_ops.gather_nd(x, [[0]]) |
| |
| c = constant_op.constant([[2.]]) |
| f_c = f(c) |
| g, = gradients_impl.gradients(f_c, c) |
| self.assertAllEqual(sess.run(g), [[1.0]]) |
| |
| def testNestedInputsGraphFunction(self): |
| matmul = function.defun(math_ops.matmul) |
| |
| pair = collections.namedtuple('pair', ['a', 'b']) |
| |
| @function.defun |
| def a_times_b(inputs): |
| return matmul(inputs.a['a'], inputs.b['b']) |
| |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| inputs = pair({'a': t}, {'b': t}) |
| sq_op = a_times_b.get_concrete_function(inputs) |
| self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2])) |
| out = sq_op(inputs) |
| self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) |
| |
| def testNestedOutputGraphFunction(self): |
| matmul = function.defun(math_ops.matmul) |
| |
| @function.defun |
| def sq(a): |
| return (matmul(a, a), {'b': constant_op.constant(1.0)}) |
| |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| |
| sq_op = sq.get_concrete_function(t) |
| self.assertEqual(sq_op.output_shapes, |
| (tensor_shape.TensorShape([2, 2]), |
| {'b': tensor_shape.TensorShape([])})) |
| self.assertEqual(sq_op.output_dtypes, |
| (dtypes.float32, {'b': dtypes.float32})) |
| (a, b) = sq_op(t) |
| self.assertAllEqual(a, math_ops.matmul(t, t).numpy()) |
| self.assertAllEqual(b['b'].numpy(), 1.0) |
| |
| def testGraphFunctionWithGradients(self): |
| v = resource_variable_ops.ResourceVariable(1.0, name='v') |
| |
| @function.defun |
| def step(): |
| def inner(): |
| return v * v |
| |
| return backprop.implicit_grad(inner)()[0][0] |
| |
| step_op = step.get_concrete_function() |
| self.assertEqual(step_op.output_dtypes, dtypes.float32) |
| self.assertEqual(step_op.output_shapes, tensor_shape.TensorShape([])) |
| self.assertAllEqual(step_op(), 2.0) |
| |
| def testGraphFunctionNoneOutput(self): |
| @function.defun |
| def fn(unused_a, unused_b): |
| return None |
| |
| x = constant_op.constant(1) |
| fn_op = fn.get_concrete_function(x, x) |
| self.assertEqual(fn_op.output_dtypes, None) |
| self.assertEqual(fn_op.output_shapes, None) |
| self.assertAllEqual(fn_op(x, x), None) |
| |
| @test_util.run_in_graph_and_eager_modes() |
| def testDefunCondGradient(self): |
| |
| @function.defun |
| def f(x): |
| return control_flow_ops.cond(x > 0.5, lambda: 2 * x, lambda: 3 * x) |
| |
| with backprop.GradientTape() as t: |
| x = constant_op.constant(1.0) |
| t.watch(x) |
| y = f(x) |
| self.assertAllEqual(self.evaluate(t.gradient(y, x)), 2.0) |
| |
| @test_util.run_in_graph_and_eager_modes() |
| def testGraphLoopGradient(self): |
| |
| @function.defun |
| def f(x): |
| return control_flow_ops.while_loop(lambda _, i: i < 2, |
| lambda x, i: (2*x, i + 1), |
| [x, 0])[0] |
| |
| with backprop.GradientTape() as t: |
| x = constant_op.constant(1.0) |
| t.watch(x) |
| y = f(x) |
| self.assertAllEqual(self.evaluate(t.gradient(y, x)), 4.0) |
| |
| def testDefunNumpyArraysConvertedToTensors(self): |
| |
| def f(x): |
| self.assertIsInstance(x, ops.Tensor) |
| return x |
| |
| x = random_ops.random_uniform([2, 2]).numpy() |
| defined = function.defun(f) |
| defined(x) |
| self.assertEqual(len(defined._function_cache), 1) |
| |
| x = random_ops.random_uniform([2, 2]).numpy() |
| defined(x) |
| # A NumPy array with different values but the same shape and dtype |
| # shouldn't trigger another function definition. |
| self.assertEqual(len(defined._function_cache), 1) |
| |
| # Test that the numpy array is properly an argument to the graph function. |
| self.assertEqual(1., defined(numpy.ones([])).numpy()) |
| self.assertEqual(0., defined(numpy.zeros([])).numpy()) |
| self.assertEqual(1., defined(array_ops.ones([])).numpy()) |
| self.assertEqual(0., defined(array_ops.zeros([])).numpy()) |
| |
| def testDefunCapturedInt32(self): |
| x = constant_op.constant(1, dtype=dtypes.int32) |
| |
| @function.defun |
| def add_int32s(): |
| return x + x |
| |
| self.assertEqual(2, int(add_int32s())) |
| |
| def testDefunReadVariable(self): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| |
| @function.defun |
| def f(): |
| return v.read_value() |
| |
| self.assertEqual(1.0, float(f())) |
| |
| def testDefunAssignAddVariable(self): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| x = constant_op.constant(2.0) |
| |
| @function.defun |
| def test_assign_add(): |
| v.assign_add(x) |
| return v.read_value() |
| |
| self.assertEqual(3.0, float(test_assign_add())) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testTensorInitializationInFunctionRaisesError(self): |
| error_msg = ('Tensor-typed variable initializers must either be ' |
| 'wrapped in an init_scope or callable.*') |
| |
| @function.defun |
| def tensor_init(): |
| with self.assertRaisesRegexp(ValueError, error_msg): |
| resource_variable_ops.ResourceVariable(constant_op.constant(2.0)) |
| |
| tensor_init() |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testCallableTensorInitializationInFunction(self): |
| |
| @function.defun |
| def tensor_init(): |
| self.v = resource_variable_ops.ResourceVariable( |
| lambda: constant_op.constant(2.0)) |
| return self.v.read_value() |
| |
| value = tensor_init() |
| if not context.executing_eagerly(): |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertEqual(self.evaluate(value), 2.0) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testInitScopeTensorInitializationInFunction(self): |
| |
| @function.defun |
| def tensor_init(): |
| with ops.init_scope(): |
| const = constant_op.constant(2.0) |
| self.v = resource_variable_ops.ResourceVariable(const) |
| return self.v.read_value() |
| |
| value = tensor_init() |
| if not context.executing_eagerly(): |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertEqual(self.evaluate(value), 2.0) |
| |
| def testDefunShapeInferenceWithCapturedResourceVariable(self): |
| v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) |
| |
| def f(): |
| x = constant_op.constant([[1, 2], [3, 4]]) |
| out = math_ops.matmul(v, x) |
| self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) |
| # We do not return v directly since the tensor conversion function of |
| # ResourceVariable returns the read value and not the resource itself. |
| return v._handle |
| |
| compiled = function.defun(f) |
| var_handle = compiled() |
| self.assertEqual(var_handle.dtype, dtypes.resource) |
| self.assertEqual(var_handle.shape, tensor_shape.scalar()) |
| var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) |
| self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) |
| |
| def testVariableInLoopInFunction(self): |
| |
| @function.defun |
| def test_function(): |
| |
| def loop_test(_): |
| return False |
| |
| def loop_body(_): |
| return variable_scope.get_variable('a', shape=()) |
| |
| return control_flow_ops.while_loop(loop_test, loop_body, [0.0]) |
| |
| self.assertEqual(test_function().shape, []) |
| |
| def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self): |
| with context.graph_mode(): |
| v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]]) |
| |
| def f(): |
| x = constant_op.constant([[1, 2], [3, 4]]) |
| out = math_ops.matmul(v, x) |
| self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) |
| # We do not return v directly since the tensor conversion function of |
| # ResourceVariable returns the read value and not the resource itself. |
| return v._handle |
| |
| compiled = function.defun(f) |
| var_handle = compiled() |
| self.assertEqual(var_handle.dtype, dtypes.resource) |
| self.assertEqual(var_handle.shape, tensor_shape.scalar()) |
| var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype) |
| self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2])) |
| |
| def testDefunShapeInferenceWithCapturedVariableInGraphMode(self): |
| with context.graph_mode(): |
| v = variables.Variable([[1, 2], [3, 4]]) |
| |
| def f(): |
| x = constant_op.constant([[1, 2], [3, 4]]) |
| out = math_ops.matmul(v, x) |
| self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2])) |
| |
| # Check that shape inference works while creating the defun |
| compiled = function.defun(f) |
| compiled() |
| |
| def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self): |
| with context.graph_mode(): |
| tensor_list = list_ops.empty_tensor_list( |
| element_dtype=dtypes.float32, |
| element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) |
| tensor_list = list_ops.tensor_list_push_back(tensor_list, |
| constant_op.constant(1.0)) |
| tensor_list = list_ops.tensor_list_push_back(tensor_list, |
| constant_op.constant(2.0)) |
| |
| def f(): |
| tl, value = list_ops.tensor_list_pop_back( |
| tensor_list, element_dtype=dtypes.float32) |
| self.assertEqual(value.shape, tensor_shape.scalar()) |
| return tl |
| |
| compiled = function.defun(f) |
| output_tensor_list = compiled() |
| _, value = list_ops.tensor_list_pop_back( |
| output_tensor_list, element_dtype=dtypes.float32) |
| self.assertEqual(value.shape, tensor_shape.scalar()) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testDefunForcesResourceVariables(self): |
| |
| def variable_creator(): |
| self.v = variables.Variable(0.0) |
| return self.v.read_value() |
| |
| self.v = None |
| defined = function.defun(variable_creator) |
| defined() # Create the variable. |
| self.assertIsInstance( |
| self.v, resource_variable_ops.ResourceVariable) |
| |
| def testDefunDifferentiable(self): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| |
| @function.defun |
| def f(): |
| return v * v |
| |
| self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0) |
| |
| def testDefunCanBeDifferentiatedTwice(self): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| |
| @function.defun |
| def f(): |
| return v * v |
| |
| self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0) |
| # Ensure that v is watched again. |
| self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0) |
| |
| def testGraphModeCaptureVariable(self): |
| with context.graph_mode(), self.cached_session() as sess: |
| |
| class HasAVar(object): |
| |
| def __init__(self): |
| self.v = resource_variable_ops.ResourceVariable(1.0) |
| |
| def call(self): |
| return self.v * 2 |
| |
| o = HasAVar() |
| variables.global_variables_initializer().run() |
| call = function.defun(o.call) |
| op = call() |
| self.assertAllEqual(sess.run(op), 2.0) |
| |
| def testSymbolicGradientVariableZerosLike(self): |
| with ops.Graph().as_default(): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| |
| @function.defun |
| def f(x, v): |
| v.read_value() |
| return x * x |
| |
| x = constant_op.constant(1.0) |
| l = f(x, v) |
| _, dv = gradients_impl.gradients(l, [x, v]) |
| with self.cached_session(): |
| v.initializer.run() |
| self.assertAllEqual(dv.eval(), 0.0) |
| |
| def testGraphModeManyFunctions(self): |
| with context.graph_mode(), self.cached_session(): |
| |
| @function.defun |
| def f(x): |
| return x * x |
| |
| @function.defun |
| def g(x): |
| return f(x) + 1 |
| |
| self.assertAllEqual(g(constant_op.constant(2.0)).eval(), 5.0) |
| |
| def testDict(self): |
| |
| @function.defun |
| def f(x): |
| return {'name': x + 1} |
| |
| self.assertAllEqual(f(constant_op.constant(1.0))['name'], 2.0) |
| |
| def testTensorConversionWithDefun(self): |
| |
| @function.defun |
| def f(x): |
| return math_ops.add(x, constant_op.constant(3)) |
| |
| self.assertAllEqual(5, f(constant_op.constant(2))) |
| |
| def testTensorConversionCall(self): |
| |
| @function.defun |
| def f(x): |
| return math_ops.add(x, constant_op.constant(3)) |
| |
| @function.defun |
| def g(x): |
| return f(f(x)) |
| |
| self.assertAllEqual(8, g(constant_op.constant(2))) |
| |
| def testDefunCallBackprop(self): |
| |
| @function.defun |
| def f(x): |
| return math_ops.add(x, x) |
| |
| @function.defun |
| def g(x): |
| return backprop.gradients_function(f, [0])(x)[0] |
| |
| self.assertAllEqual(2, g(constant_op.constant(2.))) |
| |
| def testGraphModeEagerGradError(self): |
| with context.graph_mode(): |
| def f(): |
| x = variable_scope.get_variable( |
| 'v', initializer=constant_op.constant(1.0)) |
| return x * constant_op.constant(2.0) |
| |
| with self.assertRaisesRegexp(ValueError, |
| 'No trainable variables were accessed'): |
| backprop.implicit_val_and_grad(f)() |
| |
| def testDefunCallBackpropUsingSameObjectForMultipleArguments(self): |
| |
| @function.defun |
| def g(x): |
| return backprop.gradients_function(math_ops.multiply, [0, 1])(x, x) |
| |
| def np_g(x): |
| return [d.numpy() for d in g(x)] |
| |
| x = constant_op.constant(1.) |
| self.assertAllEqual([1., 1.], np_g(x)) |
| self.assertAllEqual([1., 1.], np_g(1.)) |
| |
| def testCallShape(self): |
| |
| @function.defun |
| def f(x): |
| return x + 1 |
| |
| @function.defun |
| def g(x): |
| x = f(x) |
| self.assertEqual(x.shape.as_list(), []) |
| return None |
| |
| g(constant_op.constant(1.0)) |
| |
| def testNestedDefunWithNoOutputAndTapedInput(self): |
| three = resource_variable_ops.ResourceVariable(3.0, name='v') |
| |
| @function.defun |
| def f(x): |
| # This function intentionally takes a taped variable as input, |
| # but does not return any values |
| math_ops.add(x, three) |
| |
| @function.defun |
| def g(x): |
| y = math_ops.add(x, three) |
| f(y) |
| |
| g(three) |
| |
| def testGradientTensorConversionWithDefun(self): |
| three = resource_variable_ops.ResourceVariable(3.0, name='v') |
| |
| @function.defun |
| def f(x): |
| return math_ops.add(x, three) |
| |
| def g(x): |
| return f(x) |
| |
| g = backprop.implicit_grad(g)(constant_op.constant(1.0))[0][0] |
| self.assertAllEqual(g, 1.0) |
| |
| def testGradient(self): |
| matmul = function.defun(math_ops.matmul) |
| |
| def sq(x): |
| return matmul(x, x, transpose_a=True) |
| |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| grad_t, = backprop.gradients_function(sq, [0])(t) |
| self.assertAllEqual(grad_t, [[6, 6], [14, 14]]) |
| |
| def testGradientInFunction(self): |
| |
| @function.defun |
| def f(x): |
| return backprop.gradients_function(lambda y: y * y, [0])(x)[0] |
| |
| self.assertAllEqual(f(constant_op.constant(1.0)), 2.0) |
| |
| def testGatherResourceWithDefun(self): |
| with ops.device('cpu:0'): |
| v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) |
| |
| def sum_gather(): |
| return math_ops.reduce_sum(array_ops.gather(v, [1, 2])) |
| |
| defined = function.defun(sum_gather) |
| self.assertAllEqual(sum_gather(), defined()) |
| |
| def testGradientOfGatherWithDefun(self): |
| v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) |
| |
| def sum_gather(): |
| return math_ops.reduce_sum(array_ops.gather(v, [1, 2])) |
| |
| grad_fn = backprop.implicit_grad(sum_gather) |
| gradient = grad_fn() |
| defun_grad_fn = backprop.implicit_grad(function.defun(sum_gather)) |
| defun_gradient = defun_grad_fn() |
| self.assertEqual(len(gradient), len(defun_gradient)) |
| |
| gradient = gradient[0][0] |
| defun_gradient = defun_gradient[0][0] |
| self.assertAllEqual(gradient.values, defun_gradient.values) |
| self.assertAllEqual(gradient.indices, defun_gradient.indices) |
| self.assertAllEqual(gradient.dense_shape, defun_gradient.dense_shape) |
| |
| def testReturningIndexedSlicesWithDefun(self): |
| |
| def validate(indexed_slice): |
| @function.defun |
| def f(): |
| return indexed_slice |
| |
| output = f() |
| self.assertTrue(isinstance(output, ops.IndexedSlices)) |
| self.assertAllEqual(indexed_slice.values, output.values) |
| self.assertAllEqual(indexed_slice.indices, output.indices) |
| self.assertAllEqual(indexed_slice.dense_shape, output.dense_shape) |
| |
| self.assertEqual( |
| f.get_concrete_function().output_shapes, |
| indexed_slice.values.shape) |
| |
| arg = ops.IndexedSlices( |
| values=constant_op.constant([1, 2]), |
| indices=constant_op.constant([0, 1]), |
| dense_shape=constant_op.constant([2])) |
| validate(arg) |
| |
| arg = ops.IndexedSlices( |
| values=constant_op.constant([1, 2]), |
| indices=constant_op.constant([0, 1]), |
| dense_shape=None) |
| validate(arg) |
| |
| def testIndexedSliceAsArgumentWithDefun(self): |
| |
| @function.defun |
| def f(indexed_slice): |
| return indexed_slice |
| |
| def validate(arg): |
| output = f(arg) |
| self.assertTrue(isinstance(output, ops.IndexedSlices)) |
| self.assertAllEqual(arg.values, output.values) |
| self.assertAllEqual(arg.indices, output.indices) |
| self.assertAllEqual(arg.dense_shape, output.dense_shape) |
| |
| indexed_slice = ops.IndexedSlices( |
| values=constant_op.constant([1]), |
| indices=constant_op.constant([0]), |
| dense_shape=constant_op.constant([1])) |
| validate(indexed_slice) |
| |
| # Test that `f` works even when `dense_shape` is None. |
| indexed_slice = ops.IndexedSlices( |
| values=constant_op.constant([1]), |
| indices=constant_op.constant([0]), |
| dense_shape=None) |
| validate(indexed_slice) |
| |
| def testFunctionOnDevice(self): |
| if not context.context().num_gpus(): |
| self.skipTest('No GPUs found') |
| |
| x = constant_op.constant([1.]).gpu() |
| f = function.defun(math_ops.add) |
| y = f(x, x).cpu() |
| self.assertAllEqual(y, [2.]) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testFunctionWithResourcesOnDifferentDevices(self): |
| if not context.context().num_gpus(): |
| self.skipTest('No GPUs found.') |
| |
| with ops.device('/cpu:0'): |
| v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) |
| |
| with ops.device('/gpu:0'): |
| v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0]) |
| |
| def sum_gather(): |
| cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2])) |
| gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2])) |
| return cpu_result, gpu_result |
| |
| defined = function.defun(sum_gather) |
| if not context.executing_eagerly(): |
| self.evaluate(variables.global_variables_initializer()) |
| expected = self.evaluate(sum_gather()) |
| self.assertAllEqual(expected, self.evaluate(defined())) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testOpInFunctionWithConflictingResourceInputs(self): |
| if not context.context().num_gpus(): |
| self.skipTest('No GPUs found.') |
| |
| with ops.device('/cpu:0'): |
| v_cpu = resource_variable_ops.ResourceVariable( |
| [0.0, 1.0, 2.0], name='cpu') |
| v_also_cpu = resource_variable_ops.ResourceVariable( |
| [0.0, 1.0, 2.0], name='also_cpu') |
| |
| with ops.device('/gpu:0'): |
| v_gpu = resource_variable_ops.ResourceVariable( |
| [0.0, 1.0, 2.0], name='gpu') |
| |
| @function.defun |
| def resource_apply_adam(): |
| training_ops.resource_apply_adam( |
| v_cpu.handle, |
| v_gpu.handle, |
| v_also_cpu.handle, |
| 1.0, # beta1_power |
| 1.0, # beta2_power |
| 1.0, # learning_rate |
| 1.0, # beta1 |
| 1.0, # beta2 |
| 1.0, # epsilon, |
| [1.0, 1.0, 1.0], # grad |
| False) # use_locking |
| return None |
| |
| with self.assertRaisesRegexp( |
| errors.InvalidArgumentError, 'Could not colocate node with its ' |
| 'resource and reference inputs.*'): |
| if not context.executing_eagerly(): |
| self.evaluate(variables.global_variables_initializer()) |
| self.evaluate(resource_apply_adam()) |
| |
| def testFunctionHandlesInputsOnDifferentDevices(self): |
| if not context.context().num_gpus(): |
| self.skipTest('No GPUs found') |
| |
| # The Reshape op requires the shape tensor to be placed in host memory. |
| reshape = function.defun(array_ops.reshape) |
| value = constant_op.constant([1., 2.]).gpu() |
| shape = constant_op.constant([2, 1]) |
| reshaped = reshape(value, shape).cpu() |
| self.assertAllEqual(reshaped, [[1], [2]]) |
| |
| def testFunctionHandlesInputsPlacedOnTheWrongDeviceGracefully(self): |
| if not context.context().num_gpus(): |
| self.skipTest('No GPUs found') |
| |
| # The Reshape op requires the shape tensor to be placed in host memory. |
| reshape = function.defun(array_ops.reshape) |
| value = constant_op.constant([1., 2.]) |
| shape = constant_op.constant([2, 1]).gpu() |
| reshape(value, shape) # No error is raised |
| |
| def testDifferentiableFunctionNoneOutputs(self): |
| |
| @function.defun |
| def my_function(x): |
| return x, None |
| |
| def wrapper(x): |
| return my_function(x)[0] |
| |
| g = backprop.gradients_function(wrapper, [0])(constant_op.constant(0.0)) |
| self.assertAllEqual(g[0], 1.) |
| |
| @function.defun |
| def foo(a): |
| return None, a * a |
| |
| x = constant_op.constant(5.0) |
| with backprop.GradientTape() as tp: |
| tp.watch(x) |
| none, r = foo(x) |
| g = tp.gradient(r, x) |
| |
| self.assertIs(none, None) |
| self.assertAllEqual(r, 25.0) |
| self.assertAllEqual(g, 2 * 5.0) |
| |
| def testNestedDifferentiableFunction(self): |
| @function.defun |
| def inner_fn(a, b): |
| return a * math_ops.add(a, b) |
| |
| @function.defun |
| def outer_fn(x): |
| return inner_fn(x, 1.0) |
| |
| x = constant_op.constant(5.0) |
| with backprop.GradientTape() as tp: |
| tp.watch(x) |
| result = outer_fn(x) |
| grad = tp.gradient(result, x) |
| |
| self.assertAllEqual(grad, 2 * 5.0 + 1.0) |
| |
| def testNestedDifferentiableFunctionNoneOutputs(self): |
| @function.defun |
| def foo(a, b): |
| return None, a * math_ops.add(a, b), None, 2*a |
| |
| @function.defun |
| def bar(x): |
| return foo(x, 1.0) |
| |
| x = constant_op.constant(5.0) |
| with backprop.GradientTape(persistent=True) as tp: |
| tp.watch(x) |
| none1, r1, none2, r2 = bar(x) |
| g1 = tp.gradient(r1, x) |
| g2 = tp.gradient(r2, x) |
| |
| self.assertAllEqual(r1, 30.0) |
| self.assertAllEqual(r2, 10.0) |
| self.assertIs(none1, None) |
| self.assertIs(none2, None) |
| self.assertAllEqual(g1, 2 * 5.0 + 1.0) |
| self.assertAllEqual(g2, 2.0) |
| |
| def testNoneOutput(self): |
| |
| @function.defun |
| def my_function(_): |
| return None |
| |
| self.assertAllEqual(my_function(1), None) |
| |
| def testNestedFunctions(self): |
| # TensorFlow function (which is what would be used in TensorFlow graph |
| # construction). |
| @tf_function.Defun(dtypes.int32, dtypes.int32) |
| def add(a, b): |
| return math_ops.add(a, b) |
| |
| @function.defun |
| def add_one(x): |
| return add(x, 1) |
| |
| self.assertAllEqual(3, add_one(constant_op.constant(2))) |
| |
| def testVariableCaptureInNestedFunctions(self): |
| v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32) |
| |
| @function.defun |
| def inner_read(): |
| return v.read_value() |
| |
| @function.defun |
| def outer(): |
| return inner_read() |
| |
| self.assertEqual(1, int(outer())) |
| |
| def testReturnCapturedEagerTensor(self): |
| t = constant_op.constant(1) |
| |
| @function.defun |
| def read(): |
| return t |
| |
| self.assertEqual(1, int(read())) |
| |
| def testReturnCapturedGraphTensor(self): |
| with context.graph_mode(), self.cached_session(): |
| t = constant_op.constant(1) |
| |
| @function.defun |
| def read(): |
| return t |
| |
| self.assertEqual(1, int(self.evaluate(read()))) |
| |
| def testSequenceInputs(self): |
| clip_by_global_norm = function.defun(clip_ops.clip_by_global_norm) |
| t_list = [constant_op.constant(1.0), constant_op.constant(2.0)] |
| clipped_list, global_norm = clip_by_global_norm(t_list, |
| constant_op.constant(.2)) |
| for t in clipped_list: |
| self.assertTrue(isinstance(t, ops.Tensor)) |
| self.assertTrue(isinstance(global_norm, ops.Tensor)) |
| |
| def testNestedSequenceInputs(self): |
| |
| def my_op(inputs): |
| a, b, c = inputs |
| e, f = b |
| g, h = e |
| return [a + a, [tuple([f + f, g + g]), h + h], c + c], a + f + g + h + c |
| |
| my_eager_op = function.defun(my_op) |
| ret = my_eager_op([ |
| constant_op.constant(1), [(constant_op.constant(2), |
| constant_op.constant(3)), |
| constant_op.constant(4)], |
| constant_op.constant(5) |
| ]) |
| self.assertEqual(len(ret), 2) |
| self.assertAllEqual(ret[0][0], 2) |
| self.assertAllEqual(ret[0][1][0][0], 8) |
| self.assertAllEqual(ret[0][1][0][1], 4) |
| self.assertTrue(isinstance(ret[0][1][0], tuple)) |
| self.assertAllEqual(ret[0][1][1], 6) |
| self.assertAllEqual(ret[0][2], 10) |
| self.assertAllEqual(ret[1], 15) |
| |
| def testVariableNamesRespectNameScopesWithDefun(self): |
| @function.defun |
| def create_variable(): |
| with ops.name_scope('foo'): |
| v = resource_variable_ops.ResourceVariable(0.0, name='bar') |
| self.assertEqual(v.name, 'foo/bar:0') |
| |
| create_variable() |
| |
| def testVariableNamesRespectNameScopesWithDefunInGraph(self): |
| with context.graph_mode(): |
| @function.defun |
| def create_variable(): |
| with ops.name_scope('foo'): |
| v = resource_variable_ops.ResourceVariable([1.0, 2.0], name='bar') |
| self.assertEqual(v.name, 'foo/bar:0') |
| |
| with ops.get_default_graph().as_default(): |
| create_variable() |
| |
| @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) |
| def testLayerInDefun(self): |
| conv = convolutional.Conv2D( |
| filters=1, |
| kernel_size=2, |
| kernel_initializer=init_ops.ones_initializer(), |
| bias_initializer=init_ops.zeros_initializer()) |
| |
| @function.defun |
| def model(x): |
| return conv(x) |
| |
| x = array_ops.ones([1, 2, 2, 1]) |
| y = model(x) |
| |
| if not context.executing_eagerly(): |
| self.evaluate(variables.global_variables_initializer()) |
| |
| self.assertAllEqual([[[[4.0]]]], self.evaluate(y)) |
| |
| # Remove reference cycles in model |
| test_util.dismantle_polymorphic_function(model) |
| |
| @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) |
| def testDefunKerasModelCall(self): |
| model = MiniModel() |
| model.call = function.defun(model.call) |
| |
| x = array_ops.ones([1, 2]) |
| y = model(x) |
| |
| if not context.executing_eagerly(): |
| self.evaluate(variables.global_variables_initializer()) |
| |
| self.assertAllEqual([[3.0]], self.evaluate(y)) |
| |
| # Remove reference cycles in defun. |
| test_util.dismantle_polymorphic_function(model.call) |
| # Break the reference cycle between the MiniModel and the defun: |
| # MiniModel --(through its `call` method)--> PolymorphicFunction |
| # PolymorphicFunction --(instancemethod on MiniModel)--> MiniModel |
| del model.call |
| |
| # Note: The ConfigProto below unfortunately only configures graph |
| # construction. Eager's configuration is controlled in `__main__`. |
| @test_util.run_in_graph_and_eager_modes( |
| config=config_pb2.ConfigProto(device_count={'CPU': 4})) |
| def testDeviceAnnotationsRespected(self): |
| |
| def multi_device_fn(): |
| with ops.device('/cpu:0'): |
| s0 = iterator_ops.Iterator.from_structure( |
| (dtypes.float32,)).string_handle() |
| with ops.device('/cpu:1'): |
| s1 = iterator_ops.Iterator.from_structure( |
| (dtypes.float32,)).string_handle() |
| with ops.device('/cpu:2'): |
| s2 = iterator_ops.Iterator.from_structure( |
| (dtypes.float32,)).string_handle() |
| s3 = iterator_ops.Iterator.from_structure( |
| (dtypes.float32,)).string_handle() |
| return s0, s1, s2, s3 |
| |
| defined = function.defun(multi_device_fn) |
| outputs = self.evaluate(defined()) |
| self.assertEqual(len(defined._function_cache), 1) |
| self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) |
| self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) |
| self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) |
| |
| with ops.device('/cpu:3'): |
| outputs = self.evaluate(defined()) |
| self.assertEqual(len(defined._function_cache), 2) |
| self.assertIn(compat.as_bytes('CPU:0'), outputs[0]) |
| self.assertIn(compat.as_bytes('CPU:1'), outputs[1]) |
| self.assertIn(compat.as_bytes('CPU:2'), outputs[2]) |
| self.assertIn(compat.as_bytes('CPU:3'), outputs[3]) |
| |
| # This should retrieve the call-site-device agnostic function |
| defined() |
| self.assertEqual(len(defined._function_cache), 2) |
| |
| # And this should retrieve the function created for '/cpu:3' |
| with ops.device('/cpu:3'): |
| defined() |
| self.assertEqual(len(defined._function_cache), 2) |
| |
| @test_util.run_in_graph_and_eager_modes( |
| config=config_pb2.ConfigProto(device_count={'CPU': 2})) |
| def testCallingGraphFunctionOnIncompatibleDeviceRaisesError(self): |
| |
| def func(): |
| return constant_op.constant(0) |
| |
| defined = function.defun(func) |
| with ops.device('cpu:0'): |
| cpu_graph_function = defined.get_concrete_function() |
| |
| with ops.device('cpu:0'): |
| self.assertEqual( |
| self.evaluate(cpu_graph_function()), self.evaluate(func())) |
| |
| with self.assertRaisesRegexp( |
| ValueError, |
| 'The current device stack does not match the device stack under ' |
| 'which the TensorFlow function \'.*func.*\' was created.\n' |
| 'Current device stack: .*\n.*func.* device stack.*'): |
| with ops.device('cpu:1'): |
| cpu_graph_function() |
| |
| with self.assertRaisesRegexp( |
| ValueError, |
| 'The current device stack does not match the device stack under ' |
| 'which the TensorFlow function \'.*func.*\' was created.\n' |
| 'Current device stack: .*\n.*func.* device stack.*'): |
| with ops.device(None): |
| cpu_graph_function() |
| |
| default_graph_function = defined.get_concrete_function() |
| self.assertEqual( |
| self.evaluate(default_graph_function()), self.evaluate(func())) |
| |
| with self.assertRaisesRegexp( |
| ValueError, |
| 'The current device stack does not match the device stack under ' |
| 'which the TensorFlow function \'.*func.*\' was created.\n' |
| 'Current device stack: .*\n.*func.* device stack.*'): |
| with ops.device('cpu:1'): |
| default_graph_function() |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testColocateWithRespected(self): |
| # TODO(b/113291792): Use multiple CPUs instead of a GPU. |
| if not context.context().num_gpus(): |
| self.skipTest('No GPUs found.') |
| |
| with ops.device('cpu:0'): |
| x = constant_op.constant(1.0) |
| |
| with ops.device('gpu:0'): |
| y = constant_op.constant(1.0) |
| |
| @function.defun |
| def foo(): |
| return iterator_ops.Iterator.from_structure( |
| (dtypes.float32,)).string_handle() |
| |
| with ops.colocate_with(x): |
| self.assertIn(compat.as_bytes('CPU:0'), self.evaluate(foo())) |
| |
| with ops.colocate_with(y): |
| self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo())) |
| |
| def testVariablesAreTracked(self): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| |
| def foo(x): |
| return v * x |
| |
| defined = function.defun(foo) |
| |
| x = constant_op.constant([1.0]) |
| self.assertEqual(1., self.evaluate(defined(x))) |
| v.assign(2.) |
| |
| x = constant_op.constant([1.0, 2.0]) |
| self.assertAllEqual([2., 4.], self.evaluate(defined(x))) |
| |
| def testPythonFunctionWithDefaultArgs(self): |
| |
| def func(foo, bar=1, baz=2): |
| del foo |
| del bar |
| del baz |
| return |
| |
| defined = function.defun(func) |
| defined(0, baz=20) |
| |
| def cache_keys(): |
| """Sanitizes cache keys of non-input metadata.""" |
| return tuple(key[:3] for key in defined._function_cache) |
| |
| # `True` corresponds to the fact that we're executing eagerly |
| self.assertIn((0, 1, 20), cache_keys()) |
| |
| defined(1) # bar=1, baz=2 |
| self.assertIn((1, 1, 2), cache_keys()) |
| |
| # This matches the previous call. |
| defined(foo=1) |
| self.assertEqual(len(defined._function_cache), 2) |
| |
| defined(1, 2, 3) |
| self.assertIn((1, 2, 3), cache_keys()) |
| |
| # This matches the previous call. |
| defined(1, bar=2, baz=3) |
| self.assertEqual(len(defined._function_cache), 3) |
| |
| # This matches the previous call. |
| defined(1, baz=3, bar=2) |
| self.assertEqual(len(defined._function_cache), 3) |
| |
| def testFunctoolsPartialUnwrappedCorrectly(self): |
| |
| def full_function(a, b, c=3): |
| return a, b, c |
| |
| partial = functools.partial(full_function, 1, c=3) |
| a, b, c = partial(2) |
| |
| defined = function.defun(partial) |
| func_a, func_b, func_c = defined(2) |
| self.assertEqual(func_a.numpy(), a) |
| self.assertEqual(func_b.numpy(), b) |
| self.assertEqual(func_c.numpy(), c) |
| |
| def testInputSignatureWithCompatibleInputs(self): |
| |
| def foo(a): |
| self.assertEqual(a.shape, (2,)) |
| return a |
| |
| signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)] |
| defined = function.defun(foo, input_signature=signature) |
| a = array_ops.ones([2]) |
| out = defined(a) |
| self.assertEqual(len(defined._function_cache), 1) |
| self.assertAllEqual(out, a) |
| |
| def bar(a): |
| self.assertEqual(a._shape_tuple(), (2, None)) |
| return a |
| |
| signature = [tensor_spec.TensorSpec((2, None), dtypes.float32)] |
| defined = function.defun(bar, input_signature=signature) |
| a = array_ops.ones([2, 1]) |
| out = defined(a) |
| self.assertEqual(len(defined._function_cache), 1) |
| self.assertAllEqual(out, a) |
| |
| # Changing the second dimension shouldn't create a new function. |
| b = array_ops.ones([2, 3]) |
| out = defined(b) |
| self.assertEqual(len(defined._function_cache), 1) |
| self.assertAllEqual(out, b) |
| |
| def testNestedInputSignatures(self): |
| |
| def foo(a, b): |
| self.assertEqual(a[0]._shape_tuple(), (2, None)) |
| self.assertEqual(a[1]._shape_tuple(), (2, None)) |
| self.assertEqual(b._shape_tuple(), (1,)) |
| return [a, b] |
| |
| signature = [[tensor_spec.TensorSpec((2, None), dtypes.float32)] * 2, |
| tensor_spec.TensorSpec((1,), dtypes.float32)] |
| defined = function.defun(foo, input_signature=signature) |
| a = array_ops.ones([2, 1]) |
| b = array_ops.ones([1]) |
| out = defined([a, a], b) |
| self.assertEqual(len(defined._function_cache), 1) |
| nest.assert_same_structure(out, [[a, a], b]) |
| self.assertAllEqual(out[0][0], a) |
| self.assertAllEqual(out[0][1], a) |
| self.assertAllEqual(out[1], b) |
| |
| # Changing the unspecified dimensions shouldn't create a new function. |
| a = array_ops.ones([2, 3]) |
| b = array_ops.ones([2, 5]) |
| c = array_ops.ones([1]) |
| out = defined([a, b], c) |
| self.assertEqual(len(defined._function_cache), 1) |
| nest.assert_same_structure(out, [[a, b], c]) |
| self.assertAllEqual(out[0][0], a) |
| self.assertAllEqual(out[0][1], b) |
| self.assertAllEqual(out[1], c) |
| |
| def bar(a): |
| self.assertEqual(a['a']._shape_tuple(), (2, None)) |
| self.assertEqual(a['b']._shape_tuple(), (2, None)) |
| self.assertEqual(a['c']._shape_tuple(), (1,)) |
| return a |
| |
| signature = [{ |
| 'a': tensor_spec.TensorSpec((2, None), dtypes.float32), |
| 'b': tensor_spec.TensorSpec((2, None), dtypes.float32), |
| 'c': tensor_spec.TensorSpec((1,), dtypes.float32) |
| }] |
| a = array_ops.ones([2, 3]) |
| b = array_ops.ones([1]) |
| inputs = {'a': a, 'b': a, 'c': b} |
| defined = function.defun(bar, input_signature=signature) |
| out = defined(inputs) |
| nest.assert_same_structure(out, inputs) |
| self.assertAllEqual(out['a'], inputs['a']) |
| self.assertAllEqual(out['b'], inputs['b']) |
| self.assertAllEqual(out['c'], inputs['c']) |
| |
| def testInputSignatureMustBeSequenceOfTensorSpecs(self): |
| |
| def foo(a, b): |
| del a |
| del b |
| |
| # Signatures must consist exclusively of `TensorSpec` objects. |
| signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)] |
| with self.assertRaisesRegexp(TypeError, 'Invalid input_signature.*'): |
| function.defun(foo, input_signature=signature) |
| |
| # Signatures must be either lists or tuples on their outermost levels. |
| signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)} |
| with self.assertRaisesRegexp(TypeError, 'input_signature must be either a ' |
| 'tuple or a list.*'): |
| function.defun(foo, input_signature=signature) |
| |
| def testInputsIncompatibleWithSignatureRaisesError(self): |
| |
| def foo(a): |
| return a |
| |
| signature = [tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.float32)] |
| defined = function.defun(foo, input_signature=signature) |
| |
| # Invalid shapes. |
| with self.assertRaisesRegexp(ValueError, 'Python inputs incompatible.*'): |
| defined(array_ops.ones([3])) |
| |
| with self.assertRaisesRegexp(ValueError, 'Python inputs incompatible.*'): |
| defined(array_ops.ones([2, 1])) |
| |
| # Wrong number of arguments. |
| with self.assertRaisesRegexp(ValueError, |
| 'Structure of Python function inputs.*'): |
| defined(array_ops.ones([2]), array_ops.ones([2])) |
| with self.assertRaisesRegexp(ValueError, |
| 'Structure of Python function inputs.*'): |
| defined() |
| |
| def testInputSignatureForFunctionWithNonTensorInputsNotAllowed(self): |
| |
| def foo(a, training=True): |
| if training: |
| return a |
| else: |
| return -1.0 * a |
| |
| signature = [tensor_spec.TensorSpec([], dtypes.float32)] * 2 |
| defined = function.defun(foo, input_signature=signature) |
| a = constant_op.constant(1.0) |
| with self.assertRaisesRegexp( |
| ValueError, 'When input_signature is provided, ' |
| 'all inputs to the Python function must be Tensors.'): |
| defined(a, training=True) |
| |
| def testInputSignatureWithKeywordPositionalArgs(self): |
| |
| @function.defun(input_signature=[ |
| tensor_spec.TensorSpec([], dtypes.float32), |
| tensor_spec.TensorSpec([], dtypes.int64) |
| ]) |
| def foo(flt, integer): |
| return flt, integer |
| |
| flt = constant_op.constant(1.0) |
| integer = constant_op.constant(2, dtypes.int64) |
| |
| out1, out2 = foo(flt, integer) |
| self.assertEqual(len(foo._function_cache), 1) |
| self.assertEqual(out1.numpy(), 1.0) |
| self.assertEqual(out2.numpy(), 2) |
| |
| out1, out2 = foo(flt=flt, integer=integer) |
| self.assertEqual(len(foo._function_cache), 1) |
| self.assertEqual(out1.numpy(), 1.0) |
| self.assertEqual(out2.numpy(), 2) |
| |
| out1, out2 = foo(integer=integer, flt=flt) |
| self.assertEqual(len(foo._function_cache), 1) |
| self.assertEqual(out1.numpy(), 1.0) |
| self.assertEqual(out2.numpy(), 2) |
| |
| out1, out2 = foo(flt, integer=integer) |
| self.assertEqual(len(foo._function_cache), 1) |
| self.assertEqual(out1.numpy(), 1.0) |
| self.assertEqual(out2.numpy(), 2) |
| |
| def testInputSignatureWithKeywordArgsFails(self): |
| |
| def foo(a, **kwargs): |
| del a |
| del kwargs |
| |
| with self.assertRaisesRegexp( |
| ValueError, 'Cannot define a TensorFlow function from a Python ' |
| 'function with keyword arguments when input_signature.*'): |
| function.defun( |
| foo, |
| input_signature=[ |
| tensor_spec.TensorSpec([], dtypes.float32), |
| tensor_spec.TensorSpec([], dtypes.int64) |
| ]) |
| |
| def testTensorKeywordArguments(self): |
| |
| def foo(a, b): |
| del a |
| return b |
| |
| defined = function.defun(foo) |
| a = constant_op.constant(2.0) |
| b = constant_op.constant([1.0, 2.0]) |
| one = defined(a, b) |
| self.assertEqual(len(defined._function_cache), 1) |
| |
| two = defined(a=a, b=b) |
| self.assertEqual(len(defined._function_cache), 1) |
| |
| three = defined(b=b, a=a) |
| self.assertEqual(len(defined._function_cache), 1) |
| |
| four = defined(a, b=b) |
| self.assertEqual(len(defined._function_cache), 1) |
| |
| # The next call corresponds to a new input signature, hence |
| # we expect another function to be defined. |
| five = defined(b, a) |
| self.assertEqual(len(defined._function_cache), 2) |
| |
| six = defined(a=b, b=a) |
| self.assertEqual(len(defined._function_cache), 2) |
| |
| seven = defined(b=a, a=b) |
| self.assertEqual(len(defined._function_cache), 2) |
| |
| self.assertAllEqual(one, [1.0, 2.0]) |
| self.assertAllEqual(two, [1.0, 2.0]) |
| self.assertAllEqual(three, [1.0, 2.0]) |
| self.assertAllEqual(four, [1.0, 2.0]) |
| self.assertAllEqual(five, 2.0) |
| self.assertAllEqual(six, 2.0) |
| self.assertAllEqual(seven, 2.0) |
| |
| def testGradientWithKeywordArguments(self): |
| matmul = function.defun(math_ops.matmul) |
| |
| def sq(x): |
| return matmul(a=x, b=x, transpose_a=True) |
| |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| grad_t, = backprop.gradients_function(sq, [0])(t) |
| self.assertAllEqual(grad_t, [[6, 6], [14, 14]]) |
| |
| with backprop.GradientTape(persistent=True) as tape: |
| tape.watch(t) |
| one = matmul(t, b=t, transpose_a=True) |
| two = matmul(b=t, a=t, transpose_a=True) |
| three = matmul(a=t, b=t, transpose_a=True) |
| |
| for output in [one, two, three]: |
| self.assertAllEqual(tape.gradient(output, t), [[6, 6], [14, 14]]) |
| |
| def testGradientInFunctionWithKeywordArguments(self): |
| |
| @function.defun |
| def f(x): |
| return backprop.gradients_function(lambda y: y * y, [0])(x)[0] |
| |
| self.assertAllEqual(f(x=constant_op.constant(1.0)), 2.0) |
| |
| def testDefuningInstanceMethod(self): |
| |
| integer = constant_op.constant(2, dtypes.int64) |
| |
| class Foo(object): |
| |
| def one(self, tensor): |
| return tensor |
| |
| @function.defun |
| def two(self, tensor, other=integer): |
| return self.one(tensor), other |
| |
| foo = Foo() |
| t = constant_op.constant(1.0) |
| one, two = foo.two(t) |
| self.assertEqual(one.numpy(), 1.0) |
| self.assertEqual(two.numpy(), 2) |
| |
| def testDefuningInstanceMethodWithDefaultArgument(self): |
| |
| integer = constant_op.constant(2, dtypes.int64) |
| |
| class Foo(object): |
| |
| @function.defun |
| def func(self, other=integer): |
| return other |
| |
| foo = Foo() |
| self.assertEqual(foo.func().numpy(), int(integer)) |
| |
| def testPythonCallWithSideEffects(self): |
| state = [] |
| |
| @function.defun |
| def side_effecting_function(): |
| state.append(0) |
| |
| side_effecting_function() |
| self.assertAllEqual(state, [0]) |
| |
| # The second invocation should call the graph function, which shouldn't |
| # trigger the list append. |
| side_effecting_function() |
| self.assertAllEqual(state, [0]) |
| |
| # Whereas calling the python function directly should create a side-effect. |
| side_effecting_function.python_function() |
| self.assertAllEqual(state, [0, 0]) |
| |
| def testFunctionWithExtraAttributes(self): |
| @function.defun_with_attributes(attributes={'experimental_1': 'value1', |
| 'experimental_2': 2}) |
| def matmul(x, y): |
| return math_ops.matmul(x, y) |
| |
| def add(x, y): |
| return math_ops.add(x, y) |
| defun_add = function.defun_with_attributes( |
| add, attributes={'experimental_3': True, 'experimental_4': 1.0}) |
| |
| with context.graph_mode(), self.test_session(): |
| with ops.get_default_graph().as_default(): |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| sq = matmul(t, t) |
| double = defun_add(t, t) |
| self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) |
| self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8]) |
| |
| graph = ops.get_default_graph() |
| # pylint: disable=protected-access |
| self.assertEqual(len(graph._functions), 2) |
| functions = list(graph._functions.values()) |
| self.assertRegexpMatches( |
| functions[0].definition.signature.name, '.*matmul.*') |
| attrs = functions[0].definition.attr |
| self.assertEqual(len(attrs), 2) |
| self.assertEqual(attrs['experimental_1'].s, b'value1') |
| self.assertEqual(attrs['experimental_2'].i, 2) |
| |
| self.assertRegexpMatches( |
| functions[1].definition.signature.name, '.*add.*') |
| attrs = functions[1].definition.attr |
| self.assertEqual(len(attrs), 2) |
| self.assertEqual(attrs['experimental_3'].b, True) |
| self.assertEqual(attrs['experimental_4'].f, 1.0) |
| # pylint: enable=protected-access |
| |
| def testFunctionWithInvalidAttribute(self): |
| @function.defun_with_attributes(attributes={'attr1': 'value1'}) |
| def matmul(x, y): |
| return math_ops.matmul(x, y) |
| |
| with self.assertRaisesRegexp(ValueError, |
| '.*Attribute name is not whitelisted.*'): |
| with context.graph_mode(), self.test_session(): |
| with ops.get_default_graph().as_default(): |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| matmul(t, t) |
| |
| @function.defun_with_attributes(attributes={'experimental_1': ['value1']}) |
| def add(x, y): |
| return math_ops.add(x, y) |
| |
| with self.assertRaisesRegexp(ValueError, |
| '.*Unsupported attribute type.*'): |
| with context.graph_mode(), self.test_session(): |
| with ops.get_default_graph().as_default(): |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| add(t, t) |
| |
| def testRegisterFunction(self): |
| @function.defun |
| def add(x, y): |
| return math_ops.add(x, y) |
| |
| def matmul(x, y): |
| return math_ops.matmul(x, y) |
| defun_matmul = function.defun(matmul) |
| |
| with context.graph_mode(), self.cached_session(): |
| with ops.get_default_graph().as_default(): |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| function.register(defun_matmul, t, t) |
| function.register(add, t, t) |
| |
| graph = ops.get_default_graph() |
| # pylint: disable=protected-access |
| self.assertEqual(len(graph._functions), 2) |
| functions = list(graph._functions.values()) |
| pre_register_matmul_func_name = functions[0].definition.signature.name |
| self.assertRegexpMatches(pre_register_matmul_func_name, '.*matmul.*') |
| pre_register_add_func_name = functions[1].definition.signature.name |
| self.assertRegexpMatches(pre_register_add_func_name, '.*add.*') |
| |
| sq = defun_matmul(t, t) |
| double = add(t, t) |
| self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) |
| self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8]) |
| # Make sure the pre registered function is used, and no other function |
| # is added. |
| self.assertEqual(len(graph._functions), 2) |
| functions = list(graph._functions.values()) |
| called_func_name = functions[0].definition.signature.name |
| self.assertEqual(pre_register_matmul_func_name, called_func_name) |
| called_func_name = functions[1].definition.signature.name |
| self.assertEqual(pre_register_add_func_name, called_func_name) |
| |
| def testRegisterFunctionWithInputSignature(self): |
| def matmul(x, y): |
| return math_ops.matmul(x, y) |
| defun_matmul = function.defun( |
| matmul, |
| input_signature=[ |
| tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32), |
| tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32) |
| ]) |
| with context.graph_mode(), self.cached_session(): |
| with ops.get_default_graph().as_default(): |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| function.register(defun_matmul, t, t) |
| |
| graph = ops.get_default_graph() |
| # pylint: disable=protected-access |
| self.assertEqual(len(graph._functions), 1) |
| |
| # Test input param shape mismatch |
| t2 = constant_op.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) |
| with self.assertRaisesRegexp( |
| ValueError, 'Python inputs incompatible with input_signature'): |
| function.register(defun_matmul, t2, t2) |
| |
| def testRegisterFunctionWithCache(self): |
| def matmul(x, y): |
| return math_ops.matmul(x, y) |
| defun_matmul = function.defun(matmul) |
| |
| with context.graph_mode(), self.cached_session(): |
| with ops.get_default_graph().as_default(): |
| t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) |
| t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]]) |
| function.register(defun_matmul, t, t) |
| function.register(defun_matmul, t2, t2) |
| |
| graph = ops.get_default_graph() |
| # Only one function is registered since the input param are in same type |
| # pylint: disable=protected-access |
| self.assertEqual(len(graph._functions), 1) |
| |
| def testCallingFunctionWithDifferentVariables(self): |
| |
| @function.defun |
| def foo(v): |
| v.assign_add(1.0) |
| return v.read_value() |
| |
| v = resource_variable_ops.ResourceVariable(0.0) |
| graph_function = foo.get_concrete_function(v) |
| self.assertEqual(len(graph_function.inputs), 1) |
| self.assertEqual(len(graph_function.captured_inputs), 0) |
| |
| self.assertEqual(float(graph_function(v)), 1.0) |
| self.assertEqual(float(graph_function(v)), 2.0) |
| |
| w = resource_variable_ops.ResourceVariable(0.0) |
| |
| @function.defun |
| def bar(v): |
| del v |
| return constant_op.constant(1.0) |
| |
| graph_function = bar.get_concrete_function(v) |
| self.assertEqual(float(graph_function(v)), 1.0) |
| self.assertEqual(float(graph_function(w)), 1.0) |
| |
| def testCallingFunctionWithNonTensorsFails(self): |
| |
| @function.defun |
| def foo(x): |
| return x |
| |
| graph_function = foo.get_concrete_function(constant_op.constant(1.0)) |
| with self.assertRaisesRegexp(ValueError, 'All inputs to `Function`s must ' |
| 'be Tensors;.*'): |
| graph_function('Not a Tensor.') |
| |
| def testSwapImplementationWithGrapplerPlugin(self): |
| rewrites = rewriter_config_pb2.RewriterConfig() |
| # function_optimizer has to be turn off, otherwise it will delete the |
| # registered function if it does not get called. |
| # TODO(scottzhu): Move the ExperimentalImplementationSelector to be called |
| # before function_optimizer in future. |
| rewrites.function_optimization = rewriter_config_pb2.RewriterConfig.OFF |
| customer_optimizer = rewrites.custom_optimizers.add() |
| customer_optimizer.name = 'ExperimentalImplementationSelector' |
| rewrites.min_graph_nodes = -1 |
| graph_options = config_pb2.GraphOptions( |
| rewrite_options=rewrites, build_cost_model=1) |
| config = config_pb2.ConfigProto(graph_options=graph_options) |
| |
| with context.graph_mode(), self.cached_session( |
| config=config, graph=ops.Graph(), use_gpu=True) as sess: |
| |
| @function.defun_with_attributes( |
| attributes={ |
| 'experimental_api_implements': 'random_boost', |
| 'experimental_api_preferred_device': 'CPU' |
| }) |
| def cpu_boost(x): |
| return math_ops.add(x, 2.0) |
| |
| @function.defun_with_attributes( |
| attributes={ |
| 'experimental_api_implements': 'random_boost', |
| 'experimental_api_preferred_device': 'GPU' |
| }) |
| def gpu_boost(x): |
| return math_ops.add(x, 4.0) |
| |
| x = constant_op.constant(1.0) |
| |
| function.register(cpu_boost, x) |
| y = gpu_boost(x) |
| y_value = sess.run(y) |
| |
| if test.is_gpu_available(): |
| self.assertEquals(y_value, 5.0) |
| else: |
| # Grappler fallback to use the CPU impl even called with GPU function. |
| self.assertEquals(y_value, 3.0) |
| |
| |
| @test_util.with_c_shapes |
| class AutomaticControlDependenciesTest(test.TestCase): |
| |
| def testBasic(self): |
| with context.graph_mode(), self.cached_session(): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| variables.global_variables_initializer().run() |
| with function.AutomaticControlDependencies() as c: |
| v.assign(v + 1) |
| v.assign(2 * v) |
| val = v.read_value() |
| val = c.mark_as_return(val) |
| self.assertAllEqual(val.eval(), 4.0) |
| |
| def testCondMustRun(self): |
| with context.graph_mode(), self.cached_session(): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| variables.global_variables_initializer().run() |
| p = array_ops.placeholder(dtype=dtypes.bool) |
| with function.AutomaticControlDependencies() as c: |
| |
| def true_fn(): |
| v.assign(v + 1) |
| return 0.0 |
| |
| def false_fn(): |
| v.assign(v + 4) |
| return 1.0 |
| |
| control_flow_ops.cond(p, true_fn, false_fn) |
| val = v.read_value() |
| val = c.mark_as_return(val) |
| self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0) |
| self.assertAllEqual(val.eval(feed_dict={p: True}), 6.0) |
| |
| def testCondMustRunSeparateRead(self): |
| with context.graph_mode(), self.cached_session(): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| variables.global_variables_initializer().run() |
| p = array_ops.placeholder(dtype=dtypes.bool) |
| with function.AutomaticControlDependencies() as c: |
| |
| def true_fn(): |
| v.assign(v + 1) |
| return 0.0 |
| |
| def false_fn(): |
| v.assign(v + 4) |
| return 1.0 |
| |
| control_flow_ops.cond(p, true_fn, false_fn) |
| one = constant_op.constant(1.0) |
| one = c.mark_as_return(one) |
| one.eval(feed_dict={p: False}) |
| self.assertAllEqual(v.read_value().eval(), 5.0) |
| one.eval(feed_dict={p: True}) |
| self.assertAllEqual(v.read_value().eval(), 6.0) |
| |
| def testCondNested(self): |
| with context.graph_mode(), self.cached_session(): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| variables.global_variables_initializer().run() |
| p = array_ops.placeholder(dtype=dtypes.bool) |
| q = array_ops.placeholder(dtype=dtypes.bool) |
| with function.AutomaticControlDependencies() as c: |
| |
| def true_fn(): |
| v.assign(v + 1, name='true') |
| return 1.0 |
| |
| def false_fn(): |
| |
| def inner_true_fn(): |
| v.assign(v * 2, name='false_true') |
| return 2.0 |
| |
| def inner_false_fn(): |
| v.assign(v * 3, name='false_false') |
| return 3.0 |
| |
| control_flow_ops.cond(q, inner_true_fn, inner_false_fn) |
| return 1.0 |
| |
| control_flow_ops.cond(p, true_fn, false_fn) |
| with ops.name_scope('final'): |
| val = v.read_value() |
| val = c.mark_as_return(val) |
| self.assertAllEqual(val.eval(feed_dict={p: False, q: False}), 3.0) |
| self.assertAllEqual(val.eval(feed_dict={p: False, q: True}), 6.0) |
| self.assertAllEqual(val.eval(feed_dict={p: True, q: True}), 7.0) |
| self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0) |
| |
| def testCondOneBranch(self): |
| with context.graph_mode(), self.cached_session(): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| variables.global_variables_initializer().run() |
| p = array_ops.placeholder(dtype=dtypes.bool) |
| with function.AutomaticControlDependencies() as c: |
| |
| def true_fn(): |
| return 0.0 |
| |
| def false_fn(): |
| v.assign(v + 4) |
| return 1.0 |
| |
| control_flow_ops.cond(p, true_fn, false_fn) |
| val = v.read_value() |
| val = c.mark_as_return(val) |
| self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0) |
| self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0) |
| |
| def testCondOneBranchUpdateBefore(self): |
| with context.graph_mode(), self.cached_session(): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| variables.global_variables_initializer().run() |
| p = array_ops.placeholder(dtype=dtypes.bool) |
| with function.AutomaticControlDependencies() as c: |
| v.assign(v * 2) |
| |
| def true_fn(): |
| return 0.0 |
| |
| def false_fn(): |
| v.assign(v + 4) |
| return 1.0 |
| |
| control_flow_ops.cond(p, true_fn, false_fn) |
| val = v.read_value() |
| val = c.mark_as_return(val) |
| self.assertAllEqual(val.eval(feed_dict={p: False}), 6.0) |
| self.assertAllEqual(val.eval(feed_dict={p: True}), 12.0) |
| |
| def testCondOneBranchUpdateAfter(self): |
| with context.graph_mode(), self.cached_session(): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| variables.global_variables_initializer().run() |
| p = array_ops.placeholder(dtype=dtypes.bool) |
| with function.AutomaticControlDependencies() as c: |
| |
| def true_fn(): |
| return 0.0 |
| |
| def false_fn(): |
| v.assign(v + 4) |
| return 1.0 |
| |
| control_flow_ops.cond(p, true_fn, false_fn) |
| v.assign(v * 2) |
| val = v.read_value() |
| val = c.mark_as_return(val) |
| self.assertAllEqual(val.eval(feed_dict={p: False}), 10.0) |
| self.assertAllEqual(val.eval(feed_dict={p: True}), 20.0) |
| |
| def testDefunWhileLoopWithCapturedLoopVars(self): |
| n = 3 |
| x = constant_op.constant(list(range(n))) |
| |
| @function.defun |
| def loop(): |
| c = lambda i, x: i < n |
| b = lambda i, x: (i + 1, x + 1) |
| i, out = control_flow_ops.while_loop(c, b, (0, x)) |
| return i, out |
| |
| i, out = loop() |
| self.assertEqual(int(i), 3) |
| self.assertAllEqual(out, [3, 4, 5]) |
| |
| def testDecorator(self): |
| with context.graph_mode(), self.cached_session(): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| variables.global_variables_initializer().run() |
| |
| @function.automatic_control_dependencies |
| def f(): |
| v.assign(v + 1) |
| v.assign(2 * v) |
| return v.read_value() |
| |
| self.assertAllEqual(f().eval(), 4.0) |
| |
| def testOptimizerInDefun(self): |
| def loss(v): |
| return v**2 |
| |
| optimizer = momentum.MomentumOptimizer(learning_rate=1.0, momentum=1.0) |
| |
| @function.defun |
| def train(): |
| self.v = resource_variable_ops.ResourceVariable(1.0) |
| grad = backprop.implicit_grad(loss)(self.v) |
| optimizer.apply_gradients(grad) |
| return self.v.read_value() |
| |
| value = train() |
| self.assertEqual(value.numpy(), -1.0) |
| |
| def testReturningNonTensorRaisesError(self): |
| optimizer = momentum.MomentumOptimizer(learning_rate=1.0, momentum=1.0) |
| optimizer.apply_gradients = function.defun(optimizer.apply_gradients) |
| v = resource_variable_ops.ResourceVariable(1.0) |
| grad = backprop.implicit_grad(lambda v: v**2)(v) |
| |
| with self.assertRaisesRegexp(TypeError, |
| '.*must return zero or more Tensors.*'): |
| # TODO(akshayka): We might want to allow defun-ing Python functions |
| # that return operations (and just execute the op instead of running it). |
| optimizer.apply_gradients(grad) |
| |
| # TODO(b/111663004): This should work when the outer context is graph |
| # building. |
| def testOptimizerNonSlotVarsInDefunNoError(self): |
| def loss(v): |
| return v**2 |
| |
| optimizer = adam.AdamOptimizer(learning_rate=1.0) |
| |
| @function.defun |
| def train(): |
| self.v = resource_variable_ops.ResourceVariable(1.0) |
| grad = backprop.implicit_grad(loss)(self.v) |
| optimizer.apply_gradients(grad) |
| return self.v.read_value() |
| |
| train() |
| |
| def testOptimizerInDefunWithCapturedVariable(self): |
| v = resource_variable_ops.ResourceVariable(1.0) |
| def loss(): |
| return v**2 |
| |
| optimizer = momentum.MomentumOptimizer(learning_rate=1.0, momentum=1.0) |
| |
| @function.defun |
| def train(): |
| grad = backprop.implicit_grad(loss)() |
| optimizer.apply_gradients(grad) |
| |
| train() |
| self.assertEqual(v.numpy(), -1.0) |
| |
| def testFunctionModifiesInputList(self): |
| # Tests on `list` methods that do in place modification, except `list.sort` |
| # since it cannot even be "defunned" in the first place |
| |
| def get_list(): |
| return [constant_op.constant(0.), constant_op.constant(1.)] |
| |
| expected_msg = ( |
| 'Function to be traced should not modify structure of input ' |
| 'arguments. Check if your function has list and dictionary ' |
| 'operations that alter input arguments, ' |
| 'such as `list.pop`, `list.append`') |
| |
| with self.assertRaisesRegexp(ValueError, expected_msg): |
| |
| @function.defun |
| def append(l): |
| l.append(constant_op.constant(0.)) |
| |
| append(get_list()) |
| |
| with self.assertRaisesRegexp(ValueError, expected_msg): |
| |
| @function.defun |
| def extend(l): |
| l.extend([constant_op.constant(0.)]) |
| |
| extend(get_list()) |
| |
| with self.assertRaisesRegexp(ValueError, expected_msg): |
| |
| @function.defun |
| def insert(l): |
| l.insert(0, constant_op.constant(0.)) |
| |
| insert(get_list()) |
| |
| with self.assertRaisesRegexp(ValueError, expected_msg): |
| |
| @function.defun |
| def pop(l): |
| l.pop() |
| |
| pop(get_list()) |
| |
| with self.assertRaisesRegexp(ValueError, expected_msg): |
| |
| @function.defun |
| def reverse(l): |
| l.reverse() |
| |
| reverse(get_list()) |
| |
| with self.assertRaisesRegexp(ValueError, expected_msg): |
| |
| @function.defun |
| def remove(l): |
| l.remove(l[0]) |
| |
| remove(get_list()) |
| |
| # `list.clear` is a method that is in Py3 but not Py2 |
| if sys.version.startswith('3'): |
| |
| with self.assertRaisesRegexp(ValueError, expected_msg): |
| |
| @function.defun |
| def clear(l): |
| l.clear() |
| |
| clear(get_list()) |
| |
| # One last test for keyword arguments |
| with self.assertRaisesRegexp(ValueError, expected_msg): |
| |
| @function.defun |
| def kwdappend(**kwargs): |
| l = kwargs['l'] |
| l.append(constant_op.constant(0.)) |
| |
| kwdappend(l=get_list()) |
| |
| def testFunctionModifiesInputDict(self): |
| |
| def get_dict(): |
| return {'t1': constant_op.constant(0.), 't2': constant_op.constant(1.)} |
| |
| expected_msg = ( |
| 'Function to be traced should not modify structure of input ' |
| 'arguments. Check if your function has list and dictionary ' |
| 'operations that alter input arguments, ' |
| 'such as `list.pop`, `list.append`') |
| |
| with self.assertRaisesRegexp(ValueError, expected_msg): |
| |
| @function.defun |
| def clear(m): |
| m.clear() |
| |
| clear(get_dict()) |
| |
| with self.assertRaisesRegexp(ValueError, expected_msg): |
| |
| @function.defun |
| def pop(m): |
| m.pop('t1') |
| |
| pop(get_dict()) |
| |
| with self.assertRaisesRegexp(ValueError, expected_msg): |
| |
| @function.defun |
| def popitem(m): |
| m.popitem() |
| |
| popitem(get_dict()) |
| |
| with self.assertRaisesRegexp(ValueError, expected_msg): |
| |
| @function.defun |
| def update(m): |
| m.update({'t1': constant_op.constant(3.)}) |
| |
| update(get_dict()) |
| |
| with self.assertRaisesRegexp(ValueError, expected_msg): |
| |
| @function.defun |
| def setdefault(m): |
| m.setdefault('t3', constant_op.constant(3.)) |
| |
| setdefault(get_dict()) |
| |
| def testFunctionModifiesInputNest(self): |
| # Test on functions that modify structure of nested input arguments |
| expected_msg = ( |
| 'Function to be traced should not modify structure of input ' |
| 'arguments. Check if your function has list and dictionary ' |
| 'operations that alter input arguments, ' |
| 'such as `list.pop`, `list.append`') |
| |
| with self.assertRaisesRegexp(ValueError, expected_msg): |
| |
| @function.defun |
| def modify(n): |
| n[0]['t1'].append(constant_op.constant(1.)) |
| |
| nested_input = [{ |
| 't1': [constant_op.constant(0.), |
| constant_op.constant(1.)], |
| }, |
| constant_op.constant(2.)] |
| |
| modify(nested_input) |
| |
| with self.assertRaisesRegexp(ValueError, expected_msg): |
| |
| # The flat list doesn't change whereas the true structure changes |
| @function.defun |
| def modify_same_flat(n): |
| n[0].append(n[1].pop(0)) |
| |
| nested_input = [[constant_op.constant(0.)], |
| [constant_op.constant(1.), |
| constant_op.constant(2.)]] |
| |
| modify_same_flat(nested_input) |
| |
| def testDecoratedMethodVariableCleanup(self): |
| m = DefunnedMiniModel() |
| m(array_ops.ones([1, 2])) |
| weak_variables = weakref.WeakSet(m.variables) |
| self.assertEqual(2, len(weak_variables)) |
| del m |
| self.assertEqual([], list(weak_variables)) |
| |
| if __name__ == '__main__': |
| ops.enable_eager_execution( |
| config=config_pb2.ConfigProto(device_count={'CPU': 4})) |
| test.main() |