blob: 4af49bf025cacb4587d3ce0be5a960978124f2cb [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for while_v2."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.compat import compat
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import random_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops import while_v2
from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
from tensorflow.python.platform import test
def random_gamma(shape): # pylint: disable=invalid-name
return random_ops.random_gamma(shape, 1.0)
def random_gamma_with_alpha_beta(shape): # pylint: disable=invalid-name
return random_ops.random_gamma(
shape, alpha=[[1.], [3.], [5.], [6.]], beta=[[3., 4.]])
def random_poisson_v2(shape): # pylint: disable=invalid-name
return random_ops.random_poisson_v2(shape, 1.0)
def random_poisson_v2_with_lam(shape): # pylint: disable=invalid-name
return random_ops.random_poisson_v2(shape, [12.2, 3.3])
def fill(shape): # pylint: disable=invalid-name
return array_ops.fill(shape, 1.0)
class WhileV2Test(test.TestCase, parameterized.TestCase):
@test_util.run_deprecated_v1
def testSingleLoopVar(self):
x = constant_op.constant(2.)
ret = while_loop_v2(
lambda v: v < 8., lambda v: v * v, [x], return_same_structure=False)
grad = gradients_impl.gradients(ret, [x])
with self.cached_session():
self.assertEqual(self.evaluate(ret), 16.)
self.assertSequenceEqual(self.evaluate(grad), [32.])
@test_util.run_v1_only("b/120545219")
def testReturnSameStructureTrue(self):
x = constant_op.constant(2.)
ret = while_loop_v2(
lambda v: v < 8., lambda v: v * v, [x], return_same_structure=True)
grad = gradients_impl.gradients(ret, [x])
with self.cached_session() as sess:
eval_result = sess.run(ret)
self.assertIsInstance(eval_result, list)
self.assertLen(eval_result, 1)
self.assertEqual(16., eval_result[0])
self.assertSequenceEqual(sess.run(grad), [32.])
def testVerifyInputOutputTypesMatch(self):
@def_function.function
def BuildWhile():
x = constant_op.constant(1., dtypes.float32)
def Body(x):
return math_ops.cast(x, dtypes.float16) + 1
while_loop_v2(lambda x: x < 10, Body, [x])
with self.assertRaisesRegexp(
TypeError,
r"Loop var Const:0 enters the loop with type <dtype: 'float32'> "
r"but has type <dtype: 'float16'> after 1 iteration."):
BuildWhile()
def testGradientTapeResourceVariable(self):
with context.eager_mode():
v = variables.Variable(1.)
@def_function.function
def fnWithLoop(): # pylint: disable=invalid-name
with backprop.GradientTape() as tape:
_, x = while_loop_v2(
lambda i, _: i < 2,
lambda i, x: (i + 1, x * v),
[0, 2.])
return tape.gradient(x, v)
self.assertAllEqual(fnWithLoop(), 4.0)
def testExternalControlDependencies(self):
with ops.Graph().as_default(), self.test_session():
v = variables.Variable(1.)
v.initializer.run()
op = v.assign_add(1.)
def body_fn(i): # pylint: disable=invalid-name
with ops.control_dependencies([op]):
return i + 1
loop = while_loop_v2(lambda i: i < 1, body_fn, [0])
loop[0].op.run()
self.assertAllEqual(self.evaluate(v), 2.0)
@test_util.run_deprecated_v1
def testMultipleLoopVarsBasic(self):
x = constant_op.constant(5.)
y = constant_op.constant(3.)
# x = 5.
# y = 3.
# while x < 45.:
# x = x * y
ret = while_loop_v2(
lambda v, _: v < 45.,
lambda v, w: (v * w, w), [x, y],
return_same_structure=False)
# ret = [x*y^2, y]
# Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
grad = gradients_impl.gradients(ret, [x]) # [2*x*y]
with self.cached_session():
self.assertSequenceEqual(self.evaluate(ret), [45., 3.])
self.assertSequenceEqual(self.evaluate(grad), [9.])
@test_util.run_deprecated_v1
def testMultipleLoopNonscalarCond(self):
x = constant_op.constant([[5.]])
y = constant_op.constant(3.)
# x = 5.
# y = 3.
# while x < 45.:
# x = x * y
ret = while_loop_v2(
lambda v, _: v < 45.,
lambda v, w: (v * w, w), [x, y],
return_same_structure=False)
# ret == [x*y^2, y]
# Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
grad = gradients_impl.gradients(ret, [x]) # [2*x*y]
with self.cached_session():
self.assertSequenceEqual(self.evaluate(ret), [45., 3.])
self.assertSequenceEqual(self.evaluate(grad), [9.])
@test_util.run_deprecated_v1
def testMultipleLoopVars(self):
x = constant_op.constant(5.)
y = constant_op.constant(3.)
# x = 5.
# y = 3.
# while x < 45.:
# x = x * y
# y = x + y
ret = while_loop_v2(
lambda v, _: v < 45.,
lambda v, w: (v * w, v + w), [x, y],
return_same_structure=False)
# ret = [y*x**2 + x*y**2, x*y + x + y]
gradx_0 = gradients_impl.gradients(ret[0], [x]) # [2*x*y + y**2]
gradx_1 = gradients_impl.gradients(ret[1], [x]) # [y + 1]
gradx_2 = gradients_impl.gradients(ret, [x]) # [2*x*y + y**2 + 2*y + 1]
grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2]
grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1]
grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1]
with self.cached_session():
self.assertSequenceEqual(self.evaluate(ret), [120., 23.])
self.assertSequenceEqual(self.evaluate(gradx_0), [39.])
self.assertSequenceEqual(self.evaluate(gradx_1), [4.])
self.assertSequenceEqual(self.evaluate(gradx_2), [43.])
self.assertSequenceEqual(self.evaluate(grady_0), [55.])
self.assertSequenceEqual(self.evaluate(grady_1), [6.])
self.assertSequenceEqual(self.evaluate(grady_2), [61.])
@test_util.run_deprecated_v1
def testGradientTape(self):
with backprop.GradientTape() as t:
x = constant_op.constant(2.)
t.watch(x)
ret = while_loop_v2(
lambda v: v < 4., lambda v: v * v, [x],
return_same_structure=False) # x**2
grad = t.gradient(ret, x)
with self.cached_session() as sess:
self.assertAllEqual(sess.run(grad), 4.0)
@test_util.run_deprecated_v1
def testMultipleWhileLoops(self):
x = constant_op.constant(2.)
ret1 = while_loop_v2(
lambda v: v < 4., lambda v: v * v, [x],
return_same_structure=False) # x**2
ret2 = while_loop_v2(
lambda v: v < 16., lambda v: v * v, [ret1],
return_same_structure=False) # x**4
grad = gradients_impl.gradients(ret2, [x]) # 4x**3
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
with self.cached_session():
self.assertSequenceEqual(self.evaluate(grad), [32.])
self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
def testMultipleWhileLoopsWithFunc(self):
if compat.forward_compatible(2019, 8, 23):
x = constant_op.constant(2.)
@def_function.function
def Fn():
ret1 = while_loop_v2(
lambda v: v < 4.,
lambda v: v * v, [x],
return_same_structure=False,
name="while_1") # x**2
ret2 = while_loop_v2(
lambda v: v < 16.,
lambda v: v * v, [x],
return_same_structure=False,
name="while_2") # x**4
return ret1, ret2
concrete_fn = Fn.get_concrete_function()
while_1 = concrete_fn.graph.get_operation_by_name("while_1")
while_2 = concrete_fn.graph.get_operation_by_name("while_2")
self.assertEqual(while_1.type, "StatelessWhile")
self.assertEqual(while_2.type, "StatelessWhile")
self.assertEmpty(while_1.control_inputs)
self.assertEmpty(while_2.control_inputs)
def testMultipleWhileLoopsWithDeps(self):
if compat.forward_compatible(2019, 8, 23):
x = variables.Variable(2.)
c = constant_op.constant(2.)
@def_function.function
def Fn():
ret1 = while_loop_v2(
lambda v: v < 4.,
lambda v: v * x, [c],
return_same_structure=False,
name="while_1") # 2x
ret2 = while_loop_v2(
lambda v: v < 16.,
lambda v: v * x * x, [c],
return_same_structure=False,
name="while_2") # 4x
return ret1, ret2
concrete_fn = Fn.get_concrete_function()
while_1 = concrete_fn.graph.get_operation_by_name("while_1")
while_2 = concrete_fn.graph.get_operation_by_name("while_2")
self.assertEqual(while_1.type, "While")
self.assertEqual(while_2.type, "While")
self.assertEmpty(while_1.control_inputs)
self.assertLen(while_2.control_inputs, 1)
self.assertIs(while_2.control_inputs[0], while_1)
def testMultipleWhileLoopsWithVarsDeps(self):
if compat.forward_compatible(2019, 8, 23):
x1 = variables.Variable(2.)
x2 = variables.Variable(3.)
c = constant_op.constant(2.)
@def_function.function
def Fn():
ret1 = while_loop_v2(
lambda v: v < 4.,
lambda v: v * x1, [c],
return_same_structure=False,
name="while_1") # 2x
ret2 = while_loop_v2(
lambda v: v < 16.,
lambda v: v * x1 * x1, [c],
return_same_structure=False,
name="while_2") # 4x
ret3 = while_loop_v2(
lambda v: v < 4.,
lambda v: v * x2, [c],
return_same_structure=False,
name="while_3") # 3x
ret4 = while_loop_v2(
lambda v: v < 16.,
lambda v: v * x2 * x2, [c],
return_same_structure=False,
name="while_4") # 9x
ret5 = while_loop_v2(
lambda v: v < 16.,
lambda v: v * v, [c],
return_same_structure=False,
name="while_stateless") # x**2
return ret1, ret2, ret3, ret4, ret5
concrete_fn = Fn.get_concrete_function()
while_1 = concrete_fn.graph.get_operation_by_name("while_1")
while_2 = concrete_fn.graph.get_operation_by_name("while_2")
while_3 = concrete_fn.graph.get_operation_by_name("while_3")
while_4 = concrete_fn.graph.get_operation_by_name("while_4")
while_stateless = concrete_fn.graph.get_operation_by_name(
"while_stateless")
self.assertEqual(while_1.type, "While")
self.assertEqual(while_2.type, "While")
self.assertEqual(while_3.type, "While")
self.assertEqual(while_4.type, "While")
self.assertEqual(while_stateless.type, "StatelessWhile")
self.assertEmpty(while_1.control_inputs)
self.assertLen(while_2.control_inputs, 1)
self.assertIs(while_2.control_inputs[0], while_1)
self.assertEmpty(while_3.control_inputs)
self.assertLen(while_4.control_inputs, 1)
self.assertIs(while_4.control_inputs[0], while_3)
self.assertEmpty(while_stateless.control_inputs)
@test_util.run_deprecated_v1
def testDoubleDerivative(self):
x = constant_op.constant(2.)
ret = while_loop_v2(
lambda v: v < 8., lambda v: v**2, [x],
return_same_structure=False) # x**4
grad = gradients_impl.gradients(ret, [x]) # 4x**3
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
with self.cached_session():
self.assertEqual(self.evaluate(ret), 16.)
self.assertSequenceEqual(self.evaluate(grad), [32.])
self.assertSequenceEqual(self.evaluate(grad_grad), [48.])
@test_util.run_v2_only
def testMultipleWhileLoopsEager(self):
@def_function.function
def Func():
x = constant_op.constant(2.)
ret1 = while_loop_v2(
lambda v: v < 4., lambda v: v * v, [x],
return_same_structure=False) # x**2
ret2 = while_loop_v2(
lambda v: v < 16.,
lambda v: v * v, [ret1],
return_same_structure=False) # x**4
grad = gradients_impl.gradients(ret2, [x])[0] # 4x**3
grad_grad = gradients_impl.gradients(grad, [x])[0] # 12x**2
return grad, grad_grad
grad, grad_grad = Func()
self.assertEqual(grad.numpy(), 32.)
self.assertEqual(grad_grad.numpy(), 48.)
@test_util.run_v2_only
def testDoubleDerivativeEager(self):
@def_function.function
def Func():
x = constant_op.constant(2.)
ret = while_loop_v2(
lambda v: v < 8., lambda v: v**2, [x],
return_same_structure=False) # x**4
grad = gradients_impl.gradients(ret, [x])[0] # 4x**3
grad_grad = gradients_impl.gradients(grad, [x])[0] # 12x**2
return ret, grad, grad_grad
ret, grad, grad_grad = Func()
self.assertEqual(ret.numpy(), 16.)
self.assertEqual(grad.numpy(), 32.)
self.assertEqual(grad_grad.numpy(), 48.)
def _testPruning(self):
x = constant_op.constant(1)
tensor_list = list_ops.empty_tensor_list(
element_dtype=x.dtype, element_shape=x.shape)
def Cond(x, tl):
del tl # Unused for Cond.
return x < 5
def Body(x, tl):
return x + 1, list_ops.tensor_list_push_back(tl, x)
outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list])
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(outputs[0])
g = GetOptimizedGraph()
# TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
# away, causing an extra Enter node.
enter_count = 2 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 1
self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
# Test that the TensorList is pruned out.
self.assertEmpty([
n for n in g.node if n.op == "Enter" and
n.attr["T"].type == dtypes.variant.as_datatype_enum
])
self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"])
stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
train_op.append(stack)
g = GetOptimizedGraph()
# TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
# away, causing an extra Enter node.
enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
# Test that the TensorList is not pruned out.
self.assertNotEmpty([
n for n in g.node if n.op == "Enter" and
n.attr["T"].type == dtypes.variant.as_datatype_enum
])
self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"])
@test_util.run_deprecated_v1
def testPruningV1(self):
self._testPruning()
@test_util.enable_control_flow_v2
@test_util.run_deprecated_v1
def testPruningV2(self):
self._testPruning()
def _testDoNotAccumulateInvariants(self):
push_op = ("TensorListPushBack"
if control_flow_v2_toggles.control_flow_v2_enabled() else
"StackPushV2")
# Tests that loop invariants, i.e., tensors that are "captured" by the
# while loop and not passed as loop variables are not accumulated in
# gradient computation.
v = constant_op.constant(5.0, name="v")
r = control_flow_ops.while_loop(
lambda _: True, lambda x: v * x, [1.0], maximum_iterations=5)
output = gradients_impl.gradients(r, v)[0]
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(output)
g = GetOptimizedGraph()
# The gradient for v * x requires the value of both v and x. Since v is a
# loop invariant it is not accumulated so we have just one accumulator for
# x.
self.assertLen([n for n in g.node if n.op == push_op], 1)
@test_util.run_deprecated_v1
def testDoNotAccumulateInvariantsV1(self):
self._testDoNotAccumulateInvariants()
@test_util.run_deprecated_v1
@test_util.enable_control_flow_v2
def testDoNotAccumulateInvariantsV2(self):
self._testDoNotAccumulateInvariants()
@test_util.enable_control_flow_v2
@test_util.run_deprecated_v1
@test_util.enable_output_all_intermediates
def testPruningNested(self):
assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
x = constant_op.constant(0)
tensor_list = list_ops.empty_tensor_list(
element_dtype=x.dtype, element_shape=x.shape)
def Cond(x, tl):
del tl # Unused for Cond.
return x < 25
def Body(x, tl):
def InnerCond(inner_x, unused_outer_x, unused_tl):
return inner_x < 5
def InnerBody(inner_x, outer_x, tl):
return inner_x + 1, outer_x + 1, list_ops.tensor_list_push_back(tl, x)
inner_x = constant_op.constant(0)
return control_flow_ops.while_loop(InnerCond, InnerBody,
[inner_x, x, tl])[1:]
outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list])
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(outputs[0])
g = GetOptimizedGraph()
# TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
# away, causing an extra Enter node.
# enter_count = 4 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
# self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
# Test that the TensorList is pruned out.
self.assertEmpty([
n for n in g.node if n.op == "Enter" and
n.attr["T"].type == dtypes.variant.as_datatype_enum
])
self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"])
self.assertEmpty([n for n in g.node if n.op == "_While"])
stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
train_op.append(stack)
g = GetOptimizedGraph()
# TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
# away, causing an extra Enter node.
# enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
# self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
# Test that the TensorList is not pruned out.
self.assertNotEmpty([
n for n in g.node if n.op == "Enter" and
n.attr["T"].type == dtypes.variant.as_datatype_enum
])
self.assertNotEmpty([n for n in g.node if n.op == "TensorListPushBack"])
@test_util.enable_control_flow_v2
@test_util.run_deprecated_v1
@test_util.enable_output_all_intermediates
def testPruningNested2(self):
assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
v = constant_op.constant(5.0, name="v")
p = array_ops.placeholder(dtype=dtypes.int32)
def MidBodyBuilder(iterations):
def MidBody(i, x):
r = control_flow_ops.while_loop(
lambda *_: True,
lambda i, x: (i + 1, math_ops.multiply(v, x, name="my_mul")),
(0, x),
maximum_iterations=iterations,
name="inner")
return (i + 1, gradients_impl.gradients(x + r[1], v)[0])
return MidBody
def OuterBody(i, x):
iterations = array_ops.size(p, name="iterations")
return (i + 1, x + control_flow_ops.while_loop(
lambda *_: True,
MidBodyBuilder(iterations), (0, x),
maximum_iterations=iterations,
name="mid")[1])
def CreateWhileLoop():
with ops.device("/cpu:0"):
r = control_flow_ops.while_loop(
lambda *_: True,
OuterBody, (0, 1.0),
maximum_iterations=5,
name="outer")
return array_ops.identity(r[1])
output = CreateWhileLoop()
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(output)
g = GetOptimizedGraph()
self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1)
@test_util.enable_control_flow_v2
@test_util.run_deprecated_v1
@test_util.enable_output_all_intermediates
def testPruningNested3(self):
assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
v = constant_op.constant(5.0, name="v")
def CreateWhileLoop():
r = control_flow_ops.while_loop(
lambda _: True,
lambda x: math_ops.multiply(v, x, name="my_mul"), [1.0],
maximum_iterations=5,
name="outer")
return array_ops.identity(r)
r = CreateWhileLoop()
output = gradients_impl.gradients(r, v)[0]
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(output)
g = GetOptimizedGraph()
self.assertLen([n for n in g.node if n.op == "TensorListPushBack"], 1)
@test_util.run_deprecated_v1
def testCaptureExternalTensorInCond(self):
x = constant_op.constant(2.)
y = constant_op.constant(1.)
ret = while_loop_v2(
lambda v: v + y < 9.,
lambda v: v * 3., [x],
return_same_structure=False)
grad = gradients_impl.gradients(ret, [x])
with self.cached_session():
self.assertEqual(self.evaluate(ret), 18.)
self.assertSequenceEqual(self.evaluate(grad), [9.])
@test_util.run_deprecated_v1
def testCaptureExternalTensorInBody(self):
x = constant_op.constant(2.)
y = constant_op.constant(3.)
ret = while_loop_v2(
lambda v: v < 8., lambda v: v * y, [x], return_same_structure=False)
grad = gradients_impl.gradients(ret, [x])
with self.cached_session():
self.assertEqual(self.evaluate(ret), 18.)
self.assertSequenceEqual(self.evaluate(grad), [9.])
@test_util.run_deprecated_v1
def testLoopWithTensorListPushBack(self):
x = constant_op.constant(2.)
tensor_list = list_ops.empty_tensor_list(
element_dtype=dtypes.float32, element_shape=ScalarShape())
def Cond(x, tl):
del tl # Unused for Cond.
return x < 5.
def Body(x, tl):
tl = list_ops.tensor_list_push_back(tl, x)
tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.))
return x**2., tl
ret = while_loop_v2(
Cond, Body, [x, tensor_list], return_same_structure=False)
grad = gradients_impl.gradients(ret[0], x)
with self.cached_session() as sess:
self.assertEqual(sess.run(ret[0]), 16.)
self.assertSequenceEqual(self.evaluate(grad), [32.])
@test_util.run_deprecated_v1
def testDuplicateAccumulator(self):
x = constant_op.constant(2.)
tensor_list = list_ops.empty_tensor_list(
element_dtype=dtypes.float32, element_shape=ScalarShape())
def Cond(x, tl):
del tl # Unused for Cond.
return x < 5.
def Body(x, tl):
# There is an accumulator in the loop already so we should not add
# another.
tl = list_ops.tensor_list_push_back(tl, x)
return x**2., tl
ret = while_loop_v2(
Cond, Body, [x, tensor_list], return_same_structure=False)
for op in ops.get_default_graph().get_operations():
if op.type == "While" or op.type == "StatelessWhile":
while_op = op
body_graph = while_v2._get_graph(while_op, "body")
x_input_index = [i for i, inp in enumerate(while_op.inputs) if inp == x][0]
x_input_t = body_graph.inputs[x_input_index]
accumulator_count = len(
[c for c in x_input_t.consumers() if c.type == "TensorListPushBack"])
self.assertEqual(accumulator_count, 1)
grad = gradients_impl.gradients(ret[0], x)
with self.cached_session() as sess:
self.assertEqual(sess.run(ret[0]), 16.)
self.assertSequenceEqual(self.evaluate(grad), [32.])
@parameterized.named_parameters(
("UnknownShape", None),
("PartiallyDefinedShape", [None, 2]),
("FullyDefinedShape", [1, 2]),
)
@test_util.run_deprecated_v1
def testAccumulatorElementShape(self, shape):
def MatchShape(actual_tensor_shape):
# Compare the shapes, treating None dimensions as equal. We do not
# directly check actual_tensor_shape and tf.TensorShape(shape) for
# equality because tf.Dimension.__eq__ returns None if either dimension is
# None.
if shape is None:
self.assertIsNone(actual_tensor_shape.dims)
else:
self.assertListEqual(actual_tensor_shape.as_list(), shape)
def GetAccumulatorForInputAtIndex(while_op, idx):
body_graph = while_v2._get_graph(while_op, "body")
y_input_t = body_graph.inputs[idx]
push_back_node = [c for c in y_input_t.consumers()
if c.type == "TensorListPushBack"][0]
output_idx = body_graph.outputs.index(push_back_node.outputs[0])
return while_op.outputs[output_idx]
x = array_ops.placeholder(dtype=dtypes.float32, shape=shape)
y = array_ops.placeholder(dtype=dtypes.float32, shape=shape)
# Forward pass.
ret = while_loop_v2(lambda v, u: v < 8.,
lambda v, u: (math_ops.pow(v, u), u),
[x, y],
return_same_structure=True)
while_op = ret[0].op.inputs[0].op
# Gradient pass.
grad = gradients_impl.gradients(ret[0], x)
# Note: There is an Identity b/w grad[0] and the While op.
grad_while_op = grad[0].op.inputs[0].op
# Get the TensorList output of While op containing the accumulated values
# of y.
x_input_index = [i for i, inp in enumerate(while_op.inputs) if x == inp][0]
output = GetAccumulatorForInputAtIndex(while_op, x_input_index)
_, val = list_ops.tensor_list_pop_back(output,
element_dtype=dtypes.float32)
MatchShape(val.shape)
# Take second derivative to generate intermediate grad_while_op outputs
gradients_impl.gradients(grad, x)
# Get the TensorList output of gradient While op containing the accumulated
# values of grad_x (note that grad_x is needed by the second derivative).
# grad_while_op.inputs:
grad_output_index = grad_while_op.outputs.index(grad[0].op.inputs[0])
grad_output = GetAccumulatorForInputAtIndex(grad_while_op,
grad_output_index)
_, val = list_ops.tensor_list_pop_back(grad_output,
element_dtype=dtypes.float32)
MatchShape(val.shape)
def _createWhile(self, name):
"""Helper function testDefaultName."""
output = while_v2.while_loop(
lambda i: i < 3,
lambda i: i + 1, [constant_op.constant(0)],
return_same_structure=False)
while_op = output.op.inputs[0].op
if compat.forward_compatible(2019, 8, 23):
self.assertEqual(while_op.type, "StatelessWhile")
return while_op
def testDefaultName(self):
with ops.Graph().as_default():
while_op = self._createWhile(None)
self.assertEqual(while_op.name, "while")
self.assertRegexpMatches(
while_op.get_attr("cond").name, r"while_cond_\d*")
self.assertRegexpMatches(
while_op.get_attr("body").name, r"while_body_\d*")
with ops.Graph().as_default():
with ops.name_scope("foo"):
while1_op = self._createWhile("")
self.assertEqual(while1_op.name, "foo/while")
self.assertRegexpMatches(
while1_op.get_attr("cond").name, r"foo_while_cond_\d*")
self.assertRegexpMatches(
while1_op.get_attr("body").name, r"foo_while_body_\d*")
while2_op = self._createWhile(None)
self.assertEqual(while2_op.name, "foo/while_1")
self.assertRegexpMatches(
while2_op.get_attr("cond").name, r"foo_while_1_cond_\d*")
self.assertRegexpMatches(
while2_op.get_attr("body").name, r"foo_while_1_body_\d*")
@test_util.enable_control_flow_v2
@test_util.run_deprecated_v1
def testWhileAndTensorArray(self):
param = constant_op.constant(2.0)
y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
# map_fn uses TensorArray internally.
r = map_fn.map_fn(lambda x: math_ops.multiply(x, param), y0)
grad = gradients_impl.gradients(r, param)[0]
self.assertAllClose([2.0, 4.0, 6.0, 8.0, 10.0, 12.0], self.evaluate(r))
self.assertAllClose(21.0, self.evaluate(grad))
@test_util.run_deprecated_v1
def testNestedWhile(self):
# Compute sum of geometric progression: n^0 + n^1 + ... + n^m
# We compute the pow using a while loop.
n = constant_op.constant(3.)
m = constant_op.constant(5.)
sum_of_powers = constant_op.constant(0.)
def Body(i, previous_sum):
prod = constant_op.constant(1.)
return i - 1., previous_sum + while_loop_v2(
lambda c, _: c > 0,
lambda c, v: (c - 1., v * n), [i, prod],
return_same_structure=False)[1]
result = while_loop_v2(
lambda i, _: i >= 0,
Body, [m, sum_of_powers],
return_same_structure=False)[1]
grad = gradients_impl.gradients(result, [n])
with self.cached_session() as sess:
self.assertEqual(self.evaluate(result), 364.)
self.assertSequenceEqual(self.evaluate(grad), [547.])
@test_util.run_deprecated_v1
def testIdentityNodeInBody(self):
def Body(v):
v = array_ops.identity(v)
v = array_ops.identity(v)
return v * v
x = constant_op.constant(2.)
ret = while_loop_v2(
lambda v: v < 8., Body, [x], return_same_structure=False)
grad = gradients_impl.gradients(ret, [x])
with self.cached_session() as sess:
self.assertEqual(self.evaluate(ret), 16.)
self.assertSequenceEqual(self.evaluate(grad), [32.])
@test_util.run_deprecated_v1
def testForwardPassRewrite(self):
if compat.forward_compatible(2019, 8, 23):
x = constant_op.constant(1.0, name="x")
output = while_v2.while_loop(lambda x: x < 10.0,
lambda x: x * 2.0,
[x])[0]
while_op = output.op.inputs[0].op
self.assertEqual(while_op.type, "StatelessWhile")
# outputs = [loop_counter, max_iters, x]
self.assertLen(while_op.outputs, 3)
gradients_impl.gradients(output, x)
# while_op should have been rewritten to output intermediates.
# outputs = [loop_counter, max_iters, x, x_accumulator]
self.assertLen(while_op.outputs, 4)
gradients_impl.gradients(output, x)
# Computing the gradient again shouldn't rewrite while_op again.
self.assertLen(while_op.outputs, 4)
@parameterized.named_parameters(
("RandomUniform", random_ops.random_uniform, [5, 3]),
("RandomNormal", random_ops.random_normal, [5, 3]),
("ParameterizedTruncatedNormal",
random_ops.parameterized_truncated_normal, [5, 3]),
("TruncatedNormal", random_ops.truncated_normal, [5, 3]),
("RandomGamma", random_gamma, [5, 3]),
("RandomPoissonV2", random_poisson_v2, [5, 3]),
("RandomGammaWithAlphaBeta", random_gamma_with_alpha_beta, [5, 3, 4, 2]),
("RandomPoissonV2WithLam", random_poisson_v2_with_lam, [5, 3, 2]),
)
@test_util.run_deprecated_v1
def testRandomOpsShape(self, random_fn, expected_shape):
shape = constant_op.constant([3])
def Body(i, u):
shape_extended = array_ops.concat([[5], shape], axis=0)
u = random_fn(shape_extended)
assert u.shape.as_list() == expected_shape, str(u.shape.as_list())
return i + 1, u
_, _ = while_loop_v2(
cond=lambda i, _: i < 3,
body=Body,
loop_vars=[
0,
array_ops.zeros(expected_shape, dtype=dtypes.float32),
])
@test_util.run_deprecated_v1
def testReshapeShape(self):
shape = constant_op.constant([3, 4])
def Body(i, u):
shape_extended = array_ops.concat([[5], shape], axis=0)
u = array_ops.reshape(u, [-1])
assert u.shape.as_list() == [60], str(u.shape.as_list())
u = array_ops.reshape(u, shape_extended)
assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list())
return i + 1, u
_, _ = while_loop_v2(
cond=lambda i, _: i < 3,
body=Body,
loop_vars=[
0,
array_ops.zeros([5, 3, 4], dtype=dtypes.float32),
])
@parameterized.named_parameters(
("Zeros", array_ops.zeros),
("Ones", array_ops.ones),
("Fill", fill),
)
@test_util.run_deprecated_v1
def testFillOpsShape(self, fill_fn):
shape = constant_op.constant([3, 4])
def Body(i, u):
shape_extended = array_ops.concat([[5], shape], axis=0)
u = fill_fn(shape_extended)
assert u.shape.as_list() == [5, 3, 4], str(u.shape.as_list())
return i + 1, u
_, _ = while_loop_v2(
cond=lambda i, _: i < 3,
body=Body,
loop_vars=[
0,
array_ops.zeros([5, 3, 4], dtype=dtypes.float32),
])
@test_util.run_deprecated_v1
def testExternalColocationGrad(self):
external_t = constant_op.constant(2.)
v0 = constant_op.constant(2.)
def Body(v):
with ops.colocate_with(external_t):
return v * v
ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0]
grad = gradients_impl.gradients(ret, [v0])[0]
self.assertAllEqual(ret, 16.)
self.assertAllEqual(grad, 32.)
@test_util.run_deprecated_v1
def testDoNotAccumulateConstNodes(self):
def Body(v):
return v * 2.0
v0 = constant_op.constant(2.)
ret = while_loop_v2(lambda v: v < 8., Body, [v0])[0]
# Gradients computation has the side-effect of updating the forward op
# which is what we want to test.
unused_grad = gradients_impl.gradients(ret, [v0])[0]
# ret is separated from the `While` op by an `Identity` so we skip over
# that.
forward_while_op = ret.op.inputs[0].op
body_graph = while_v2._get_graph(forward_while_op, "body")
push_back_nodes = [
o for o in body_graph.get_operations() if o.type == "TensorListPushBack"
]
# Gradient of `Mul` requires accumulating both its inputs. But since one
# of those is a Const (2.0), we should have just one accumulator.
self.assertLen(push_back_nodes, 1)
def ScalarShape():
return ops.convert_to_tensor([], dtype=dtypes.int32)
def GetOptimizedGraph():
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
config = config_pb2.ConfigProto()
config.graph_options.rewrite_options.CopyFrom(
rewriter_config_pb2.RewriterConfig(
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL))
return tf_optimizer.OptimizeGraph(config, mg)
if __name__ == "__main__":
test.main()