blob: 0c3b72408ed548ce5a94172fef49c2e500ff00a2 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for while_v2."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import while_v2
from tensorflow.python.ops.control_flow_ops import while_loop as while_loop_v1
from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
from tensorflow.python.platform import test
class WhileV2Test(test.TestCase, parameterized.TestCase):
def testSingleLoopVar(self):
x = constant_op.constant(2.)
ret = while_loop_v2(lambda v: v < 8., lambda v: v * v, [x])
grad = gradients_impl.gradients(ret, [x])
with self.test_session() as sess:
self.assertEqual(sess.run(ret), 16.)
self.assertSequenceEqual(sess.run(grad), [32.])
def testMultipleLoopVarsBasic(self):
x = constant_op.constant(5.)
y = constant_op.constant(3.)
# x = 5.
# y = 3.
# while x < 45.:
# x = x * y
ret = while_loop_v2(lambda v, _: v < 45., lambda v, w: (v * w, w), [x, y])
# ret = [x*y^2, y]
# Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
grad = gradients_impl.gradients(ret, [x]) # [2*x*y]
with self.test_session() as sess:
self.assertSequenceEqual(sess.run(ret), [45., 3.])
self.assertSequenceEqual(sess.run(grad), [9.])
def testMultipleLoopVars(self):
x = constant_op.constant(5.)
y = constant_op.constant(3.)
# x = 5.
# y = 3.
# while x < 45.:
# x = x * y
# y = x + y
ret = while_loop_v2(lambda v, _: v < 45., lambda v, w: (v * w, v + w),
[x, y])
# ret = [y*x**2 + x*y**2, x*y + x + y]
gradx_0 = gradients_impl.gradients(ret[0], [x]) # [2*x*y + y**2]
gradx_1 = gradients_impl.gradients(ret[1], [x]) # [y + 1]
gradx_2 = gradients_impl.gradients(ret, [x]) # [2*x*y + y**2 + 2*y + 1]
grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2]
grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1]
grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1]
with self.test_session() as sess:
self.assertSequenceEqual(sess.run(ret), [120., 23.])
self.assertSequenceEqual(sess.run(gradx_0), [39.])
self.assertSequenceEqual(sess.run(gradx_1), [4.])
self.assertSequenceEqual(sess.run(gradx_2), [43.])
self.assertSequenceEqual(sess.run(grady_0), [55.])
self.assertSequenceEqual(sess.run(grady_1), [6.])
self.assertSequenceEqual(sess.run(grady_2), [61.])
def testMultipleWhileLoops(self):
x = constant_op.constant(2.)
ret1 = while_loop_v2(lambda v: v < 4., lambda v: v * v, [x]) # x**2
ret2 = while_loop_v2(lambda v: v < 16., lambda v: v * v, ret1) # x**4
grad = gradients_impl.gradients(ret2, [x]) # 4x**3
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
with self.test_session() as sess:
self.assertSequenceEqual(sess.run(grad), [32.])
self.assertSequenceEqual(sess.run(grad_grad), [48.])
def testDoubleDerivative(self):
x = constant_op.constant(2.)
ret = while_loop_v2(lambda v: v < 8., lambda v: v**2, [x]) # x**4
grad = gradients_impl.gradients(ret, [x]) # 4x**3
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
with self.test_session() as sess:
self.assertEqual(sess.run(ret), 16.)
self.assertSequenceEqual(sess.run(grad), [32.])
self.assertSequenceEqual(sess.run(grad_grad), [48.])
def testPruning(self):
x = constant_op.constant(1)
tensor_list = list_ops.empty_tensor_list(
element_dtype=x.dtype, element_shape=x.shape)
def Cond(x, tl):
del tl # Unused for Cond.
return x < 5
def Body(x, tl):
return x + 1, list_ops.tensor_list_push_back(tl, x)
outputs = while_loop_v1(Cond, Body, [x, tensor_list])
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(outputs[0])
def GetOptimizedGraph():
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
rewriter_config = rewriter_config_pb2.RewriterConfig(
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
return tf_optimizer.OptimizeGraph(rewriter_config, mg)
g = GetOptimizedGraph()
self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1)
stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
train_op.append(stack)
g = GetOptimizedGraph()
self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
def testCaptureExternalTensorInCond(self):
x = constant_op.constant(2.)
y = constant_op.constant(1.)
ret = while_loop_v2(lambda v: v + y < 9., lambda v: v * 3., [x])
grad = gradients_impl.gradients(ret, [x])
with self.test_session() as sess:
self.assertEqual(sess.run(ret), 18.)
self.assertSequenceEqual(sess.run(grad), [9.])
def testCaptureExternalTensorInBody(self):
x = constant_op.constant(2.)
y = constant_op.constant(3.)
ret = while_loop_v2(lambda v: v < 8., lambda v: v * y, [x])
grad = gradients_impl.gradients(ret, [x])
with self.test_session() as sess:
self.assertEqual(sess.run(ret), 18.)
self.assertSequenceEqual(sess.run(grad), [9.])
def testLoopWithTensorListPushBack(self):
x = constant_op.constant(2.)
tensor_list = list_ops.empty_tensor_list(
element_dtype=dtypes.float32, element_shape=ScalarShape())
def Cond(x, tl):
del tl # Unused for Cond.
return x < 5.
def Body(x, tl):
tl = list_ops.tensor_list_push_back(tl, x)
tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.))
return x**2., tl
ret = while_loop_v2(Cond, Body, [x, tensor_list])
grad = gradients_impl.gradients(ret[0], x)
with self.test_session() as sess:
self.assertEqual(sess.run(ret[0]), 16.)
self.assertSequenceEqual(sess.run(grad), [32.])
def testDuplicateAccumulator(self):
x = constant_op.constant(2.)
tensor_list = list_ops.empty_tensor_list(
element_dtype=dtypes.float32, element_shape=ScalarShape())
def Cond(x, tl):
del tl # Unused for Cond.
return x < 5.
def Body(x, tl):
# There is an accumulator in the loop already so we should not add
# another.
tl = list_ops.tensor_list_push_back(tl, x)
return x**2., tl
ret = while_loop_v2(Cond, Body, [x, tensor_list])
for op in ops.get_default_graph().get_operations():
if op.type == "While":
while_op = op
body_graph = while_v2._get_body_graph(while_op)
# body_graph.inputs: [counter_arg, x_arg, tl_arg, *accumulators]
x_input_t = body_graph.inputs[1]
accumulator_count = len(
[c for c in x_input_t.consumers() if c.type == "TensorListPushBack"])
self.assertEqual(accumulator_count, 1)
grad = gradients_impl.gradients(ret[0], x)
with self.test_session() as sess:
self.assertEqual(sess.run(ret[0]), 16.)
self.assertSequenceEqual(sess.run(grad), [32.])
@parameterized.named_parameters(
("UnknownShape", None),
("PartiallyDefinedShape", [None, 2]),
("FullyDefinedShape", [1, 2]),
)
def testTensorListOutputElementShape(self, shape):
def MatchShape(actual_tensor_shape):
# Compare the shapes, treating None dimensions as equal. We do not
# directly check actual_tensor_shape and tf.TensorShape(shape) for
# equality because tf.Dimension.__eq__ returns None if either dimension is
# None.
if shape is None:
self.assertIsNone(actual_tensor_shape.dims)
else:
self.assertListEqual(actual_tensor_shape.as_list(), shape)
def GetAccumulatorForInputAtIndex(while_op, idx):
body_graph = while_v2._get_body_graph(while_op)
y_input_t = body_graph.inputs[idx]
push_back_node = [c for c in y_input_t.consumers()
if c.type == "TensorListPushBack"][0]
output_idx = body_graph.outputs.index(push_back_node.outputs[0])
return while_op.outputs[output_idx]
x = constant_op.constant(2.)
y = array_ops.placeholder(dtype=dtypes.float32, shape=shape)
# Forward pass.
ret = while_loop_v2(lambda v, u: v < 8., lambda v, u: (v * v, u), [x, y])
while_op = ret[0].op
# Get the TensorList output of While op containing the accumulated values
# of y.
# while_op.inputs: [counter_arg, x_arg, y_arg, *accumulators]
output = GetAccumulatorForInputAtIndex(while_op, 2)
_, val = list_ops.tensor_list_pop_back(output,
element_dtype=dtypes.float32)
MatchShape(val.shape)
# Gradient pass.
grad = gradients_impl.gradients(ret[1], y)
grad_while_op = grad[0].op
# Get the TensorList output of gradient While op containing the accumulated
# values of grad_y.
# grad_while_op.inputs:
# [counter_arg, total_iters_arg, grad_x_arg, grad_y_arg, *other_args]
grad_output = GetAccumulatorForInputAtIndex(grad_while_op, 4)
_, val = list_ops.tensor_list_pop_back(grad_output,
element_dtype=dtypes.float32)
MatchShape(val.shape)
def ScalarShape():
return ops.convert_to_tensor([], dtype=dtypes.int32)
if __name__ == "__main__":
test.main()