| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for while_v2.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| from absl.testing import parameterized |
| |
| from tensorflow.core.protobuf import rewriter_config_pb2 |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import meta_graph |
| from tensorflow.python.framework import ops |
| from tensorflow.python.grappler import tf_optimizer |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import gradients_impl |
| from tensorflow.python.ops import list_ops |
| from tensorflow.python.ops import while_v2 |
| from tensorflow.python.ops.control_flow_ops import while_loop as while_loop_v1 |
| from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2 |
| from tensorflow.python.platform import test |
| |
| |
| class WhileV2Test(test.TestCase, parameterized.TestCase): |
| |
| def testSingleLoopVar(self): |
| x = constant_op.constant(2.) |
| ret = while_loop_v2(lambda v: v < 8., lambda v: v * v, [x]) |
| grad = gradients_impl.gradients(ret, [x]) |
| with self.test_session() as sess: |
| self.assertEqual(sess.run(ret), 16.) |
| self.assertSequenceEqual(sess.run(grad), [32.]) |
| |
| def testMultipleLoopVarsBasic(self): |
| x = constant_op.constant(5.) |
| y = constant_op.constant(3.) |
| |
| # x = 5. |
| # y = 3. |
| # while x < 45.: |
| # x = x * y |
| ret = while_loop_v2(lambda v, _: v < 45., lambda v, w: (v * w, w), [x, y]) |
| # ret = [x*y^2, y] |
| |
| # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0. |
| grad = gradients_impl.gradients(ret, [x]) # [2*x*y] |
| with self.test_session() as sess: |
| self.assertSequenceEqual(sess.run(ret), [45., 3.]) |
| self.assertSequenceEqual(sess.run(grad), [9.]) |
| |
| def testMultipleLoopVars(self): |
| x = constant_op.constant(5.) |
| y = constant_op.constant(3.) |
| |
| # x = 5. |
| # y = 3. |
| # while x < 45.: |
| # x = x * y |
| # y = x + y |
| ret = while_loop_v2(lambda v, _: v < 45., lambda v, w: (v * w, v + w), |
| [x, y]) |
| # ret = [y*x**2 + x*y**2, x*y + x + y] |
| |
| gradx_0 = gradients_impl.gradients(ret[0], [x]) # [2*x*y + y**2] |
| gradx_1 = gradients_impl.gradients(ret[1], [x]) # [y + 1] |
| gradx_2 = gradients_impl.gradients(ret, [x]) # [2*x*y + y**2 + 2*y + 1] |
| grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2] |
| grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1] |
| grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1] |
| with self.test_session() as sess: |
| self.assertSequenceEqual(sess.run(ret), [120., 23.]) |
| self.assertSequenceEqual(sess.run(gradx_0), [39.]) |
| self.assertSequenceEqual(sess.run(gradx_1), [4.]) |
| self.assertSequenceEqual(sess.run(gradx_2), [43.]) |
| self.assertSequenceEqual(sess.run(grady_0), [55.]) |
| self.assertSequenceEqual(sess.run(grady_1), [6.]) |
| self.assertSequenceEqual(sess.run(grady_2), [61.]) |
| |
| def testMultipleWhileLoops(self): |
| x = constant_op.constant(2.) |
| ret1 = while_loop_v2(lambda v: v < 4., lambda v: v * v, [x]) # x**2 |
| ret2 = while_loop_v2(lambda v: v < 16., lambda v: v * v, ret1) # x**4 |
| grad = gradients_impl.gradients(ret2, [x]) # 4x**3 |
| grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2 |
| with self.test_session() as sess: |
| self.assertSequenceEqual(sess.run(grad), [32.]) |
| self.assertSequenceEqual(sess.run(grad_grad), [48.]) |
| |
| def testDoubleDerivative(self): |
| x = constant_op.constant(2.) |
| ret = while_loop_v2(lambda v: v < 8., lambda v: v**2, [x]) # x**4 |
| grad = gradients_impl.gradients(ret, [x]) # 4x**3 |
| grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2 |
| with self.test_session() as sess: |
| self.assertEqual(sess.run(ret), 16.) |
| self.assertSequenceEqual(sess.run(grad), [32.]) |
| self.assertSequenceEqual(sess.run(grad_grad), [48.]) |
| |
| def testPruning(self): |
| x = constant_op.constant(1) |
| |
| tensor_list = list_ops.empty_tensor_list( |
| element_dtype=x.dtype, element_shape=x.shape) |
| |
| def Cond(x, tl): |
| del tl # Unused for Cond. |
| return x < 5 |
| |
| def Body(x, tl): |
| return x + 1, list_ops.tensor_list_push_back(tl, x) |
| |
| outputs = while_loop_v1(Cond, Body, [x, tensor_list]) |
| |
| train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) |
| train_op.append(outputs[0]) |
| |
| def GetOptimizedGraph(): |
| mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) |
| rewriter_config = rewriter_config_pb2.RewriterConfig( |
| constant_folding=rewriter_config_pb2.RewriterConfig.OFF, |
| memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL) |
| return tf_optimizer.OptimizeGraph(rewriter_config, mg) |
| |
| g = GetOptimizedGraph() |
| self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1) |
| |
| stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype) |
| train_op.append(stack) |
| g = GetOptimizedGraph() |
| self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2) |
| |
| def testCaptureExternalTensorInCond(self): |
| x = constant_op.constant(2.) |
| y = constant_op.constant(1.) |
| ret = while_loop_v2(lambda v: v + y < 9., lambda v: v * 3., [x]) |
| grad = gradients_impl.gradients(ret, [x]) |
| with self.test_session() as sess: |
| self.assertEqual(sess.run(ret), 18.) |
| self.assertSequenceEqual(sess.run(grad), [9.]) |
| |
| def testCaptureExternalTensorInBody(self): |
| x = constant_op.constant(2.) |
| y = constant_op.constant(3.) |
| ret = while_loop_v2(lambda v: v < 8., lambda v: v * y, [x]) |
| grad = gradients_impl.gradients(ret, [x]) |
| with self.test_session() as sess: |
| self.assertEqual(sess.run(ret), 18.) |
| self.assertSequenceEqual(sess.run(grad), [9.]) |
| |
| def testLoopWithTensorListPushBack(self): |
| x = constant_op.constant(2.) |
| |
| tensor_list = list_ops.empty_tensor_list( |
| element_dtype=dtypes.float32, element_shape=ScalarShape()) |
| |
| def Cond(x, tl): |
| del tl # Unused for Cond. |
| return x < 5. |
| |
| def Body(x, tl): |
| tl = list_ops.tensor_list_push_back(tl, x) |
| tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.)) |
| return x**2., tl |
| |
| ret = while_loop_v2(Cond, Body, [x, tensor_list]) |
| grad = gradients_impl.gradients(ret[0], x) |
| with self.test_session() as sess: |
| self.assertEqual(sess.run(ret[0]), 16.) |
| self.assertSequenceEqual(sess.run(grad), [32.]) |
| |
| def testDuplicateAccumulator(self): |
| x = constant_op.constant(2.) |
| |
| tensor_list = list_ops.empty_tensor_list( |
| element_dtype=dtypes.float32, element_shape=ScalarShape()) |
| |
| def Cond(x, tl): |
| del tl # Unused for Cond. |
| return x < 5. |
| |
| def Body(x, tl): |
| # There is an accumulator in the loop already so we should not add |
| # another. |
| tl = list_ops.tensor_list_push_back(tl, x) |
| return x**2., tl |
| |
| ret = while_loop_v2(Cond, Body, [x, tensor_list]) |
| |
| for op in ops.get_default_graph().get_operations(): |
| if op.type == "While": |
| while_op = op |
| |
| body_graph = while_v2._get_body_graph(while_op) |
| # body_graph.inputs: [counter_arg, x_arg, tl_arg, *accumulators] |
| x_input_t = body_graph.inputs[1] |
| accumulator_count = len( |
| [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"]) |
| self.assertEqual(accumulator_count, 1) |
| |
| grad = gradients_impl.gradients(ret[0], x) |
| with self.test_session() as sess: |
| self.assertEqual(sess.run(ret[0]), 16.) |
| self.assertSequenceEqual(sess.run(grad), [32.]) |
| |
| @parameterized.named_parameters( |
| ("UnknownShape", None), |
| ("PartiallyDefinedShape", [None, 2]), |
| ("FullyDefinedShape", [1, 2]), |
| ) |
| def testTensorListOutputElementShape(self, shape): |
| |
| def MatchShape(actual_tensor_shape): |
| # Compare the shapes, treating None dimensions as equal. We do not |
| # directly check actual_tensor_shape and tf.TensorShape(shape) for |
| # equality because tf.Dimension.__eq__ returns None if either dimension is |
| # None. |
| if shape is None: |
| self.assertIsNone(actual_tensor_shape.dims) |
| else: |
| self.assertListEqual(actual_tensor_shape.as_list(), shape) |
| |
| def GetAccumulatorForInputAtIndex(while_op, idx): |
| body_graph = while_v2._get_body_graph(while_op) |
| y_input_t = body_graph.inputs[idx] |
| push_back_node = [c for c in y_input_t.consumers() |
| if c.type == "TensorListPushBack"][0] |
| output_idx = body_graph.outputs.index(push_back_node.outputs[0]) |
| return while_op.outputs[output_idx] |
| |
| x = constant_op.constant(2.) |
| y = array_ops.placeholder(dtype=dtypes.float32, shape=shape) |
| |
| # Forward pass. |
| ret = while_loop_v2(lambda v, u: v < 8., lambda v, u: (v * v, u), [x, y]) |
| while_op = ret[0].op |
| # Get the TensorList output of While op containing the accumulated values |
| # of y. |
| # while_op.inputs: [counter_arg, x_arg, y_arg, *accumulators] |
| output = GetAccumulatorForInputAtIndex(while_op, 2) |
| _, val = list_ops.tensor_list_pop_back(output, |
| element_dtype=dtypes.float32) |
| MatchShape(val.shape) |
| |
| # Gradient pass. |
| grad = gradients_impl.gradients(ret[1], y) |
| grad_while_op = grad[0].op |
| # Get the TensorList output of gradient While op containing the accumulated |
| # values of grad_y. |
| # grad_while_op.inputs: |
| # [counter_arg, total_iters_arg, grad_x_arg, grad_y_arg, *other_args] |
| grad_output = GetAccumulatorForInputAtIndex(grad_while_op, 4) |
| _, val = list_ops.tensor_list_pop_back(grad_output, |
| element_dtype=dtypes.float32) |
| MatchShape(val.shape) |
| |
| |
| def ScalarShape(): |
| return ops.convert_to_tensor([], dtype=dtypes.int32) |
| |
| |
| if __name__ == "__main__": |
| test.main() |