| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for while_v2.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| from absl.testing import parameterized |
| |
| from google.protobuf import text_format |
| from tensorflow.core.framework import graph_pb2 |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.core.protobuf import rewriter_config_pb2 |
| from tensorflow.python.eager import backprop |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import function |
| from tensorflow.python.framework import importer |
| from tensorflow.python.framework import meta_graph |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.grappler import tf_optimizer |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import control_flow_util |
| from tensorflow.python.ops import control_flow_util_v2 |
| from tensorflow.python.ops import control_flow_v2_toggles |
| from tensorflow.python.ops import custom_gradient |
| from tensorflow.python.ops import gen_array_ops |
| from tensorflow.python.ops import gradients_impl |
| from tensorflow.python.ops import list_ops |
| from tensorflow.python.ops import map_fn |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import random_ops |
| from tensorflow.python.ops import variables |
| from tensorflow.python.ops import while_v2 |
| from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2 |
| from tensorflow.python.platform import test |
| |
| |
| def random_gamma(shape): # pylint: disable=invalid-name |
| return random_ops.random_gamma(shape, 1.0) |
| |
| |
| def random_gamma_with_alpha_beta(shape): # pylint: disable=invalid-name |
| return random_ops.random_gamma( |
| shape, alpha=[[1.], [3.], [5.], [6.]], beta=[[3., 4.]]) |
| |
| |
| def random_poisson_v2(shape): # pylint: disable=invalid-name |
| return random_ops.random_poisson_v2(shape, 1.0) |
| |
| |
| def random_poisson_v2_with_lam(shape): # pylint: disable=invalid-name |
| return random_ops.random_poisson_v2(shape, [12.2, 3.3]) |
| |
| |
| def fill(shape): # pylint: disable=invalid-name |
| return array_ops.fill(shape, 1.0) |
| |
| |
| class WhileV2Test(test.TestCase, parameterized.TestCase): |
| |
| @test_util.run_deprecated_v1 |
| def testSingleLoopVar(self): |
| x = constant_op.constant(2.) |
| ret = while_loop_v2( |
| lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False) |
| grad = gradients_impl.gradients(ret, [x]) |
| with self.cached_session(): |
| self.assertEqual(self.evaluate(ret), 16.) |
| self.assertSequenceEqual(self.evaluate(grad), [32.]) |
| |
| @test_util.run_deprecated_v1 |
| def testSingleLoopVarBackPropFalse(self): |
| x = constant_op.constant(2.) |
| ret = while_loop_v2( |
| lambda v: v < 8., |
| lambda v: v * v, [x], |
| return_same_structure=False, |
| back_prop=False) |
| grad = gradients_impl.gradients(ret, [x]) |
| self.assertEqual(grad, [None]) |
| with self.cached_session(): |
| self.assertEqual(self.evaluate(ret), 16.) |
| |
| @test_util.run_deprecated_v1 |
| def testCustomGradient(self): |
| x = constant_op.constant(2.) |
| n = constant_op.constant(1., name="const-n") |
| m = variables.Variable(1.0) |
| self.evaluate(variables.global_variables_initializer()) |
| |
| def body_fn(v): # pylint: disable=invalid-name |
| |
| @custom_gradient.custom_gradient |
| def inner_fn(v): # pylint: disable=invalid-name |
| |
| def grad_fn(dy, variables=None): # pylint: disable=invalid-name, unused-argument, redefined-outer-name |
| return dy * 2 * v * n * m, [v * v] |
| |
| return v * v * m, grad_fn |
| |
| return inner_fn(v) |
| |
| ret = while_loop_v2( |
| lambda v: v < 8., body_fn, [x], return_same_structure=False) |
| grad = gradients_impl.gradients(ret, [x]) |
| with self.cached_session(): |
| self.assertEqual(self.evaluate(ret), 16.) |
| self.assertSequenceEqual(self.evaluate(grad), [32.]) |
| |
| @test_util.run_v1_only("b/120545219") |
| def testReturnSameStructureTrue(self): |
| x = constant_op.constant(2.) |
| ret = while_loop_v2( |
| lambda v: v < 8., lambda v: v * v, [x], return_same_structure=True) |
| grad = gradients_impl.gradients(ret, [x]) |
| with self.cached_session() as sess: |
| eval_result = sess.run(ret) |
| self.assertIsInstance(eval_result, list) |
| self.assertLen(eval_result, 1) |
| self.assertEqual(16., eval_result[0]) |
| self.assertSequenceEqual(sess.run(grad), [32.]) |
| |
| def testVerifyInputOutputTypesMatch(self): |
| |
| @def_function.function |
| def BuildWhile(): |
| x = constant_op.constant(1., dtypes.float32) |
| |
| def Body(x): |
| return math_ops.cast(x, dtypes.float16) + 1 |
| |
| while_loop_v2(lambda x: x < 10, Body, [x]) |
| |
| with self.assertRaisesRegex( |
| TypeError, |
| r"Loop var Const:0 enters the loop with type <dtype: 'float32'> " |
| r"but has type <dtype: 'float16'> after 1 iteration."): |
| BuildWhile() |
| |
| @parameterized.parameters(dtypes.float32, dtypes.float64) |
| def testGradientTapeResourceVariable(self, dtype): |
| with context.eager_mode(): |
| v = variables.Variable(1., dtype=dtype) |
| |
| @def_function.function |
| def fnWithLoop(): # pylint: disable=invalid-name |
| with backprop.GradientTape() as tape: |
| _, x = while_loop_v2( |
| lambda i, _: i < 2, |
| lambda i, x: (i + 1, x * v), |
| [0, constant_op.constant(2., dtype=dtype)]) |
| return tape.gradient(x, v) |
| |
| self.assertAllEqual(fnWithLoop(), 4.0) |
| |
| def testDeviceLabelsInherited(self): |
| def _LoopBody(i, y): |
| result = math_ops.cos(y) |
| self.assertIn("CPU:10", result.device) |
| with ops.device("CPU:11"): |
| result = array_ops.identity(result) |
| self.assertIn("CPU:11", result.device) |
| return i + 1, result |
| |
| @def_function.function |
| def _FunctionWithWhileLoop(): |
| x = constant_op.constant(1.) |
| with ops.device("CPU:10"): |
| _, z = while_loop_v2( |
| lambda i, _: i < 2, |
| _LoopBody, |
| [0, x]) |
| return z |
| # The test assertion runs at trace time. |
| _FunctionWithWhileLoop.get_concrete_function() |
| |
| def testExternalControlDependencies(self): |
| with ops.Graph().as_default(), self.test_session(): |
| v = variables.Variable(1.) |
| self.evaluate(v.initializer) |
| op = v.assign_add(1.) |
| |
| def body_fn(i): # pylint: disable=invalid-name |
| with ops.control_dependencies([op]): |
| return i + 1 |
| |
| loop = while_loop_v2(lambda i: i < 1, body_fn, [0]) |
| loop[0].op.run() |
| self.assertAllEqual(self.evaluate(v), 2.0) |
| |
| @test_util.run_deprecated_v1 |
| def testMultipleLoopVarsBasic(self): |
| x = constant_op.constant(5.) |
| y = constant_op.constant(3.) |
| |
| # x = 5. |
| # y = 3. |
| # while x < 45.: |
| # x = x * y |
| ret = while_loop_v2( |
| lambda v, _: v < 45., |
| lambda v, w: (v * w, w), [x, y], |
| return_same_structure=False) |
| # ret = [x*y^2, y] |
| |
| # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0. |
| grad = gradients_impl.gradients(ret, [x]) # [2*x*y] |
| with self.cached_session(): |
| self.assertSequenceEqual(self.evaluate(ret), [45., 3.]) |
| self.assertSequenceEqual(self.evaluate(grad), [9.]) |
| |
| @test_util.run_deprecated_v1 |
| def testMultipleLoopNonscalarCond(self): |
| x = constant_op.constant([[5.]]) |
| y = constant_op.constant(3.) |
| |
| # x = 5. |
| # y = 3. |
| # while x < 45.: |
| # x = x * y |
| ret = while_loop_v2( |
| lambda v, _: v < 45., |
| lambda v, w: (v * w, w), [x, y], |
| return_same_structure=False) |
| # ret == [x*y^2, y] |
| |
| # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0. |
| grad = gradients_impl.gradients(ret, [x]) # [2*x*y] |
| with self.cached_session(): |
| self.assertSequenceEqual(self.evaluate(ret), [45., 3.]) |
| self.assertSequenceEqual(self.evaluate(grad), [9.]) |
| |
| @test_util.run_deprecated_v1 |
| def testMultipleLoopVars(self): |
| x = constant_op.constant(5.) |
| y = constant_op.constant(3.) |
| |
| # x = 5. |
| # y = 3. |
| # while x < 45.: |
| # x = x * y |
| # y = x + y |
| ret = while_loop_v2( |
| lambda v, _: v < 45., |
| lambda v, w: (v * w, v + w), [x, y], |
| return_same_structure=False) |
| # ret = [y*x**2 + x*y**2, x*y + x + y] |
| |
| gradx_0 = gradients_impl.gradients(ret[0], [x]) # [2*x*y + y**2] |
| gradx_1 = gradients_impl.gradients(ret[1], [x]) # [y + 1] |
| gradx_2 = gradients_impl.gradients(ret, [x]) # [2*x*y + y**2 + 2*y + 1] |
| grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2] |
| grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1] |
| grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1] |
| with self.cached_session(): |
| self.assertSequenceEqual(self.evaluate(ret), [120., 23.]) |
| self.assertSequenceEqual(self.evaluate(gradx_0), [39.]) |
| self.assertSequenceEqual(self.evaluate(gradx_1), [4.]) |
| self.assertSequenceEqual(self.evaluate(gradx_2), [43.]) |
| self.assertSequenceEqual(self.evaluate(grady_0), [55.]) |
| self.assertSequenceEqual(self.evaluate(grady_1), [6.]) |
| self.assertSequenceEqual(self.evaluate(grady_2), [61.]) |
| |
| @test_util.run_deprecated_v1 |
| def testGradientTape(self): |
| with backprop.GradientTape() as t: |
| x = constant_op.constant(2.) |
| t.watch(x) |
| ret = while_loop_v2( |
| lambda v: v < 4., lambda v: v * v, [x], |
| return_same_structure=False) # x**2 |
| grad = t.gradient(ret, x) |
| with self.cached_session() as sess: |
| self.assertAllEqual(sess.run(grad), 4.0) |
| |
| @test_util.run_deprecated_v1 |
| def testMultipleWhileLoops(self): |
| x = constant_op.constant(2.) |
| ret1 = while_loop_v2( |
| lambda v: v < 4., lambda v: v * v, [x], |
| return_same_structure=False) # x**2 |
| ret2 = while_loop_v2( |
| lambda v: v < 16., lambda v: v * v, [ret1], |
| return_same_structure=False) # x**4 |
| grad = gradients_impl.gradients(ret2, [x]) # 4x**3 |
| grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2 |
| with self.cached_session(): |
| self.assertSequenceEqual(self.evaluate(grad), [32.]) |
| self.assertSequenceEqual(self.evaluate(grad_grad), [48.]) |
| |
| def testMultipleWhileLoopsWithFunc(self): |
| x = constant_op.constant(2.) |
| |
| @def_function.function |
| def Fn(): |
| ret1 = while_loop_v2( |
| lambda v: v < 4., |
| lambda v: v * v, [x], |
| return_same_structure=False, |
| name="while_1") # x**2 |
| ret2 = while_loop_v2( |
| lambda v: v < 16., |
| lambda v: v * v, [x], |
| return_same_structure=False, |
| name="while_2") # x**4 |
| return ret1, ret2 |
| |
| concrete_fn = Fn.get_concrete_function() |
| while_1 = concrete_fn.graph.get_operation_by_name("while_1") |
| while_2 = concrete_fn.graph.get_operation_by_name("while_2") |
| self.assertEqual(while_1.type, "StatelessWhile") |
| self.assertEqual(while_2.type, "StatelessWhile") |
| self.assertEmpty(while_1.control_inputs) |
| self.assertEmpty(while_2.control_inputs) |
| |
| def testMultipleWhileLoopsGradStateless(self): |
| |
| @def_function.function |
| def Fn(): |
| x = constant_op.constant(2.) |
| with backprop.GradientTape() as tape: |
| tape.watch(x) |
| ret1 = while_loop_v2( |
| lambda v: v < 4., |
| lambda v: v * v, [x], |
| return_same_structure=False, |
| name="while_1") # x**2 |
| ret2 = while_loop_v2( |
| lambda v: v < 16., |
| lambda v: v * v, [x], |
| return_same_structure=False, |
| name="while_2") # x**4 |
| loss = ret1 + ret2 |
| return tape.gradient(loss, x) |
| |
| graph = Fn.get_concrete_function().graph |
| while_ops = [op for op in graph.get_operations() if "While" in op.type] |
| self.assertAllEqual([op.type for op in while_ops], ["StatelessWhile"] * 4, |
| "Must have exactly 4 StatelessWhile ops.") |
| for op in while_ops: |
| self.assertEmpty(op.control_inputs, |
| "{} should not have any control inputs".format(op.name)) |
| |
| def testMultipleWhileLoopsWithDeps(self): |
| x = variables.Variable(2.) |
| c = constant_op.constant(2.) |
| |
| @def_function.function |
| def Fn(): |
| |
| def Body1(v): |
| x.assign(x) |
| return v * x |
| |
| ret1 = while_loop_v2( |
| lambda v: v < 4., |
| Body1, [c], |
| return_same_structure=False, |
| name="while_1") # 2x |
| |
| def Body2(v): |
| x.assign(x) |
| return v * x * x |
| |
| ret2 = while_loop_v2( |
| lambda v: v < 16., |
| Body2, [c], |
| return_same_structure=False, |
| name="while_2") # 4x |
| return ret1, ret2 |
| |
| concrete_fn = Fn.get_concrete_function() |
| while_1 = concrete_fn.graph.get_operation_by_name("while_1") |
| while_2 = concrete_fn.graph.get_operation_by_name("while_2") |
| self.assertEqual(while_1.type, "While") |
| self.assertEqual(while_2.type, "While") |
| self.assertEmpty(while_1.control_inputs) |
| self.assertLen(while_2.control_inputs, 1) |
| self.assertIs(while_2.control_inputs[0], while_1) |
| |
| def testMultipleWhileLoopsWithVarsDeps(self): |
| x1 = variables.Variable(2.) |
| x2 = variables.Variable(3.) |
| c = constant_op.constant(2.) |
| |
| @def_function.function |
| def Fn(): |
| |
| def Body1(v): |
| x1.assign(x1) |
| return v * x1 |
| |
| ret1 = while_loop_v2( |
| lambda v: v < 4., |
| Body1, [c], |
| return_same_structure=False, |
| name="while_1") # 2x |
| |
| def Body2(v): |
| x1.assign(x1) |
| return v * x1 * x1 |
| |
| ret2 = while_loop_v2( |
| lambda v: v < 16., |
| Body2, [c], |
| return_same_structure=False, |
| name="while_2") # 4x |
| |
| def Body3(v): |
| x2.assign(x2) |
| return v * x2 |
| |
| ret3 = while_loop_v2( |
| lambda v: v < 4., |
| Body3, [c], |
| return_same_structure=False, |
| name="while_3") # 3x |
| |
| def Body4(v): |
| x2.assign(x2) |
| return v * x2 * x2 |
| |
| ret4 = while_loop_v2( |
| lambda v: v < 16., |
| Body4, [c], |
| return_same_structure=False, |
| name="while_4") # 9x |
| ret5 = while_loop_v2( |
| lambda v: v < 16., |
| lambda v: v * v, [c], |
| return_same_structure=False, |
| name="while_stateless") # x**2 |
| return ret1, ret2, ret3, ret4, ret5 |
| |
| concrete_fn = Fn.get_concrete_function() |
| while_1 = concrete_fn.graph.get_operation_by_name("while_1") |
| while_2 = concrete_fn.graph.get_operation_by_name("while_2") |
| while_3 = concrete_fn.graph.get_operation_by_name("while_3") |
| while_4 = concrete_fn.graph.get_operation_by_name("while_4") |
| while_stateless = concrete_fn.graph.get_operation_by_name( |
| "while_stateless") |
| self.assertEqual(while_1.type, "While") |
| self.assertEqual(while_2.type, "While") |
| self.assertEqual(while_3.type, "While") |
| self.assertEqual(while_4.type, "While") |
| self.assertEqual(while_stateless.type, "StatelessWhile") |
| self.assertEmpty(while_1.control_inputs) |
| self.assertLen(while_2.control_inputs, 1) |
| self.assertIs(while_2.control_inputs[0], while_1) |
| self.assertEmpty(while_3.control_inputs) |
| self.assertLen(while_4.control_inputs, 1) |
| self.assertIs(while_4.control_inputs[0], while_3) |
| self.assertEmpty(while_stateless.control_inputs) |
| |
| @test_util.run_deprecated_v1 |
| def testDoubleDerivative(self): |
| x = constant_op.constant(2.) |
| ret = while_loop_v2( |
| lambda v: v < 8., lambda v: v**2, [x], |
| return_same_structure=False) # x**4 |
| grad = gradients_impl.gradients(ret, [x]) # 4x**3 |
| grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2 |
| with self.cached_session(): |
| self.assertEqual(self.evaluate(ret), 16.) |
| self.assertSequenceEqual(self.evaluate(grad), [32.]) |
| self.assertSequenceEqual(self.evaluate(grad_grad), [48.]) |
| |
| @test_util.run_v2_only |
| def testMultipleWhileLoopsEager(self): |
| |
| @def_function.function |
| def Func(): |
| x = constant_op.constant(2.) |
| ret1 = while_loop_v2( |
| lambda v: v < 4., lambda v: v * v, [x], |
| return_same_structure=False) # x**2 |
| ret2 = while_loop_v2( |
| lambda v: v < 16., |
| lambda v: v * v, [ret1], |
| return_same_structure=False) # x**4 |
| grad = gradients_impl.gradients(ret2, [x])[0] # 4x**3 |
| grad_grad = gradients_impl.gradients(grad, [x])[0] # 12x**2 |
| return grad, grad_grad |
| |
| grad, grad_grad = Func() |
| self.assertEqual(grad.numpy(), 32.) |
| self.assertEqual(grad_grad.numpy(), 48.) |
| |
| @test_util.run_v2_only |
| def testDoubleDerivativeEager(self): |
| |
| @def_function.function |
| def Func(): |
| x = constant_op.constant(2.) |
| ret = while_loop_v2( |
| lambda v: v < 8., lambda v: v**2, [x], |
| return_same_structure=False) # x**4 |
| grad = gradients_impl.gradients(ret, [x])[0] # 4x**3 |
| grad_grad = gradients_impl.gradients(grad, [x])[0] # 12x**2 |
| return ret, grad, grad_grad |
| |
| ret, grad, grad_grad = Func() |
| self.assertEqual(ret.numpy(), 16.) |
| self.assertEqual(grad.numpy(), 32.) |
| self.assertEqual(grad_grad.numpy(), 48.) |
| |
| def _testPruning(self): |
| x = constant_op.constant(1) |
| |
| tensor_list = list_ops.empty_tensor_list( |
| element_dtype=x.dtype, element_shape=x.shape) |
| |
| def Cond(x, tl): |
| del tl # Unused for Cond. |
| return x < 5 |
| |
| def Body(x, tl): |
| return x + 1, list_ops.tensor_list_push_back(tl, x) |
| |
| outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list]) |
| |
| train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) |
| train_op.append(outputs[0]) |
| |
| g = GetOptimizedGraph() |
| # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned |
| # away, causing an extra Enter node. |
| enter_count = 2 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 1 |
| self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) |
| # Test that the TensorList is pruned out. |
| self.assertEmpty([ |
| n for n in g.node if n.op == "Enter" and |
| n.attr["T"].type == dtypes.variant.as_datatype_enum |
| ]) |
| self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"]) |
| |
| stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype) |
| train_op.append(stack) |
| g = GetOptimizedGraph() |
| # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned |
| # away, causing an extra Enter node. |
| enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2 |
| self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) |
| # Test that the TensorList is not pruned out. |
| self.assertNotEmpty([ |
| n for n in g.node if n.op == "Enter" and |
| n.attr["T"].type == dtypes.variant.as_datatype_enum |
| ]) |
| self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"]) |
| |
| @test_util.run_deprecated_v1 |
| def testPruningV1(self): |
| self._testPruning() |
| |
| @test_util.enable_control_flow_v2 |
| @test_util.run_deprecated_v1 |
| def testPruningV2(self): |
| self._testPruning() |
| |
| def _testDoNotAccumulateInvariants(self): |
| push_op = ("TensorListPushBack" |
| if control_flow_v2_toggles.control_flow_v2_enabled() else |
| "StackPushV2") |
| |
| # Tests that loop invariants, i.e., tensors that are "captured" by the |
| # while loop and not passed as loop variables are not accumulated in |
| # gradient computation. |
| v = constant_op.constant(5.0, name="v") |
| |
| r = control_flow_ops.while_loop( |
| lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5) |
| |
| output = gradients_impl.gradients(r, v)[0] |
| train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) |
| train_op.append(output) |
| |
| g = GetOptimizedGraph() |
| # The gradient for v * x requires the value of both v and x. Since v is a |
| # loop invariant it is not accumulated so we have just one accumulator for |
| # x. |
| self.assertLen([n for n in g.node if n.op == push_op], 1) |
| |
| @test_util.run_deprecated_v1 |
| def testDoNotAccumulateInvariantsV1(self): |
| self._testDoNotAccumulateInvariants() |
| |
| @test_util.run_deprecated_v1 |
| @test_util.enable_control_flow_v2 |
| def testDoNotAccumulateInvariantsV2(self): |
| self._testDoNotAccumulateInvariants() |
| |
| @test_util.enable_control_flow_v2 |
| @test_util.run_deprecated_v1 |
| @test_util.enable_output_all_intermediates |
| def testPruningNested(self): |
| assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE |
| x = constant_op.constant(0) |
| |
| tensor_list = list_ops.empty_tensor_list( |
| element_dtype=x.dtype, element_shape=x.shape) |
| |
| def Cond(x, tl): |
| del tl # Unused for Cond. |
| return x < 25 |
| |
| def Body(x, tl): |
| |
| def InnerCond(inner_x, unused_outer_x, unused_tl): |
| return inner_x < 5 |
| |
| def InnerBody(inner_x, outer_x, tl): |
| return inner_x + 1, outer_x + 1, list_ops.tensor_list_push_back(tl, x) |
| |
| inner_x = constant_op.constant(0) |
| return control_flow_ops.while_loop(InnerCond, InnerBody, |
| [inner_x, x, tl])[1:] |
| |
| outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list]) |
| |
| train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) |
| train_op.append(outputs[0]) |
| |
| g = GetOptimizedGraph() |
| # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned |
| # away, causing an extra Enter node. |
| # enter_count = 4 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2 |
| # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) |
| # Test that the TensorList is pruned out. |
| self.assertEmpty([ |
| n for n in g.node if n.op == "Enter" and |
| n.attr["T"].type == dtypes.variant.as_datatype_enum |
| ]) |
| self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"]) |
| self.assertEmpty([n for n in g.node if n.op == "_While"]) |
| |
| stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype) |
| train_op.append(stack) |
| g = GetOptimizedGraph() |
| # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned |
| # away, causing an extra Enter node. |
| # enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2 |
| # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) |
| # Test that the TensorList is not pruned out. |
| self.assertNotEmpty([ |
| n for n in g.node if n.op == "Enter" and |
| n.attr["T"].type == dtypes.variant.as_datatype_enum |
| ]) |
| self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"]) |
| |
| @test_util.enable_control_flow_v2 |
| @test_util.run_deprecated_v1 |
| @test_util.enable_output_all_intermediates |
| def testPruningNested2(self): |
| assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE |
| v = constant_op.constant(5.0, name="v") |
| |
| p = array_ops.placeholder(dtype=dtypes.int32) |
| |
| def MidBodyBuilder(iterations): |
| |
| def MidBody(i, x): |
| r = control_flow_ops.while_loop( |
| lambda *_: True, |
| lambda i, x: (i + 1, math_ops.multiply(v, x, name="my_mul")), |
| (0, x), |
| maximum_iterations=iterations, |
| name="inner") |
| return (i + 1, gradients_impl.gradients(x + r[1], v)[0]) |
| |
| return MidBody |
| |
| def OuterBody(i, x): |
| iterations = array_ops.size(p, name="iterations") |
| return (i + 1, x + control_flow_ops.while_loop( |
| lambda *_: True, |
| MidBodyBuilder(iterations), (0, x), |
| maximum_iterations=iterations, |
| name="mid")[1]) |
| |
| def CreateWhileLoop(): |
| with ops.device("/cpu:0"): |
| r = control_flow_ops.while_loop( |
| lambda *_: True, |
| OuterBody, (0, 1.0), |
| maximum_iterations=5, |
| name="outer") |
| return array_ops.identity(r[1]) |
| |
| output = CreateWhileLoop() |
| train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) |
| train_op.append(output) |
| |
| g = GetOptimizedGraph() |
| self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1) |
| |
| @test_util.enable_control_flow_v2 |
| @test_util.run_deprecated_v1 |
| @test_util.enable_output_all_intermediates |
| def testPruningNested3(self): |
| assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE |
| v = constant_op.constant(5.0, name="v") |
| |
| def CreateWhileLoop(): |
| r = control_flow_ops.while_loop( |
| lambda _: True, |
| lambda x: math_ops.multiply(v, x, name="my_mul"), [1.0], |
| maximum_iterations=5, |
| name="outer") |
| return array_ops.identity(r) |
| |
| r = CreateWhileLoop() |
| output = gradients_impl.gradients(r, v)[0] |
| train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) |
| train_op.append(output) |
| |
| g = GetOptimizedGraph() |
| self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1) |
| |
| def _assertNotAccumulated(self, while_op, index): |
| """Asserts that `while_op` input at `index` is not accumulated.""" |
| body_graph = while_v2._get_graph(while_op, "body", "_body_graph") |
| placeholder = body_graph.inputs[index] |
| self.assertNotIn("TensorListPushBack", |
| [op.type for op in placeholder.consumers()]) |
| |
| @test_util.enable_control_flow_v2 |
| @test_util.run_deprecated_v1 |
| @test_util.enable_output_all_intermediates |
| def testDoNotOutputLoopCounterAsIntermediate(self): |
| assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE |
| v = constant_op.constant(5.0, name="v") |
| r = control_flow_ops.while_loop( |
| lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5) |
| # Skip over Identity. |
| while_op = r.op.inputs[0].op |
| self._assertNotAccumulated(while_op, 0) |
| |
| @test_util.enable_control_flow_v2 |
| @test_util.run_deprecated_v1 |
| @test_util.enable_output_all_intermediates |
| def testDoNotOutputLoopInvariantAsIntermediate(self): |
| assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE |
| |
| def GetInputIndex(op, tensor): |
| for index, inp in enumerate(op.inputs): |
| if inp is tensor: |
| return index |
| |
| v = constant_op.constant(5.0, name="v") |
| r = control_flow_ops.while_loop( |
| lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5) |
| # Skip over Identity. |
| while_op = r.op.inputs[0].op |
| # We can't directly use while_op.inputs.index() because Tensors are not |
| # hashable. |
| index = GetInputIndex(while_op, v) |
| self._assertNotAccumulated(while_op, index) |
| |
| @test_util.run_deprecated_v1 |
| def testCaptureExternalTensorInCond(self): |
| x = constant_op.constant(2.) |
| y = constant_op.constant(1.) |
| ret = while_loop_v2( |
| lambda v: v + y < 9., |
| lambda v: v * 3., [x], |
| return_same_structure=False) |
| grad = gradients_impl.gradients(ret, [x]) |
| with self.cached_session(): |
| self.assertEqual(self.evaluate(ret), 18.) |
| self.assertSequenceEqual(self.evaluate(grad), [9.]) |
| |
| @test_util.run_deprecated_v1 |
| def testCaptureExternalTensorInBody(self): |
| x = constant_op.constant(2.) |
| y = constant_op.constant(3.) |
| ret = while_loop_v2( |
| lambda v: v < 8., lambda v: v * y, [x], return_same_structure=False) |
| grad = gradients_impl.gradients(ret, [x]) |
| with self.cached_session(): |
| self.assertEqual(self.evaluate(ret), 18.) |
| self.assertSequenceEqual(self.evaluate(grad), [9.]) |
| |
| @test_util.run_deprecated_v1 |
| def testLoopWithTensorListPushBack(self): |
| x = constant_op.constant(2.) |
| |
| tensor_list = list_ops.empty_tensor_list( |
| element_dtype=dtypes.float32, element_shape=ScalarShape()) |
| |
| def Cond(x, tl): |
| del tl # Unused for Cond. |
| return x < 5. |
| |
| def Body(x, tl): |
| tl = list_ops.tensor_list_push_back(tl, x) |
| tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.)) |
| return x**2., tl |
| |
| ret = while_loop_v2( |
| Cond, Body, [x, tensor_list], return_same_structure=False) |
| grad = gradients_impl.gradients(ret[0], x) |
| with self.cached_session() as sess: |
| self.assertEqual(sess.run(ret[0]), 16.) |
| self.assertSequenceEqual(self.evaluate(grad), [32.]) |
| |
| @test_util.run_deprecated_v1 |
| def testDuplicateAccumulator(self): |
| x = constant_op.constant(2.) |
| |
| tensor_list = list_ops.empty_tensor_list( |
| element_dtype=dtypes.float32, element_shape=ScalarShape()) |
| |
| def Cond(x, tl): |
| del tl # Unused for Cond. |
| return x < 5. |
| |
| def Body(x, tl): |
| # There is an accumulator in the loop already so we should not add |
| # another. |
| tl = list_ops.tensor_list_push_back(tl, x) |
| return x**2., tl |
| |
| ret = while_loop_v2( |
| Cond, Body, [x, tensor_list], return_same_structure=False) |
| |
| for op in ops.get_default_graph().get_operations(): |
| if op.type == "While" or op.type == "StatelessWhile": |
| while_op = op |
| |
| body_graph = while_v2._get_graph(while_op, "body", "_body_graph") |
| x_input_index = [i for i, inp in enumerate(while_op.inputs) if inp == x][0] |
| x_input_t = body_graph.inputs[x_input_index] |
| accumulator_count = len( |
| [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"]) |
| self.assertEqual(accumulator_count, 1) |
| |
| grad = gradients_impl.gradients(ret[0], x) |
| with self.cached_session() as sess: |
| self.assertEqual(sess.run(ret[0]), 16.) |
| self.assertSequenceEqual(self.evaluate(grad), [32.]) |
| |
| @parameterized.named_parameters( |
| ("UnknownShape", None), |
| ("PartiallyDefinedShape", [None, 2]), |
| ("FullyDefinedShape", [1, 2]), |
| ) |
| @test_util.run_deprecated_v1 |
| def testAccumulatorElementShape(self, shape): |
| |
| def MatchShape(actual_tensor_shape): |
| # Compare the shapes, treating None dimensions as equal. We do not |
| # directly check actual_tensor_shape and tf.TensorShape(shape) for |
| # equality because tf.Dimension.__eq__ returns None if either dimension is |
| # None. |
| if shape is None: |
| self.assertIsNone(actual_tensor_shape.dims) |
| else: |
| self.assertListEqual(actual_tensor_shape.as_list(), shape) |
| |
| def GetAccumulatorForInputAtIndex(while_op, idx): |
| body_graph = while_v2._get_graph(while_op, "body", "_body_graph") |
| y_input_t = body_graph.inputs[idx] |
| push_back_node = [c for c in y_input_t.consumers() |
| if c.type == "TensorListPushBack"][0] |
| output_idx = body_graph.outputs.index(push_back_node.outputs[0]) |
| return while_op.outputs[output_idx] |
| |
| x = array_ops.placeholder(dtype=dtypes.float32, shape=shape) |
| y = array_ops.placeholder(dtype=dtypes.float32, shape=shape) |
| |
| # Forward pass. |
| ret = while_loop_v2(lambda v, u: v < 8., |
| lambda v, u: (math_ops.pow(v, u), u), |
| [x, y], |
| return_same_structure=True) |
| while_op = ret[0].op.inputs[0].op |
| # Gradient pass. |
| grad = gradients_impl.gradients(ret[0], x) |
| # Note: There is an Identity b/w grad[0] and the While op. |
| grad_while_op = grad[0].op.inputs[0].op |
| |
| # Get the TensorList output of While op containing the accumulated values |
| # of y. |
| x_input_index = [i for i, inp in enumerate(while_op.inputs) if x == inp][0] |
| output = GetAccumulatorForInputAtIndex(while_op, x_input_index) |
| _, val = list_ops.tensor_list_pop_back(output, |
| element_dtype=dtypes.float32) |
| MatchShape(val.shape) |
| |
| # Take second derivative to generate intermediate grad_while_op outputs |
| gradients_impl.gradients(grad, x) |
| |
| # Get the TensorList output of gradient While op containing the accumulated |
| # values of grad_x (note that grad_x is needed by the second derivative). |
| # grad_while_op.inputs: |
| grad_output_index = grad_while_op.outputs.index(grad[0].op.inputs[0]) |
| grad_output = GetAccumulatorForInputAtIndex(grad_while_op, |
| grad_output_index) |
| _, val = list_ops.tensor_list_pop_back(grad_output, |
| element_dtype=dtypes.float32) |
| MatchShape(val.shape) |
| |
| def _createWhile(self, name): |
| """Helper function testDefaultName.""" |
| output = while_v2.while_loop( |
| lambda i: i < 3, |
| lambda i: i + 1, [constant_op.constant(0)], |
| return_same_structure=False) |
| while_op = output.op.inputs[0].op |
| self.assertEqual(while_op.type, "StatelessWhile") |
| return while_op |
| |
| def testDefaultName(self): |
| with ops.Graph().as_default(): |
| while_op = self._createWhile(None) |
| self.assertEqual(while_op.name, "while") |
| self.assertRegex(while_op.get_attr("cond").name, r"while_cond_\d*") |
| self.assertRegex(while_op.get_attr("body").name, r"while_body_\d*") |
| |
| with ops.Graph().as_default(): |
| with ops.name_scope("foo"): |
| while1_op = self._createWhile("") |
| self.assertEqual(while1_op.name, "foo/while") |
| self.assertRegex(while1_op.get_attr("cond").name, r"foo_while_cond_\d*") |
| self.assertRegex(while1_op.get_attr("body").name, r"foo_while_body_\d*") |
| |
| while2_op = self._createWhile(None) |
| self.assertEqual(while2_op.name, "foo/while_1") |
| self.assertRegex( |
| while2_op.get_attr("cond").name, r"foo_while_1_cond_\d*") |
| self.assertRegex( |
| while2_op.get_attr("body").name, r"foo_while_1_body_\d*") |
| |
| @test_util.enable_control_flow_v2 |
| @test_util.run_deprecated_v1 |
| def testWhileAndTensorArray(self): |
| param = constant_op.constant(2.0) |
| y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems") |
| # map_fn uses TensorArray internally. |
| r = map_fn.map_fn(lambda x: math_ops.multiply(x, param), y0) |
| grad = gradients_impl.gradients(r, param)[0] |
| self.assertAllClose([2.0, 4.0, 6.0, 8.0, 10.0, 12.0], self.evaluate(r)) |
| self.assertAllClose(21.0, self.evaluate(grad)) |
| |
| @test_util.run_deprecated_v1 |
| def testNestedWhile(self): |
| # Compute sum of geometric progression: n^0 + n^1 + ... + n^m |
| # We compute the pow using a while loop. |
| n = constant_op.constant(3.) |
| m = constant_op.constant(5.) |
| sum_of_powers = constant_op.constant(0.) |
| |
| def Body(i, previous_sum): |
| prod = constant_op.constant(1.) |
| return i - 1., previous_sum + while_loop_v2( |
| lambda c, _: c > 0, |
| lambda c, v: (c - 1., v * n), [i, prod], |
| return_same_structure=False)[1] |
| |
| result = while_loop_v2( |
| lambda i, _: i >= 0, |
| Body, [m, sum_of_powers], |
| return_same_structure=False)[1] |
| grad = gradients_impl.gradients(result, [n]) |
| self.assertEqual(self.evaluate(result), 364.) |
| self.assertSequenceEqual(self.evaluate(grad), [547.]) |
| |
| @test_util.run_deprecated_v1 |
| def testNestedWhileWithLegacyDefun(self): |
| n = constant_op.constant(3.) |
| m = constant_op.constant(5.) |
| sum_of_powers = constant_op.constant(0.) |
| |
| def Body(i, previous_sum): |
| prod = constant_op.constant(1.) |
| |
| def InnerBodyWrapper(c, v): |
| |
| @function.Defun(dtypes.float32, dtypes.float32) |
| def InnerBody(c, v): |
| return c - 1., v * n |
| |
| results = InnerBody(c, v) |
| results[0].set_shape([]) |
| results[1].set_shape([]) |
| return results |
| |
| return i - 1., previous_sum + while_loop_v2( |
| lambda c, _: c > 0, |
| InnerBodyWrapper, [i, prod], |
| return_same_structure=False)[1] |
| |
| result = while_loop_v2( |
| lambda i, _: i >= 0, |
| Body, [m, sum_of_powers], |
| return_same_structure=False)[1] |
| grad = gradients_impl.gradients(result, [n]) |
| self.assertEqual(self.evaluate(result), 364.) |
| self.assertSequenceEqual(self.evaluate(grad), [547.]) |
| |
| @test_util.run_deprecated_v1 |
| def testIdentityNodeInBody(self): |
| |
| def Body(v): |
| v = array_ops.identity(v) |
| v = array_ops.identity(v) |
| return v * v |
| |
| x = constant_op.constant(2.) |
| ret = while_loop_v2( |
| lambda v: v < 8., Body, [x], return_same_structure=False) |
| grad = gradients_impl.gradients(ret, [x]) |
| self.assertEqual(self.evaluate(ret), 16.) |
| self.assertSequenceEqual(self.evaluate(grad), [32.]) |
| |
| @test_util.run_deprecated_v1 |
| def testForwardPassRewrite(self): |
| x = constant_op.constant(1.0, name="x") |
| output = while_v2.while_loop(lambda x: x < 10.0, |
| lambda x: x * 2.0, |
| [x])[0] |
| while_op = output.op.inputs[0].op |
| self.assertEqual(while_op.type, "StatelessWhile") |
| # outputs = [loop_counter, max_iters, x] |
| self.assertLen(while_op.outputs, 3) |
| |
| gradients_impl.gradients(output, x) |
| # while_op should have been rewritten to output intermediates. |
| # outputs = [loop_counter, max_iters, x, x_accumulator] |
| self.assertLen(while_op.outputs, 4) |
| |
| gradients_impl.gradients(output, x) |
| # Computing the gradient again shouldn't rewrite while_op again. |
| self.assertLen(while_op.outputs, 4) |
| |
| @parameterized.named_parameters( |
| ("RandomUniform", random_ops.random_uniform, [5, 3]), |
| ("RandomNormal", random_ops.random_normal, [5, 3]), |
| ("ParameterizedTruncatedNormal", |
| random_ops.parameterized_truncated_normal, [5, 3]), |
| ("TruncatedNormal", random_ops.truncated_normal, [5, 3]), |
| ("RandomGamma", random_gamma, [5, 3]), |
| ("RandomPoissonV2", random_poisson_v2, [5, 3]), |
| ("RandomGammaWithAlphaBeta", random_gamma_with_alpha_beta, [5, 3, 4, 2]), |
| ("RandomPoissonV2WithLam", random_poisson_v2_with_lam, [5, 3, 2]), |
| ) |
| @test_util.run_deprecated_v1 |
| def testRandomOpsShape(self, random_fn, expected_shape): |
| shape = constant_op.constant([3]) |
| |
| def Body(i, u): |
| shape_extended = array_ops.concat([[5], shape], axis=0) |
| u = random_fn(shape_extended) |
| assert u.shape.as_list() == expected_shape, str(u.shape.as_list()) |
| return i + 1, u |
| |
| _, _ = while_loop_v2( |
| cond=lambda i, _: i < 3, |
| body=Body, |
| loop_vars=[ |
| 0, |
| array_ops.zeros(expected_shape, dtype=dtypes.float32), |
| ]) |
| |
| @test_util.run_deprecated_v1 |
| def testReshapeShape(self): |
| shape = constant_op.constant([3, 4]) |
| |
| def Body(i, u): |
| shape_extended = array_ops.concat([[5], shape], axis=0) |
| u = array_ops.reshape(u, [-1]) |
| assert u.shape.as_list() == [60], str(u.shape.as_list()) |
| u = array_ops.reshape(u, shape_extended) |
| assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list()) |
| return i + 1, u |
| |
| _, _ = while_loop_v2( |
| cond=lambda i, _: i < 3, |
| body=Body, |
| loop_vars=[ |
| 0, |
| array_ops.zeros([5, 3, 4], dtype=dtypes.float32), |
| ]) |
| |
| @parameterized.named_parameters( |
| ("Zeros", array_ops.zeros), |
| ("Ones", array_ops.ones), |
| ("Fill", fill), |
| ) |
| @test_util.run_deprecated_v1 |
| def testFillOpsShape(self, fill_fn): |
| shape = constant_op.constant([3, 4]) |
| |
| def Body(i, u): |
| shape_extended = array_ops.concat([[5], shape], axis=0) |
| u = fill_fn(shape_extended) |
| assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list()) |
| return i + 1, u |
| |
| _, _ = while_loop_v2( |
| cond=lambda i, _: i < 3, |
| body=Body, |
| loop_vars=[ |
| 0, |
| array_ops.zeros([5, 3, 4], dtype=dtypes.float32), |
| ]) |
| |
| @test_util.run_deprecated_v1 |
| def testExternalColocationGrad(self): |
| external_t = constant_op.constant(2.) |
| v0 = constant_op.constant(2.) |
| |
| def Body(v): |
| with ops.colocate_with(external_t): |
| return v * v |
| |
| ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0] |
| grad = gradients_impl.gradients(ret, [v0])[0] |
| self.assertAllEqual(ret, 16.) |
| self.assertAllEqual(grad, 32.) |
| |
| @test_util.run_deprecated_v1 |
| def testDoNotAccumulateConstNodes(self): |
| |
| def Body(v): |
| return v * 2.0 |
| |
| v0 = constant_op.constant(2.) |
| ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0] |
| # Gradients computation has the side-effect of updating the forward op |
| # which is what we want to test. |
| unused_grad = gradients_impl.gradients(ret, [v0])[0] |
| # ret is separated from the `While` op by an `Identity` so we skip over |
| # that. |
| forward_while_op = ret.op.inputs[0].op |
| body_graph = while_v2._get_graph(forward_while_op, "body", "_body_graph") |
| push_back_nodes = [ |
| o for o in body_graph.get_operations() if o.type == "TensorListPushBack" |
| ] |
| # Gradient of `Mul` requires accumulating both its inputs. But since one |
| # of those is a Const (2.0), we should have just one accumulator. |
| self.assertLen(push_back_nodes, 1) |
| |
| def testDoNotAccumulateForwardTensorsForReductionOps(self): |
| |
| @def_function.function |
| def Fn(): |
| with backprop.GradientTape() as tape: |
| x = constant_op.constant(2.) |
| tape.watch(x) |
| |
| def Body(i, x): |
| forward_graph = ops.get_default_graph() |
| |
| @custom_gradient.custom_gradient |
| def SquaredWithZeroGrad(x): |
| |
| def Grad(unused_g, variables=None): # pylint: disable=redefined-outer-name |
| del variables |
| gradient_graph = ops.get_default_graph() |
| shape = gen_array_ops.shape(x) |
| assert shape.graph is forward_graph |
| rank = gen_array_ops.rank(x) |
| assert rank.graph is forward_graph |
| size = gen_array_ops.size(x) |
| assert size.graph is forward_graph |
| zeros = array_ops.zeros(shape) |
| assert zeros.graph is gradient_graph |
| return zeros |
| |
| return x * 2, Grad |
| |
| return i + 1, SquaredWithZeroGrad(x) |
| |
| _, result = while_loop_v2(lambda i, _: i < 2, Body, [0, x]) |
| grad = tape.gradient(result, x) |
| return grad |
| |
| Fn() |
| |
| @test_util.run_v2_only |
| def testInheritParentNameScope(self): |
| |
| @def_function.function |
| def F(): |
| with ops.name_scope("foo"): |
| |
| def Cond(unused_i): |
| with ops.name_scope("cond"): |
| actual_name_scope = ops.get_name_scope() |
| expected_name_scope = "foo/while/cond" |
| assert actual_name_scope == expected_name_scope, ( |
| "%s does not match %s" % |
| (actual_name_scope, expected_name_scope)) |
| return False |
| |
| def Body(i): |
| with ops.name_scope("body"): |
| actual_name_scope = ops.get_name_scope() |
| expected_name_scope = "foo/while/body" |
| assert actual_name_scope == expected_name_scope, ( |
| "%s does not match %s" % |
| (actual_name_scope, expected_name_scope)) |
| return i |
| |
| return while_v2.while_loop(Cond, Body, [0.]) |
| |
| F() |
| |
| @test_util.run_deprecated_v1 # Need to pass RunMetadata. |
| def testDisableLowering(self): |
| old = control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE |
| control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = True |
| with self.session() as sess: |
| x = constant_op.constant(2.) |
| ret = while_loop_v2( |
| lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False) |
| |
| opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) |
| run_metadata = config_pb2.RunMetadata() |
| self.assertEqual(sess.run(ret, options=opts, run_metadata=run_metadata), |
| 16) |
| for dev_stat in run_metadata.step_stats.dev_stats: |
| for ns in dev_stat.node_stats: |
| self.assertNotIn("switch", ns.node_name) |
| control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = old |
| |
| def _runBasicWithConfig(self, config): |
| with ops.device("/cpu:0"): |
| x = constant_op.constant(0) |
| ret, = while_loop_v2(lambda x: x < 1000, lambda x: x + 1, [x]) |
| with self.cached_session(config=config): |
| self.assertEqual(1000, self.evaluate(ret)) |
| |
| @test_util.run_deprecated_v1 |
| def testRunKernelsInline(self): |
| config = config_pb2.ConfigProto() |
| config.inter_op_parallelism_threads = -1 |
| self._runBasicWithConfig(config) |
| |
| @test_util.run_deprecated_v1 |
| def testSingleThreadedExecution(self): |
| config = config_pb2.ConfigProto() |
| config.experimental.executor_type = "SINGLE_THREADED_EXECUTOR" |
| self._runBasicWithConfig(config) |
| |
| def testIsControlFlowGraph(self): |
| x = constant_op.constant(0) |
| |
| @def_function.function |
| def F(c): |
| |
| def Cond(i): |
| self.assertTrue(i.graph.is_control_flow_graph) |
| return i < 2 |
| |
| def Body(i): |
| i = i + 1 |
| self.assertTrue(i.graph.is_control_flow_graph) |
| return i |
| |
| return while_loop_v2(Cond, Body, [c]) |
| |
| ret, = F(x) |
| self.assertEqual(2, self.evaluate(ret)) |
| |
| def testImportFromSerializedWithFunctionInBody(self): |
| serialized = """node { |
| name: "Const" |
| op: "Const" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_FLOAT |
| tensor_shape { |
| } |
| float_val: 1.0 |
| } |
| } |
| } |
| } |
| node { |
| name: "while/maximum_iterations" |
| op: "Const" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_INT32 |
| tensor_shape { |
| } |
| int_val: -1 |
| } |
| } |
| } |
| } |
| node { |
| name: "while/loop_counter" |
| op: "Const" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_INT32 |
| tensor_shape { |
| } |
| int_val: 0 |
| } |
| } |
| } |
| } |
| node { |
| name: "while" |
| op: "StatelessWhile" |
| input: "while/loop_counter" |
| input: "while/maximum_iterations" |
| input: "Const" |
| attr { |
| key: "T" |
| value { |
| list { |
| type: DT_INT32 |
| type: DT_INT32 |
| type: DT_FLOAT |
| } |
| } |
| } |
| attr { |
| key: "_lower_using_switch_merge" |
| value { |
| b: true |
| } |
| } |
| attr { |
| key: "_num_original_outputs" |
| value { |
| i: 3 |
| } |
| } |
| attr { |
| key: "_read_only_resource_inputs" |
| value { |
| list { |
| } |
| } |
| } |
| attr { |
| key: "body" |
| value { |
| func { |
| name: "while_body_822" |
| } |
| } |
| } |
| attr { |
| key: "cond" |
| value { |
| func { |
| name: "while_cond_821" |
| } |
| } |
| } |
| attr { |
| key: "output_shapes" |
| value { |
| list { |
| shape { |
| } |
| shape { |
| } |
| shape { |
| } |
| } |
| } |
| } |
| attr { |
| key: "parallel_iterations" |
| value { |
| i: 10 |
| } |
| } |
| } |
| node { |
| name: "while/Identity" |
| op: "Identity" |
| input: "while" |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| } |
| node { |
| name: "while/Identity_1" |
| op: "Identity" |
| input: "while:1" |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| } |
| node { |
| name: "while/Identity_2" |
| op: "Identity" |
| input: "while:2" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| } |
| library { |
| function { |
| signature { |
| name: "while_body_822" |
| input_arg { |
| name: "while_loop_counter" |
| type: DT_INT32 |
| } |
| input_arg { |
| name: "while_maximum_iterations_0" |
| type: DT_INT32 |
| } |
| input_arg { |
| name: "placeholder" |
| type: DT_FLOAT |
| } |
| output_arg { |
| name: "add" |
| type: DT_INT32 |
| } |
| output_arg { |
| name: "while_maximum_iterations" |
| type: DT_INT32 |
| } |
| output_arg { |
| name: "partitionedcall" |
| type: DT_FLOAT |
| } |
| } |
| node_def { |
| name: "PartitionedCall" |
| op: "PartitionedCall" |
| input: "placeholder" |
| attr { |
| key: "Tin" |
| value { |
| list { |
| type: DT_FLOAT |
| } |
| } |
| } |
| attr { |
| key: "Tout" |
| value { |
| list { |
| type: DT_FLOAT |
| } |
| } |
| } |
| attr { |
| key: "_collective_manager_ids" |
| value { |
| list { |
| } |
| } |
| } |
| attr { |
| key: "_read_only_resource_inputs" |
| value { |
| list { |
| } |
| } |
| } |
| attr { |
| key: "config" |
| value { |
| s: "" |
| } |
| } |
| attr { |
| key: "config_proto" |
| value { |
| s: "" |
| } |
| } |
| attr { |
| key: "executor_type" |
| value { |
| s: "" |
| } |
| } |
| attr { |
| key: "f" |
| value { |
| func { |
| name: "__inference_f_841" |
| } |
| } |
| } |
| experimental_debug_info { |
| original_node_names: "PartitionedCall" |
| } |
| } |
| node_def { |
| name: "add/y" |
| op: "Const" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_INT32 |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_INT32 |
| tensor_shape { |
| } |
| int_val: 1 |
| } |
| } |
| } |
| experimental_debug_info { |
| original_node_names: "add/y" |
| } |
| } |
| node_def { |
| name: "add_0" |
| op: "AddV2" |
| input: "while_loop_counter" |
| input: "add/y:output:0" |
| attr { |
| key: "T" |
| value { |
| type: DT_INT32 |
| } |
| } |
| experimental_debug_info { |
| original_node_names: "add" |
| } |
| } |
| ret { |
| key: "add" |
| value: "add_0:z:0" |
| } |
| ret { |
| key: "partitionedcall" |
| value: "PartitionedCall:output:0" |
| } |
| ret { |
| key: "while_maximum_iterations" |
| value: "while_maximum_iterations_0" |
| } |
| arg_attr { |
| key: 0 |
| value { |
| attr { |
| key: "_output_shapes" |
| value { |
| list { |
| shape { |
| } |
| } |
| } |
| } |
| } |
| } |
| arg_attr { |
| key: 1 |
| value { |
| attr { |
| key: "_output_shapes" |
| value { |
| list { |
| shape { |
| } |
| } |
| } |
| } |
| } |
| } |
| arg_attr { |
| key: 2 |
| value { |
| attr { |
| key: "_output_shapes" |
| value { |
| list { |
| shape { |
| } |
| } |
| } |
| } |
| } |
| } |
| } |
| function { |
| signature { |
| name: "while_cond_821" |
| input_arg { |
| name: "while_loop_counter" |
| type: DT_INT32 |
| } |
| input_arg { |
| name: "while_maximum_iterations" |
| type: DT_INT32 |
| } |
| input_arg { |
| name: "placeholder" |
| type: DT_FLOAT |
| } |
| output_arg { |
| name: "less" |
| type: DT_BOOL |
| } |
| } |
| node_def { |
| name: "Less/y" |
| op: "Const" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_FLOAT |
| tensor_shape { |
| } |
| float_val: 5.0 |
| } |
| } |
| } |
| experimental_debug_info { |
| original_node_names: "Less/y" |
| } |
| } |
| node_def { |
| name: "Less" |
| op: "Less" |
| input: "placeholder" |
| input: "Less/y:output:0" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| experimental_debug_info { |
| original_node_names: "Less" |
| } |
| } |
| ret { |
| key: "less" |
| value: "Less:z:0" |
| } |
| arg_attr { |
| key: 0 |
| value { |
| attr { |
| key: "_output_shapes" |
| value { |
| list { |
| shape { |
| } |
| } |
| } |
| } |
| } |
| } |
| arg_attr { |
| key: 1 |
| value { |
| attr { |
| key: "_output_shapes" |
| value { |
| list { |
| shape { |
| } |
| } |
| } |
| } |
| } |
| } |
| arg_attr { |
| key: 2 |
| value { |
| attr { |
| key: "_output_shapes" |
| value { |
| list { |
| shape { |
| } |
| } |
| } |
| } |
| } |
| } |
| } |
| function { |
| signature { |
| name: "__inference_f_841" |
| input_arg { |
| name: "mul_placeholder" |
| type: DT_FLOAT |
| } |
| output_arg { |
| name: "identity" |
| type: DT_FLOAT |
| } |
| } |
| node_def { |
| name: "mul/y" |
| op: "Const" |
| attr { |
| key: "dtype" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| attr { |
| key: "value" |
| value { |
| tensor { |
| dtype: DT_FLOAT |
| tensor_shape { |
| } |
| float_val: 2.0 |
| } |
| } |
| } |
| experimental_debug_info { |
| original_node_names: "mul/y" |
| } |
| } |
| node_def { |
| name: "mul" |
| op: "Mul" |
| input: "mul_placeholder" |
| input: "mul/y:output:0" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| experimental_debug_info { |
| original_node_names: "mul" |
| } |
| } |
| node_def { |
| name: "Identity" |
| op: "Identity" |
| input: "mul:z:0" |
| attr { |
| key: "T" |
| value { |
| type: DT_FLOAT |
| } |
| } |
| experimental_debug_info { |
| original_node_names: "Identity" |
| } |
| } |
| ret { |
| key: "identity" |
| value: "Identity:output:0" |
| } |
| arg_attr { |
| key: 0 |
| value { |
| attr { |
| key: "_output_shapes" |
| value { |
| list { |
| shape { |
| } |
| } |
| } |
| } |
| } |
| } |
| } |
| } |
| versions { |
| producer: 399 |
| min_consumer: 12 |
| } |
| """ |
| # Code for generating above graph: |
| # |
| # def Body(i): |
| # @tf.function |
| # def f(): |
| # return i * 2 |
| # return f() |
| # tf.while_loop(lambda i: i < 5., Body, [tf.constant(1.)]) |
| graph_def = graph_pb2.GraphDef() |
| text_format.Parse(serialized, graph_def) |
| @def_function.function |
| def F(): |
| x, y = importer.import_graph_def( |
| graph_def, return_elements=["Const:0", "while:2"]) |
| grad_out, = gradients_impl.gradients(y, x) |
| return grad_out |
| self.assertAllEqual(F(), 8.0) |
| |
| def testIndexedSlicesInIncomingGrads(self): |
| @def_function.function |
| def F(): |
| x = constant_op.constant([2.]) |
| # Computes x^4 |
| ret = while_loop_v2( |
| lambda _: True, lambda v: v * v, [x], return_same_structure=False, |
| maximum_iterations=2) |
| v = array_ops.gather(ret, [0]) |
| return gradients_impl.gradients(v, [x])[0] # 4*x^3 |
| self.assertAllEqual(self.evaluate(F()), [32.]) |
| |
| |
| def ScalarShape(): |
| return ops.convert_to_tensor([], dtype=dtypes.int32) |
| |
| |
| def GetOptimizedGraph(): |
| mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) |
| config = config_pb2.ConfigProto() |
| config.graph_options.rewrite_options.CopyFrom( |
| rewriter_config_pb2.RewriterConfig( |
| constant_folding=rewriter_config_pb2.RewriterConfig.OFF, |
| memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)) |
| return tf_optimizer.OptimizeGraph(config, mg) |
| |
| |
| if __name__ == "__main__": |
| test.main() |