| # Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for control_flow_ops.py.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import collections |
| import itertools |
| import time |
| |
| from absl.testing import parameterized |
| import numpy as np |
| |
| from tensorflow.core.framework import graph_pb2 |
| from tensorflow.core.framework import node_def_pb2 |
| from tensorflow.core.protobuf import config_pb2 |
| from tensorflow.python import tf2 |
| from tensorflow.python.autograph.lang import directives |
| from tensorflow.python.client import session |
| from tensorflow.python.eager import backprop |
| from tensorflow.python.eager import context |
| from tensorflow.python.eager import def_function |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import sparse_tensor |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.framework import test_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import check_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import control_flow_util_v2 |
| from tensorflow.python.ops import control_flow_v2_toggles |
| from tensorflow.python.ops import custom_gradient |
| from tensorflow.python.ops import embedding_ops |
| from tensorflow.python.ops import gradients_impl |
| from tensorflow.python.ops import init_ops |
| from tensorflow.python.ops import linalg_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops import random_ops |
| from tensorflow.python.ops import script_ops |
| from tensorflow.python.ops import state_ops |
| from tensorflow.python.ops import summary_ops_v2 |
| from tensorflow.python.ops import tensor_array_ops |
| from tensorflow.python.ops import variable_scope |
| from tensorflow.python.ops import variables |
| from tensorflow.python.ops import while_v2 |
| import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import |
| from tensorflow.python.platform import googletest |
| from tensorflow.python.training import momentum |
| from tensorflow.python.util import nest |
| |
| TestTuple = collections.namedtuple("TestTuple", "a b") |
| SingletonTestTuple = collections.namedtuple("SingletonTestTuple", "a") |
| |
| |
| class GroupTestCase(test_util.TensorFlowTestCase): |
| |
| def _StripNode(self, nd): |
| snode = node_def_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input) |
| if nd.device: |
| snode.device = nd.device |
| return snode |
| |
| def _StripGraph(self, gd): |
| """Copy gd keeping only, node.name, node.op, node.input, and node.device.""" |
| return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node]) |
| |
| def testGroup_NoDevices(self): |
| with ops.Graph().as_default() as g: |
| a = constant_op.constant(0, name="a") |
| b = constant_op.constant(0, name="b") |
| c = constant_op.constant(0, name="c") |
| control_flow_ops.group(a.op, b.op, c.op, name="root") |
| gd = g.as_graph_def() |
| self.assertProtoEquals( |
| """ |
| node { name: "a" op: "Const"} |
| node { name: "b" op: "Const"} |
| node { name: "c" op: "Const"} |
| node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" } |
| """, self._StripGraph(gd)) |
| |
| def testGroup_OneDevice(self): |
| with ops.Graph().as_default() as g: |
| with g.device("/task:0"): |
| a = constant_op.constant(0, name="a") |
| b = constant_op.constant(0, name="b") |
| control_flow_ops.group(a.op, b.op, name="root") |
| gd = g.as_graph_def() |
| self.assertProtoEquals( |
| """ |
| node { name: "a" op: "Const" device: "/task:0" } |
| node { name: "b" op: "Const" device: "/task:0" } |
| node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" } |
| """, self._StripGraph(gd)) |
| |
| def testGroup_MultiDevice(self): |
| with ops.Graph().as_default() as g: |
| with g.device("/task:0"): |
| a = constant_op.constant(0, name="a") |
| b = constant_op.constant(0, name="b") |
| with g.device("/task:1"): |
| c = constant_op.constant(0, name="c") |
| d = constant_op.constant(0, name="d") |
| with g.device("/task:2"): |
| control_flow_ops.group(a.op, b.op, c.op, d.op, name="root") |
| gd = g.as_graph_def() |
| self.assertProtoEquals( |
| """ |
| node { name: "a" op: "Const" device: "/task:0"} |
| node { name: "b" op: "Const" device: "/task:0"} |
| node { name: "c" op: "Const" device: "/task:1"} |
| node { name: "d" op: "Const" device: "/task:1"} |
| node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b" |
| device: "/task:0" } |
| node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d" |
| device: "/task:1" } |
| node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1" |
| device: "/task:2" } |
| """, self._StripGraph(gd)) |
| |
| def testPassingList(self): |
| with ops.Graph().as_default() as g: |
| a = constant_op.constant(0, name="a") |
| b = constant_op.constant(0, name="b") |
| control_flow_ops.group([a.op, b.op], name="root") |
| gd = g.as_graph_def() |
| self.assertProtoEquals( |
| """ |
| node { name: "a" op: "Const"} |
| node { name: "b" op: "Const"} |
| node { name: "root" op: "NoOp" input: "^a" input: "^b" } |
| """, self._StripGraph(gd)) |
| |
| @test_util.run_deprecated_v1 |
| def testPassingNonTensors(self): |
| with self.assertRaises(TypeError): |
| control_flow_ops.group(1, 2) |
| |
| |
| class ShapeTestCase(test_util.TensorFlowTestCase): |
| |
| def testShape(self): |
| tensor = constant_op.constant([1.0, 2.0]) |
| self.assertEqual([2], tensor.get_shape()) |
| self.assertEqual([2], |
| control_flow_ops.with_dependencies( |
| [constant_op.constant(1.0)], tensor).get_shape()) |
| |
| |
| class WithDependenciesTestCase(test_util.TensorFlowTestCase): |
| |
| @test_util.run_deprecated_v1 |
| def testTupleDependencies(self): |
| counter = variable_scope.get_variable( |
| "my_counter", shape=[], initializer=init_ops.zeros_initializer()) |
| increment_counter = state_ops.assign_add(counter, 1) |
| const_with_dep = control_flow_ops.with_dependencies( |
| (increment_counter, constant_op.constant(42)), constant_op.constant(7)) |
| |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertEqual(0, self.evaluate(counter)) |
| self.assertEqual(7, self.evaluate(const_with_dep)) |
| self.assertEqual(1, self.evaluate(counter)) |
| |
| @test_util.run_deprecated_v1 |
| def testListDependencies(self): |
| counter = variable_scope.get_variable( |
| "my_counter", shape=[], initializer=init_ops.zeros_initializer()) |
| increment_counter = state_ops.assign_add(counter, 1) |
| const_with_dep = control_flow_ops.with_dependencies( |
| [increment_counter, constant_op.constant(42)], constant_op.constant(7)) |
| |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertEqual(0, self.evaluate(counter)) |
| self.assertEqual(7, self.evaluate(const_with_dep)) |
| self.assertEqual(1, self.evaluate(counter)) |
| |
| |
| class SwitchTestCase(test_util.TensorFlowTestCase): |
| |
| @test_util.run_deprecated_v1 |
| def testIndexedSlicesWithDenseShape(self): |
| with self.cached_session(): |
| data = ops.IndexedSlices( |
| constant_op.constant([1, 2, 3]), |
| constant_op.constant([0, 1, 2]), |
| dense_shape=constant_op.constant([3])) |
| zero = constant_op.constant(0) |
| one = constant_op.constant(1) |
| less_op = math_ops.less(zero, one) |
| _, switch_true = control_flow_ops.switch(data, less_op) |
| self.assertAllEqual([1, 2, 3], switch_true.values) |
| self.assertAllEqual([0, 1, 2], switch_true.indices) |
| |
| @test_util.run_deprecated_v1 |
| def testIndexedSlicesGradient(self): |
| embedding_matrix = variable_scope.get_variable( |
| "embedding_matrix", [5, 5], |
| initializer=init_ops.random_normal_initializer()) |
| |
| def cond(it, _): |
| return it < 5 |
| |
| def body(it, cost): |
| embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0]) |
| cost += math_ops.reduce_sum(embedding) |
| return it + 1, cost |
| |
| _, cost = control_flow_ops.while_loop( |
| cond, body, [constant_op.constant(0), |
| constant_op.constant(0.0)]) |
| optimizer = momentum.MomentumOptimizer(0.1, 0.9) |
| train_op = optimizer.minimize(cost) |
| with self.cached_session(): |
| self.evaluate(variables.global_variables_initializer()) |
| for _ in range(10): |
| self.evaluate([train_op]) |
| |
| def testResourceReadInLoop(self): |
| embedding_matrix = variable_scope.get_variable( |
| "embedding_matrix", initializer=[[2.0], [3.0]], use_resource=True) |
| |
| def cond(it, _): |
| return it < 5 |
| |
| def body(it, cost): |
| embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) |
| cost += math_ops.reduce_sum(embedding) |
| return it + 1, cost |
| |
| _, cost = control_flow_ops.while_loop( |
| cond, body, [constant_op.constant(0), |
| constant_op.constant(0.0)]) |
| with self.cached_session(): |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertAllEqual(10.0, self.evaluate(cost)) |
| |
| def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False): |
| embedding_matrix = variable_scope.get_variable( |
| "embedding_matrix", [5, 5], |
| initializer=init_ops.random_normal_initializer(), |
| use_resource=use_resource) |
| |
| def cond(it, _): |
| return it < 5 |
| |
| def body(it, cost): |
| embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) |
| cost = control_flow_ops.cond( |
| math_ops.equal(it, 3), lambda: math_ops.square(cost), |
| (lambda: cost + math_ops.reduce_sum(embedding))) |
| return it + 1, cost |
| |
| _, cost = control_flow_ops.while_loop( |
| cond, body, [constant_op.constant(0), |
| constant_op.constant(0.0)]) |
| |
| dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0] |
| dynamic_grads = math_ops.segment_sum(dynamic_grads.values, |
| dynamic_grads.indices) |
| |
| embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) |
| static = math_ops.square( |
| math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) + |
| math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding) |
| static_grads = gradients_impl.gradients(static, [embedding_matrix])[0] |
| static_grads = math_ops.segment_sum(static_grads.values, |
| static_grads.indices) |
| |
| with self.cached_session(): |
| self.evaluate(variables.global_variables_initializer()) |
| self.assertAllEqual(*self.evaluate([static_grads, dynamic_grads])) |
| |
| def testIndexedSlicesGradientInCondInWhileLoop(self): |
| self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=False) |
| |
| def testIndexedSlicesGradientInCondInWhileLoopResource(self): |
| self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=True) |
| |
| @test_util.run_v1_only("b/120545219") |
| def testIndexedSlicesWithShapeGradientInWhileLoop(self): |
| for dtype in [dtypes.float32, dtypes.float64]: |
| with self.cached_session() as sess: |
| num_steps = 9 |
| |
| inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps]) |
| initial_outputs = tensor_array_ops.TensorArray( |
| dtype=dtype, size=num_steps) |
| initial_i = constant_op.constant(0, dtype=dtypes.int32) |
| |
| def cond(i, _): |
| return i < num_steps # pylint: disable=cell-var-from-loop |
| |
| def body(i, outputs): |
| x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop |
| outputs = outputs.write(i, x) |
| return i + 1, outputs |
| |
| _, outputs = control_flow_ops.while_loop(cond, body, |
| [initial_i, initial_outputs]) |
| |
| outputs = math_ops.reduce_sum(outputs.stack()) |
| r = gradients_impl.gradients([outputs], [inputs])[0] |
| grad_wr_inputs = ops.convert_to_tensor(r) |
| o, grad = sess.run([outputs, grad_wr_inputs], |
| feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]}) |
| self.assertEqual(o, 20) |
| self.assertAllEqual(grad, [1] * num_steps) |
| |
| @test_util.run_v1_only("b/120545219") |
| def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self): |
| for dtype in [dtypes.float32, dtypes.float64]: |
| with self.cached_session() as sess: |
| inputs = array_ops.placeholder(dtype=dtype) |
| initial_outputs = tensor_array_ops.TensorArray( |
| dtype=dtype, dynamic_size=True, size=1) |
| initial_i = constant_op.constant(0, dtype=dtypes.int32) |
| |
| def cond(i, _): |
| return i < array_ops.size(inputs) # pylint: disable=cell-var-from-loop |
| |
| def body(i, outputs): |
| x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop |
| outputs = outputs.write(i, x) |
| return i + 1, outputs |
| |
| _, outputs = control_flow_ops.while_loop(cond, body, |
| [initial_i, initial_outputs]) |
| |
| outputs = math_ops.reduce_sum(outputs.stack()) |
| r = gradients_impl.gradients([outputs], [inputs])[0] |
| grad_wr_inputs = ops.convert_to_tensor(r) |
| o, grad = sess.run([outputs, grad_wr_inputs], |
| feed_dict={inputs: [1, 3, 2]}) |
| self.assertEqual(o, 6) |
| self.assertAllEqual(grad, [1] * 3) |
| |
| @test_util.run_deprecated_v1 |
| def testGradientThroughSingleBranchOutsideOfContext(self): |
| x = constant_op.constant(2.) |
| s = constant_op.constant(True) |
| x_false, x_true = control_flow_ops.switch(x, s) |
| grad_x_true = gradients_impl.gradients(x_true, x)[0] |
| grad_x_false = gradients_impl.gradients(x_false, x)[0] |
| self.assertEqual(self.evaluate(grad_x_true), 1.) |
| self.assertEqual(self.evaluate(grad_x_false), 0.) |
| |
| |
| class CondTest(test_util.TensorFlowTestCase): |
| |
| def testCondTrue(self): |
| x = constant_op.constant(2) |
| y = constant_op.constant(5) |
| z = control_flow_ops.cond( |
| math_ops.less(x, y), lambda: math_ops.multiply(x, 17), |
| lambda: math_ops.add(y, 23)) |
| self.assertEqual(self.evaluate(z), 34) |
| |
| def testCondFalse(self): |
| x = constant_op.constant(2) |
| y = constant_op.constant(1) |
| z = control_flow_ops.cond( |
| math_ops.less(x, y), lambda: math_ops.multiply(x, 17), |
| lambda: math_ops.add(y, 23)) |
| self.assertEqual(self.evaluate(z), 24) |
| |
| def testCondTrueLegacy(self): |
| x = constant_op.constant(2) |
| y = constant_op.constant(5) |
| z = control_flow_ops.cond( |
| math_ops.less(x, y), |
| fn1=lambda: math_ops.multiply(x, 17), |
| fn2=lambda: math_ops.add(y, 23)) |
| self.assertEqual(self.evaluate(z), 34) |
| |
| def testCondFalseLegacy(self): |
| x = constant_op.constant(2) |
| y = constant_op.constant(1) |
| z = control_flow_ops.cond( |
| math_ops.less(x, y), |
| fn1=lambda: math_ops.multiply(x, 17), |
| fn2=lambda: math_ops.add(y, 23)) |
| self.assertEqual(self.evaluate(z), 24) |
| |
| @test_util.run_v1_only("Exercises Ref variables") |
| def testCondModifyBoolPred(self): |
| # We want to use the GPU here because we want to ensure that we can update |
| # a boolean ref variable on the GPU. |
| with test_util.use_gpu(): |
| bool_var = variable_scope.get_variable( |
| "bool_var", dtype=dtypes.bool, initializer=True) |
| cond_on_bool_var = control_flow_ops.cond( |
| pred=bool_var, |
| true_fn=lambda: state_ops.assign(bool_var, False), |
| false_fn=lambda: True) |
| self.evaluate(bool_var.initializer) |
| self.assertEqual(self.evaluate(cond_on_bool_var), False) |
| self.assertEqual(self.evaluate(cond_on_bool_var), True) |
| |
| def testCondMissingArg1(self): |
| x = constant_op.constant(1) |
| with self.assertRaises(TypeError): |
| control_flow_ops.cond(True, false_fn=lambda: x) |
| |
| def testCondMissingArg2(self): |
| x = constant_op.constant(1) |
| with self.assertRaises(TypeError): |
| control_flow_ops.cond(True, lambda: x) |
| |
| def testCondDuplicateArg1(self): |
| x = constant_op.constant(1) |
| with self.assertRaises(TypeError): |
| control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x) |
| |
| def testCondDuplicateArg2(self): |
| x = constant_op.constant(1) |
| with self.assertRaises(TypeError): |
| control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x) |
| |
| @test_util.enable_control_flow_v2 |
| @test_util.run_in_graph_and_eager_modes |
| def testCond_gradient(self): |
| true_in, false_in = array_ops.constant(1.), array_ops.constant(5.) |
| with backprop.GradientTape(persistent=True) as tape: |
| tape.watch(true_in) |
| tape.watch(false_in) |
| cond_true = control_flow_ops.cond( |
| array_ops.constant(True), lambda: true_in**2., lambda: false_in**2.) |
| cond_false = control_flow_ops.cond( |
| array_ops.constant(False), lambda: true_in**2., lambda: false_in**2.) |
| grads_true = tape.gradient( |
| cond_true, [true_in, false_in], output_gradients=3.) |
| grads_false = tape.gradient( |
| cond_false, [true_in, false_in], output_gradients=3.) |
| self.assertEqual(3. * 2. * 1., self.evaluate(grads_true[0])) |
| self.assertEqual(None if context.executing_eagerly() else 0., |
| self.evaluate(grads_true[1])) |
| self.assertEqual(3. * 2. * 5., self.evaluate(grads_false[1])) |
| self.assertEqual(None if context.executing_eagerly() else 0., |
| self.evaluate(grads_false[0])) |
| |
| def testCondWithGroupAndSummaries(self): |
| with ops.Graph().as_default(): |
| writer = summary_ops_v2.create_file_writer(self.get_temp_dir()) |
| with writer.as_default(), summary_ops_v2.always_record_summaries(): |
| op = control_flow_ops.cond( |
| constant_op.constant(1) >= 0, |
| lambda: control_flow_ops.group(summary_ops_v2.scalar("loss", 0.2)), |
| control_flow_ops.no_op) |
| self.evaluate(variables.global_variables_initializer()) |
| self.evaluate(summary_ops_v2.summary_writer_initializer_op()) |
| self.assertEqual(self.evaluate(op), True) |
| |
| |
| class ContextTest(test_util.TensorFlowTestCase): |
| |
| @test_util.run_deprecated_v1 |
| def testCondContext(self): |
| with self.cached_session() as sess: |
| x = constant_op.constant(2) |
| y = constant_op.constant(5) |
| control_flow_ops.cond( |
| math_ops.less(x, y), lambda: math_ops.multiply(x, 17), |
| lambda: math_ops.add(y, 23)) |
| for op in sess.graph.get_operations(): |
| c = op._get_control_flow_context() |
| if c: |
| self.assertProtoEquals( |
| c.to_proto(), |
| control_flow_ops.CondContext.from_proto(c.to_proto()).to_proto()) |
| |
| def _testWhileContextHelper(self, maximum_iterations=None): |
| with self.cached_session() as sess: |
| i = constant_op.constant(0) |
| c = lambda i: math_ops.less(i, 10) |
| b = lambda i: math_ops.add(i, 1) |
| control_flow_ops.while_loop( |
| c, b, [i], maximum_iterations=maximum_iterations) |
| for op in sess.graph.get_operations(): |
| control_flow_context = op._get_control_flow_context() |
| if control_flow_context: |
| self.assertProtoEquals( |
| control_flow_context.to_proto(), |
| control_flow_ops.WhileContext.from_proto( |
| control_flow_context.to_proto()).to_proto()) |
| |
| @test_util.run_deprecated_v1 |
| def testWhileContext(self): |
| self._testWhileContextHelper() |
| |
| @test_util.run_deprecated_v1 |
| def testWhileContextWithMaximumIterations(self): |
| self._testWhileContextHelper(maximum_iterations=10) |
| |
| @test_util.run_deprecated_v1 |
| def testControlContextImportScope(self): |
| |
| class NoABCControlFlowContext(control_flow_ops.ControlFlowContext): |
| """A noop wrapper around `ControlFlowContext`. |
| |
| `ControlFlowContext` is an ABC and therefore cannot be instantiated. |
| """ |
| |
| # pylint: disable=useless-super-delegation |
| |
| def to_control_flow_context_def(self, context_def, export_scope=None): |
| super(NoABCControlFlowContext, |
| self).to_control_flow_context_def(context_def, export_scope) |
| |
| with self.cached_session(): |
| constant_op.constant(0, name="a") |
| constant_op.constant(2, name="test_scope/a") |
| b1 = constant_op.constant(1, name="b") |
| b2 = constant_op.constant(3, name="test_scope/b") |
| |
| c = NoABCControlFlowContext() |
| c._values = ["a", "b"] |
| c._external_values = {"a": b1} |
| |
| c_with_scope = NoABCControlFlowContext( |
| values_def=c._to_values_def(), import_scope="test_scope") |
| |
| # _values and _external_values should be have scope prepended. |
| self.assertEqual(c_with_scope._values, |
| set(["test_scope/a", "test_scope/b"])) |
| self.assertEqual(c_with_scope._external_values, {"test_scope/a": b2}) |
| |
| # Calling _to_proto() with export_scope should remove "test_scope". |
| self.assertProtoEquals( |
| c._to_values_def(), |
| c_with_scope._to_values_def(export_scope="test_scope")) |
| |
| |
| def _get_nested_shape(nested): |
| |
| def _get_shape(tensor): |
| if isinstance(tensor, tensor_array_ops.TensorArray): |
| return tensor_array_ops.TensorArray |
| elif isinstance(tensor, ops.IndexedSlices): |
| return tensor.dense_shape |
| else: |
| return tensor.get_shape() |
| |
| return nest.map_structure(_get_shape, nested) |
| |
| |
| def _create_tensor_array(size, shape): |
| ta = tensor_array_ops.TensorArray( |
| dtype=dtypes.float32, size=size, clear_after_read=False) |
| for i in range(size): |
| ta = ta.write(i, array_ops.zeros(shape)) |
| return ta |
| |
| |
| def _raw_nested_shape(nested_shape): |
| |
| def _raw_shape(shape): |
| if isinstance(shape, tensor_shape.TensorShape) and shape.ndims is not None: |
| return [x.value for x in shape.dims] |
| else: |
| return None |
| |
| return nest.map_structure(_raw_shape, nested_shape) |
| |
| |
| # TODO(yori): Add tests for indexed slices. |
| class DataTypesTest(test_util.TensorFlowTestCase): |
| |
| def assertAllEqualNested(self, a, b): |
| if isinstance(a, (list, tuple)): |
| for entry_a, entry_b in zip(a, b): |
| self.assertAllEqualNested(entry_a, entry_b) |
| else: |
| self.assertAllEqual(a, b) |
| |
| def _testShape(self, fn_true, fn_false, expected_shape, strict=False): |
| condition = array_ops.placeholder(dtypes.bool) |
| output_cond = control_flow_ops.cond( |
| condition, fn_true, fn_false, strict=strict) |
| self.assertEqual( |
| _raw_nested_shape(_get_nested_shape(output_cond)), |
| _raw_nested_shape(expected_shape)) |
| |
| output_case = control_flow_ops.case([(condition, fn_true)], |
| fn_false, |
| strict=strict) |
| self.assertEqual( |
| _raw_nested_shape(_get_nested_shape(output_case)), |
| _raw_nested_shape(expected_shape)) |
| |
| def _testReturnValues(self, |
| fn_true, |
| fn_false, |
| expected_value_true, |
| expected_value_false, |
| strict=False, |
| check_cond=True, |
| feed_dict=None): |
| if feed_dict is None: |
| feed_dict = {} |
| |
| condition = array_ops.placeholder(dtypes.bool) |
| output_cond = control_flow_ops.cond( |
| condition, fn_true, fn_false, strict=strict) |
| output_case = control_flow_ops.case([(condition, fn_true)], |
| fn_false, |
| strict=strict) |
| |
| with self.cached_session() as sess: |
| self.evaluate(variables.global_variables_initializer()) |
| true_feed_dict = {condition: True} |
| true_feed_dict.update(feed_dict) |
| result_cond, result_case = sess.run([output_cond, output_case], |
| feed_dict=true_feed_dict) |
| self.assertAllEqualNested(result_cond, expected_value_true) |
| if check_cond: |
| self.assertAllEqualNested(result_case, expected_value_true) |
| false_feed_dict = {condition: False} |
| false_feed_dict.update(feed_dict) |
| result_cond, result_case = sess.run([output_cond, output_case], |
| feed_dict=false_feed_dict) |
| self.assertAllEqualNested(result_cond, expected_value_false) |
| if check_cond: |
| self.assertAllEqualNested(result_case, expected_value_false) |
| |
| @test_util.run_deprecated_v1 |
| def test_int(self): |
| shape = tensor_shape.TensorShape([]) |
| fn_true = lambda: 1 |
| fn_false = lambda: 2 |
| self._testShape(fn_true, fn_false, shape) |
| self._testReturnValues(fn_true, fn_false, 1, 2) |
| self._testShape(fn_true, fn_false, shape, strict=True) |
| self._testReturnValues(fn_true, fn_false, 1, 2, strict=True) |
| |
| @test_util.run_deprecated_v1 |
| def test_float(self): |
| shape = tensor_shape.TensorShape([]) |
| fn_true = lambda: 1.0 |
| fn_false = lambda: 2.0 |
| self._testShape(fn_true, fn_false, shape) |
| self._testReturnValues(fn_true, fn_false, 1.0, 2.0) |
| |
| @test_util.run_deprecated_v1 |
| def test_noop(self): |
| shape = tensor_shape.TensorShape(None) |
| self._testShape(control_flow_ops.no_op, control_flow_ops.no_op, shape) |
| self._testReturnValues( |
| control_flow_ops.no_op, |
| control_flow_ops.no_op, |
| True, |
| False, |
| check_cond=False) |
| |
| @test_util.run_deprecated_v1 |
| def test_string(self): |
| shape = tensor_shape.TensorShape([]) |
| fn_true = lambda: "abc" |
| fn_false = lambda: "xyz" |
| self._testShape(fn_true, fn_false, shape) |
| self._testReturnValues(fn_true, fn_false, b"abc", b"xyz") |
| |
| @test_util.run_v1_only("b/138741991") |
| def test_variable(self): |
| shape = tensor_shape.TensorShape([]) |
| fn_true = lambda: variables.Variable(3.0) |
| fn_false = lambda: variables.Variable(4.0) |
| self._testShape(fn_true, fn_false, shape) |
| self._testReturnValues(fn_true, fn_false, 3.0, 4.0) |
| |
| @test_util.run_v1_only("b/120553181") |
| def test_none(self): |
| fn_none = lambda: None |
| fn_tensor = lambda: constant_op.constant(1) |
| |
| with self.assertRaises(ValueError): |
| control_flow_ops.cond(constant_op.constant(True), fn_none, fn_tensor) |
| |
| with self.assertRaises(ValueError): |
| control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_none) |
| |
| @test_util.run_deprecated_v1 |
| def test_tensors(self): |
| |
| def _build_true_branch(dtype): |
| |
| def _build(): |
| return (array_ops.zeros([2, 2], |
| dtype=dtype), array_ops.ones([3, 3], |
| dtype=dtype)) |
| |
| return _build |
| |
| def _build_false_branch(dtype): |
| |
| def _build(): |
| return (array_ops.ones([2, 2], |
| dtype=dtype), array_ops.zeros([3, 3], |
| dtype=dtype)) |
| |
| return _build |
| |
| for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8): |
| shape = (tensor_shape.TensorShape([2, |
| 2]), tensor_shape.TensorShape([3, 3])) |
| fn_true = _build_true_branch(dtype) |
| fn_false = _build_false_branch(dtype) |
| self._testShape(fn_true, fn_false, shape) |
| self._testReturnValues(fn_true, fn_false, |
| (np.zeros([2, 2]), np.ones([3, 3])), |
| (np.ones([2, 2]), np.zeros([3, 3]))) |
| |
| @test_util.run_deprecated_v1 |
| def test_tensors_unknown_shape(self): |
| |
| def _build_true_branch(dtype): |
| tensor = array_ops.placeholder(dtype=dtype, shape=None) |
| |
| def _build(): |
| return tensor |
| |
| return _build, tensor |
| |
| def _build_false_branch(dtype): |
| tensor = array_ops.placeholder(dtype=dtype, shape=None) |
| |
| def _build(): |
| return tensor |
| |
| return _build, tensor |
| |
| for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8): |
| shape = tensor_shape.TensorShape(None) |
| fn_true, true_tensor = _build_true_branch(dtype) |
| fn_false, false_tensor = _build_false_branch(dtype) |
| self._testShape(fn_true, fn_false, shape) |
| self._testReturnValues( |
| fn_true, |
| fn_false, |
| np.zeros([2, 2]), |
| np.ones([2, 2]), |
| feed_dict={ |
| true_tensor: np.zeros([2, 2]), |
| false_tensor: np.ones([2, 2]) |
| }) |
| |
| @test_util.run_deprecated_v1 |
| def test_sparse_tensors(self): |
| shape = tensor_shape.TensorShape([None, None]) |
| |
| def true_fn(): |
| return [ |
| sparse_tensor.SparseTensor( |
| indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) |
| ] |
| |
| def false_fn(): |
| return [ |
| sparse_tensor.SparseTensor( |
| indices=[[0, 0], [2, 1]], values=[3, 4], dense_shape=[3, 4]) |
| ] |
| |
| value1 = sparse_tensor.SparseTensorValue( |
| indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) |
| value2 = sparse_tensor.SparseTensorValue( |
| indices=[[0, 0], [2, 1]], values=[3, 4], dense_shape=[3, 4]) |
| # Non-strict cond is only available in v1 |
| if not tf2.enabled(): |
| self._testShape(true_fn, false_fn, shape) |
| self._testReturnValues(true_fn, false_fn, value1, value2) |
| self._testShape(true_fn, false_fn, [shape], strict=True) |
| self._testReturnValues(true_fn, false_fn, [value1], [value2], strict=True) |
| |
| @test_util.run_deprecated_v1 |
| def test_tensors_with_partially_specified_shapes(self): |
| |
| def _build_branch(dtype, shape): |
| a = array_ops.placeholder(dtype=dtype, shape=shape[0]) |
| b = array_ops.placeholder(dtype=dtype, shape=shape[1]) |
| c = array_ops.placeholder(dtype=dtype, shape=shape[2]) |
| |
| def _build(): |
| return a, b, c |
| |
| return _build, (a, b, c) |
| |
| for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8): |
| shape = (tensor_shape.TensorShape([None, |
| 2]), tensor_shape.TensorShape([None]), |
| tensor_shape.TensorShape([3, None])) |
| fn_true, true_tensors = _build_branch(dtype, shape) |
| fn_false, false_tensors = _build_branch(dtype, shape) |
| self._testShape(fn_true, fn_false, shape) |
| self._testReturnValues( |
| fn_true, |
| fn_false, (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])), |
| (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])), |
| feed_dict={ |
| true_tensors[0]: np.zeros([2, 2]), |
| false_tensors[0]: np.zeros([2, 2]), |
| true_tensors[1]: np.zeros([5]), |
| false_tensors[1]: np.zeros([5]), |
| true_tensors[2]: np.ones([3, 3]), |
| false_tensors[2]: np.ones([3, 3]) |
| }) |
| |
| @test_util.run_deprecated_v1 |
| def test_tensor_arrays(self): |
| element_shape = tensor_shape.TensorShape([2]) |
| ta1 = _create_tensor_array(4, element_shape) |
| ta2 = _create_tensor_array(4, element_shape) |
| shape = tensor_array_ops.TensorArray |
| fn_true = lambda: ta1 |
| fn_false = lambda: ta2 |
| self._testShape(fn_true, fn_false, shape) |
| |
| @test_util.run_deprecated_v1 |
| def test_tensor_array_reads(self): |
| shape = tensor_shape.TensorShape([2]) |
| ta = _create_tensor_array(4, shape) |
| fn_true = lambda: ta.read(0) |
| fn_false = lambda: ta.read(1) |
| self._testShape(fn_true, fn_false, shape) |
| |
| @test_util.run_v1_only("b/138741991") |
| def test_list(self): |
| shape = [ |
| tensor_shape.TensorShape([]), |
| tensor_shape.TensorShape([]), |
| tensor_shape.TensorShape([]) |
| ] |
| fn_true = lambda: [constant_op.constant(1), 2, variables.Variable(3.0)] |
| fn_false = lambda: [constant_op.constant(3), 4, variables.Variable(5.0)] |
| self._testShape(fn_true, fn_false, shape) |
| self._testReturnValues(fn_true, fn_false, [1, 2, 3.0], [3, 4, 5.0]) |
| |
| @test_util.run_v1_only("Non-strict cond is only available in v1") |
| def test_non_strict(self): |
| shape = tensor_shape.TensorShape([]) |
| fn_tensor = lambda: constant_op.constant(1) |
| fn_list = lambda: [constant_op.constant(2)] |
| fn_tuple = lambda: (constant_op.constant(3),) |
| self._testShape(fn_tensor, fn_list, shape) |
| self._testShape(fn_tensor, fn_tuple, shape) |
| self._testShape(fn_list, fn_tuple, shape) |
| self._testReturnValues(fn_tensor, fn_list, 1, 2) |
| self._testReturnValues(fn_tensor, fn_tuple, 1, 3) |
| self._testReturnValues(fn_list, fn_tuple, 2, 3) |
| |
| @test_util.run_v1_only("b/120553181") |
| def test_singleton_strict(self): |
| fn_tensor = lambda: constant_op.constant(1) |
| fn_list = lambda: [constant_op.constant(2)] |
| fn_tuple = lambda: (constant_op.constant(3),) |
| |
| with self.assertRaises(ValueError): |
| control_flow_ops.cond( |
| constant_op.constant(True), fn_tensor, fn_list, strict=True) |
| |
| with self.assertRaises(TypeError): |
| control_flow_ops.cond( |
| constant_op.constant(True), fn_list, fn_tuple, strict=True) |
| |
| with self.assertRaises(ValueError): |
| control_flow_ops.case([(constant_op.constant(True), fn_tensor)], |
| fn_list, |
| strict=True) |
| |
| with self.assertRaises(TypeError): |
| control_flow_ops.case([(constant_op.constant(True), fn_list)], |
| fn_tuple, |
| strict=True) |
| |
| @test_util.run_deprecated_v1 |
| def test_singleton_list(self): |
| shape = tensor_shape.TensorShape([]) |
| fn_true = lambda: [constant_op.constant(1)] |
| fn_false = lambda: [constant_op.constant(3)] |
| # Non-strict cond is only available in v1 |
| if not tf2.enabled(): |
| self._testShape(fn_true, fn_false, shape) |
| self._testReturnValues(fn_true, fn_false, 1, 3) |
| self._testShape(fn_true, fn_false, [shape], strict=True) |
| self._testReturnValues(fn_true, fn_false, [1], [3], strict=True) |
| |
| @test_util.run_deprecated_v1 |
| def test_singleton_tuple(self): |
| shape = tensor_shape.TensorShape([]) |
| fn_true = lambda: (constant_op.constant(1),) |
| fn_false = lambda: (constant_op.constant(3),) |
| # Non-strict cond is only available in v1 |
| if not tf2.enabled(): |
| self._testShape(fn_true, fn_false, shape) |
| self._testReturnValues(fn_true, fn_false, 1, 3) |
| self._testShape(fn_true, fn_false, (shape,), strict=True) |
| self._testReturnValues(fn_true, fn_false, (1,), (3,), strict=True) |
| |
| @test_util.run_deprecated_v1 |
| def test_singleton_namedtuple(self): |
| shape = tensor_shape.TensorShape([]) |
| fn_true = lambda: SingletonTestTuple(constant_op.constant(1)) |
| fn_false = lambda: SingletonTestTuple(constant_op.constant(3)) |
| # Non-strict cond is only available in v1 |
| if not tf2.enabled(): |
| self._testShape(fn_true, fn_false, shape) |
| self._testReturnValues(fn_true, fn_false, 1, 3) |
| self._testShape(fn_true, fn_false, SingletonTestTuple(shape), strict=True) |
| self._testReturnValues( |
| fn_true, |
| fn_false, |
| SingletonTestTuple(1), |
| SingletonTestTuple(3), |
| strict=True) |
| |
| @test_util.run_deprecated_v1 |
| def test_tuple(self): |
| shape = (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) |
| fn_true = lambda: (constant_op.constant(1), 2) |
| fn_false = lambda: (constant_op.constant(3), 4) |
| self._testShape(fn_true, fn_false, shape) |
| self._testReturnValues(fn_true, fn_false, (1, 2), (3, 4)) |
| |
| @test_util.run_deprecated_v1 |
| def test_namedtuple(self): |
| shape = TestTuple( |
| tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) |
| fn_true = lambda: TestTuple(constant_op.constant(1), 2) |
| fn_false = lambda: TestTuple(constant_op.constant(3), 4) |
| self._testShape(fn_true, fn_false, shape) |
| self._testReturnValues(fn_true, fn_false, TestTuple(1, 2), TestTuple(3, 4)) |
| |
| @test_util.run_deprecated_v1 |
| def test_nested(self): |
| shape = [ |
| tensor_shape.TensorShape([]), |
| TestTuple( |
| tensor_shape.TensorShape([]), |
| [tensor_shape.TensorShape([]), |
| tensor_shape.TensorShape([])]), |
| tensor_shape.TensorShape([5, 5]), |
| tensor_shape.TensorShape([]) |
| ] |
| |
| def true_fn(): |
| return [ |
| constant_op.constant(1), |
| TestTuple(constant_op.constant(2), [3, 4]), |
| array_ops.zeros([5, 5]), 6 |
| ] |
| |
| def false_fn(): |
| return [ |
| constant_op.constant(11), |
| TestTuple(constant_op.constant(12), [13, 14]), |
| array_ops.ones([5, 5]), 16 |
| ] |
| |
| self._testShape(true_fn, false_fn, shape) |
| self._testReturnValues( |
| true_fn, false_fn, |
| [1, TestTuple(2, [3, 4]), np.zeros([5, 5]), 6], |
| [11, TestTuple(12, [13, 14]), |
| np.ones([5, 5]), 16]) |
| |
| @test_util.run_deprecated_v1 |
| def test_cond_inside_while_loop(self): |
| |
| def body(i, matrix): |
| result_tuple, unused_matrix = control_flow_ops.cond( |
| constant_op.constant(True), lambda: |
| (TestTuple(matrix * 2, matrix * 4), matrix), lambda: |
| (TestTuple(matrix * 4, matrix * 2), matrix)) |
| return [i + 1, result_tuple.a] |
| |
| iteration, matrix = control_flow_ops.while_loop( |
| lambda i, matrix: i < 10, |
| body, |
| loop_vars=[constant_op.constant(0), |
| array_ops.ones([2, 2])]) |
| |
| self.assertEqual(iteration.get_shape(), tensor_shape.TensorShape([])) |
| self.assertEqual(matrix.get_shape(), tensor_shape.TensorShape([2, 2])) |
| |
| |
| @test_util.run_all_in_graph_and_eager_modes |
| class IndexedCaseTest(test_util.TensorFlowTestCase, parameterized.TestCase): |
| |
| def make_name(self): |
| name = self.id().split(".")[-1].replace("(", "_").replace(")", "") |
| return name.replace(" ", "_") |
| |
| def disabled_testCase_ticklesGpuVsHostMemoryIssueWithInt32(self): |
| nbranches = 5 |
| |
| def make_func(bi): |
| return lambda: array_ops.constant(bi * 10, name="br{}_out".format(bi)) |
| |
| branches = [(i, make_func(i)) for i in range(nbranches)] |
| for bi in range(nbranches): |
| branch_index = array_ops.placeholder_with_default(bi, []) |
| case_out = control_flow_ops.switch_case(branch_index, branches) |
| self.assertEqual(bi * 10, self.evaluate(case_out)) |
| |
| @parameterized.parameters((0,), (2,), (3,)) |
| def testCase(self, bi): |
| nbranches = 5 |
| |
| def make_func(bi): |
| return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi)) |
| |
| branches = [(i, make_func(i)) for i in range(nbranches)] |
| branch_index = array_ops.placeholder_with_default(bi, []) |
| case_out = control_flow_ops.switch_case( |
| branch_index, branches, name=self.make_name()) |
| self.assertEqual(bi * 10., self.evaluate(case_out)) |
| |
| @parameterized.parameters((-1,), (2,), (4,), (5,), (6,)) |
| def testCase_withDefault(self, bi): |
| nbranches = 5 |
| |
| def make_func(bi): |
| return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi)) |
| |
| branches = [(i, make_func(i)) for i in range(nbranches)] |
| branch_index = array_ops.placeholder_with_default(bi, []) |
| case_out = control_flow_ops.switch_case( |
| branch_index, branches, default=make_func(6), name=self.make_name()) |
| if bi < 0 or bi >= nbranches: |
| expected = 60. |
| else: |
| expected = bi * 10. |
| self.assertEqual(expected, self.evaluate(case_out)) |
| |
| @parameterized.parameters((-1,), (0,), (3,), (5,)) |
| def testCase_dictWithDefault(self, bi): |
| nbranches = 5 |
| |
| def make_func(bi): |
| return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi)) |
| |
| branches = {i: make_func(i) for i in range(nbranches)} |
| branch_index = array_ops.placeholder_with_default(bi, []) |
| case_out = control_flow_ops.switch_case( |
| branch_index, branches, default=make_func(6), name=self.make_name()) |
| if bi < 0 or bi >= nbranches: |
| expected = 60. |
| else: |
| expected = bi * 10. |
| self.assertEqual(expected, self.evaluate(case_out)) |
| |
| @parameterized.parameters((-1,), (1,), (4,), (5,)) |
| def testCase_gradient_disable_lowering(self, bi): |
| self._testCase_gradient(True, bi) |
| |
| @parameterized.parameters((-1,), (1,), (4,), (5,)) |
| def testCase_gradient_enable_lowering(self, bi): |
| self._testCase_gradient(False, bi) |
| |
| def _testCase_gradient(self, disable_lowering, bi): |
| default_lowering = control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE |
| control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = disable_lowering |
| nbranches = 5 |
| inputs = [ |
| array_ops.constant(float(bi), name="br{}_in".format(bi)) |
| for bi in range(nbranches) |
| ] |
| |
| def make_func(bi): |
| return lambda: inputs[bi]**2. |
| |
| branches = {bi: make_func(bi) for bi in range(nbranches)} |
| |
| branch_index = array_ops.placeholder_with_default(bi, []) |
| with backprop.GradientTape() as tape: |
| for x in inputs: |
| tape.watch(x) |
| case_out = control_flow_ops.switch_case(branch_index, branches) |
| out_grad = 3. |
| actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad) |
| expected_grads = [None if context.executing_eagerly() else 0.] * nbranches |
| used_branch_idx = nbranches - 1 if bi < 0 or bi >= nbranches - 1 else bi |
| expected_grads[used_branch_idx] = out_grad * 2. * used_branch_idx |
| self.assertEqual(len(expected_grads), len(actual_grads)) |
| for expected, actual in zip(expected_grads, actual_grads): |
| self.assertEqual(expected, self.evaluate(actual)) |
| # reset to default value |
| control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = default_lowering |
| |
| @parameterized.parameters((-2,), (2,), (5,)) |
| def testCase_gradient_diffShapedIntermediates(self, bi): |
| nbranches = 5 |
| inputs = [ |
| array_ops.constant( |
| float(bi), shape=[bi + 1], name="br{}_in".format(bi)) |
| for bi in range(nbranches) |
| ] |
| |
| def make_func(bi): |
| |
| def f(): |
| x = inputs[bi]**2 * inputs[bi][:bi + 1, None] |
| return math_ops.reduce_sum(x) |
| |
| return f |
| |
| branches = {bi: make_func(bi) for bi in range(nbranches)} |
| |
| branch_index = array_ops.placeholder_with_default(bi, []) |
| with backprop.GradientTape() as tape: |
| for x in inputs: |
| tape.watch(x) |
| case_out = control_flow_ops.switch_case( |
| branch_index, branches, name=self.make_name()) |
| out_grad = 3. |
| actual_grads = tape.gradient(case_out, inputs, output_gradients=out_grad) |
| used_bi = (nbranches - 1) if (bi < 0 or bi >= nbranches - 1) else bi |
| expected_grads = [] |
| for input_idx in range(nbranches): |
| if used_bi == input_idx: |
| with backprop.GradientTape() as tape: |
| tape.watch(inputs[used_bi]) |
| y = make_func(used_bi)() |
| expected_grads.append( |
| self.evaluate( |
| tape.gradient(y, inputs[used_bi], output_gradients=out_grad))) |
| else: |
| expected_grads.append(None if context.executing_eagerly() else [0.] * |
| (input_idx + 1)) |
| |
| self.assertEqual(len(expected_grads), len(actual_grads)) |
| for expected, actual in zip(expected_grads, actual_grads): |
| if expected is None: |
| self.assertIsNone(actual) |
| else: |
| self.assertAllEqual(expected, self.evaluate(actual)) |
| |
| @test_util.run_gpu_only |
| @test_util.disable_xla("Wants RunMetadata") |
| def testParallelExecution(self): |
| """Verify disjoint branches across while iterations are run in parallel.""" |
| if control_flow_v2_toggles.control_flow_v2_enabled(): |
| self.skipTest("b/138870290") |
| |
| with ops.Graph().as_default() as g: |
| nbranches = 7 |
| matrices = array_ops.unstack( # Ensure all are ready before while. |
| array_ops.matrix_diag( |
| random_ops.random_uniform([nbranches, 8, 512]) + 1e-3)) |
| |
| def make_branch(i, mat, name): |
| |
| def branch_fn(): |
| next_i = i + 1 |
| with ops.device("gpu:0"): |
| return next_i, math_ops.reduce_sum( |
| linalg_ops.cholesky(mat, name=name + "_Cholesky")) |
| |
| return branch_fn |
| |
| def make_branches(i): |
| return [ |
| make_branch(i, matrices[bi], "br{}".format(bi)) |
| for bi in range(nbranches) |
| ] |
| |
| def cond(i, _): |
| return i < nbranches |
| |
| def body(i, result): |
| with ops.device("cpu:0"): |
| next_i, branch_out = control_flow_ops.switch_case(i, make_branches(i)) |
| return next_i, result + branch_out |
| |
| _, result = control_flow_ops.while_loop(cond, body, [0, 0.]) |
| |
| run_metadata = config_pb2.RunMetadata() |
| run_options = config_pb2.RunOptions( |
| trace_level=config_pb2.RunOptions.FULL_TRACE) |
| config = config_pb2.ConfigProto( |
| allow_soft_placement=False, log_device_placement=True) |
| |
| with session.Session(config=config, graph=g) as sess: |
| _ = sess.run(result, options=run_options, run_metadata=run_metadata) |
| chol_node_stats = [] |
| for dev_stats in run_metadata.step_stats.dev_stats: |
| for node_stats in dev_stats.node_stats: |
| if (node_stats.node_name.endswith("Cholesky") and |
| node_stats.all_start_nanos > 0): |
| chol_node_stats.append(node_stats) |
| |
| self.assertLen(chol_node_stats, nbranches) |
| |
| chol_node_stats = sorted(chol_node_stats, key=lambda stats: stats.node_name) |
| op_start_nanos = [stats.all_start_nanos for stats in chol_node_stats] |
| op_end_nanos = [ |
| stats.all_start_nanos + stats.op_end_rel_nanos |
| for stats in chol_node_stats |
| ] |
| |
| def overlap(range1, range2): |
| s1, e1 = range1 |
| s2, e2 = range2 |
| if s1 < s2: |
| return 0 if s2 > e1 else e1 - s2 |
| return 0 if s1 > e2 else e2 - s1 |
| |
| timespans = list(zip(op_start_nanos, op_end_nanos)) |
| overlaps_chol0 = [overlap(timespans[0], r2) for r2 in timespans[1:]] |
| # There are nbranches-1 overlaps, sometimes all nonzero, but we |
| # conservatively check for at least one here, to avoid test flakiness. |
| self.assertGreater(np.count_nonzero(overlaps_chol0), 0) |
| |
| def testCase_validateIndicesContiguous(self): |
| |
| def make_func(bi): |
| return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi)) |
| |
| branches = {i: make_func(i) for i in range(0, 6, 2)} |
| with self.assertRaisesRegex(ValueError, "must form contiguous"): |
| control_flow_ops.switch_case(array_ops.constant(0), branches) |
| |
| def testCase_validateIndicesDup(self): |
| |
| def make_func(bi): |
| return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi)) |
| |
| branches = [(i, make_func(i)) for i in range(0, 6, 2)] |
| branches.append((0, make_func(7))) |
| with self.assertRaisesRegex(ValueError, "must form contiguous"): |
| control_flow_ops.switch_case(array_ops.constant(0), branches) |
| |
| def testCase_validateBranchIndex(self): |
| |
| def make_func(bi): |
| return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi)) |
| |
| branches = {i: make_func(i) for i in range(5)} |
| with self.assertRaisesRegex(TypeError, "branch_index.*Tensor"): |
| control_flow_ops.switch_case(1, branches) |
| |
| def testCase_validateNonIntKeys(self): |
| |
| def make_func(bi): |
| return lambda: array_ops.constant(bi * 10., name="br{}_out".format(bi)) |
| |
| branches = [(array_ops.constant(i), make_func(i)) for i in range(5)] |
| with self.assertRaisesRegex(TypeError, "must be a Python `int`"): |
| control_flow_ops.switch_case(array_ops.constant(1), branches) |
| |
| |
| class ExecuteFnForDeviceTest(test_util.TensorFlowTestCase): |
| |
| # The same test can run with and without XLA compilation. |
| # In non-XLA gpu case, it exercises gpu branch. |
| # In XLA gpu cases, it exercises the default case. |
| # This test is to test the non-XLA case so that we disable XLA. |
| @test_util.disable_xla("xla has different execution branch") |
| def testCommonCases(self): |
| |
| def cpu_fn(x): |
| return x + x |
| |
| def gpu_fn(x): |
| return x * x |
| |
| def flexible_fn(a): |
| branches = {"CPU": lambda: cpu_fn(a), "GPU": lambda: gpu_fn(a)} |
| return control_flow_ops.execute_fn_for_device(branches, lambda: cpu_fn(a)) |
| |
| @def_function.function |
| def flexible_defun(a): |
| return flexible_fn(a) |
| |
| def run_defun_and_tape(a): |
| with backprop.GradientTape() as tape: |
| tape.watch(a) |
| result = flexible_defun(a) |
| grad = tape.gradient(result, a) |
| r = flexible_fn(a) |
| return r, result, grad |
| |
| a = array_ops.constant(3.) |
| with ops.device("cpu:0"): |
| r, result, grad = run_defun_and_tape(a) |
| self.assertEqual(6., self.evaluate(r)) |
| self.assertEqual(6., self.evaluate(result)) |
| self.assertEqual([2.], self.evaluate(grad)) |
| |
| if test_util.is_gpu_available(): |
| with ops.device("gpu:0"): |
| r, result, grad = run_defun_and_tape(a) |
| self.assertEqual(9., self.evaluate(r)) |
| self.assertEqual(9., self.evaluate(result)) |
| self.assertEqual([6.], self.evaluate(grad)) |
| |
| # no device annotation |
| r, result, grad = run_defun_and_tape(a) |
| if test_util.is_gpu_available(): |
| self.assertEqual(9., self.evaluate(r)) |
| self.assertEqual(9., self.evaluate(result)) |
| self.assertEqual([6.], self.evaluate(grad)) |
| else: |
| self.assertEqual(6., self.evaluate(r)) |
| self.assertEqual(6., self.evaluate(result)) |
| self.assertEqual([2.], self.evaluate(grad)) |
| |
| def testCompile(self): |
| if not test_util.is_gpu_available(): |
| return |
| |
| def cpu_fn(x): |
| return x + x |
| |
| def gpu_fn(x): |
| return x * x |
| |
| @def_function.function(jit_compile=True) |
| def flexible_defun(a): |
| branches = {"CPU": lambda: cpu_fn(a), "GPU": lambda: gpu_fn(a)} |
| return control_flow_ops.execute_fn_for_device(branches, lambda: cpu_fn(a)) |
| |
| # Always execute the default branch in xla compilation case. |
| a = array_ops.constant(3.) |
| r = flexible_defun(a) |
| self.assertEqual(6., self.evaluate(r)) |
| |
| def testFallBack(self): |
| |
| def default_fn(x): |
| return x |
| |
| def tpu_fn(x): |
| return x * x * x |
| |
| def flexible_fn(a): |
| branches = {"TPU": lambda: tpu_fn(a)} |
| return control_flow_ops.execute_fn_for_device( |
| branches, default_fn=lambda: default_fn(a)) |
| |
| @def_function.function |
| def flexible_defun(a): |
| return flexible_fn(a) |
| |
| a = array_ops.constant(3.) |
| with ops.device("cpu:0"): |
| result_defun = flexible_defun(a) |
| result_defun = flexible_fn(a) |
| self.assertEqual(3., self.evaluate(result_defun)) |
| # execute_fn_for_device is not inside defun_function. |
| result = flexible_fn(a) |
| self.assertEqual(3., self.evaluate(result)) |
| |
| if test_util.is_gpu_available(): |
| with ops.device("gpu:0"): |
| result_defun = flexible_defun(a) |
| self.assertEqual(3., self.evaluate(result_defun)) |
| # execute_fn_for_device is not inside defun_function. |
| result = flexible_fn(a) |
| self.assertEqual(3., self.evaluate(result)) |
| |
| |
| class CaseTest(test_util.TensorFlowTestCase): |
| |
| @test_util.run_deprecated_v1 |
| def testCase_withDefault(self): |
| x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) |
| conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), |
| (math_ops.equal(x, 2), lambda: constant_op.constant(4))] |
| default = lambda: constant_op.constant(6) |
| output = control_flow_ops.case(conditions, default, exclusive=True) |
| with self.cached_session() as sess: |
| self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) |
| self.assertEqual(sess.run(output, feed_dict={x: 2}), 4) |
| self.assertEqual(sess.run(output, feed_dict={x: 3}), 6) |
| |
| @test_util.run_deprecated_v1 |
| def testCase_multiple_matches_exclusive(self): |
| x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) |
| conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), |
| (math_ops.equal(x, 2), lambda: constant_op.constant(4)), |
| (math_ops.equal(x, 2), lambda: constant_op.constant(6))] |
| default = lambda: constant_op.constant(8) |
| output = control_flow_ops.case(conditions, default, exclusive=True) |
| with self.cached_session() as sess: |
| self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) |
| self.assertEqual(sess.run(output, feed_dict={x: 3}), 8) |
| with self.assertRaisesRegex(errors.InvalidArgumentError, "Input error:"): |
| sess.run(output, feed_dict={x: 2}) |
| |
| @test_util.run_deprecated_v1 |
| def testCase_multiple_matches_non_exclusive(self): |
| x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) |
| conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), |
| (math_ops.equal(x, 2), lambda: constant_op.constant(4)), |
| (math_ops.equal(x, 2), lambda: constant_op.constant(6))] |
| default = lambda: constant_op.constant(8) |
| output = control_flow_ops.case(conditions, default, exclusive=False) |
| with self.cached_session() as sess: |
| self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) |
| self.assertEqual(sess.run(output, feed_dict={x: 2}), 4) |
| self.assertEqual(sess.run(output, feed_dict={x: 3}), 8) |
| |
| @test_util.run_deprecated_v1 |
| def testCase_withoutDefault(self): |
| x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) |
| conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), |
| (math_ops.equal(x, 2), lambda: constant_op.constant(4)), |
| (math_ops.equal(x, 3), lambda: constant_op.constant(6))] |
| output = control_flow_ops.case(conditions, exclusive=True) |
| with self.cached_session() as sess: |
| self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) |
| self.assertEqual(sess.run(output, feed_dict={x: 2}), 4) |
| self.assertEqual(sess.run(output, feed_dict={x: 3}), 6) |
| with self.assertRaisesRegex(errors.InvalidArgumentError, "Input error:"): |
| sess.run(output, feed_dict={x: 4}) |
| |
| @test_util.run_deprecated_v1 |
| def testCase_withoutDefault_oneCondition(self): |
| x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) |
| conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2))] |
| output = control_flow_ops.case(conditions, exclusive=True) |
| with self.cached_session() as sess: |
| self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) |
| with self.assertRaisesRegex(errors.InvalidArgumentError, "Input error:"): |
| sess.run(output, feed_dict={x: 4}) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testCase_dict(self): |
| x = constant_op.constant(2) |
| conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), |
| (math_ops.equal(x, 2), lambda: constant_op.constant(4))] |
| output = control_flow_ops.case(conditions, exclusive=True) |
| self.assertEqual(4, self.evaluate(output)) |
| |
| |
| class WhileLoopTestCase(test_util.TensorFlowTestCase): |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testWhileLoopWithSingleVariable(self): |
| i = constant_op.constant(0) |
| c = lambda i: math_ops.less(i, 10) |
| b = lambda i: math_ops.add(i, 1) |
| r = control_flow_ops.while_loop(c, b, [i]) |
| |
| self.assertEqual(self.evaluate(r), 10) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testEagerWhileLoopWithSingleVariable_bodyReturnsTuple(self): |
| i = constant_op.constant(0) |
| c = lambda i: math_ops.less(i, 10) |
| b = lambda i: (math_ops.add(i, 1),) |
| r = control_flow_ops.while_loop(c, b, [i]) |
| |
| # Expect a tuple since that is what the body returns. |
| self.assertEqual(self.evaluate(r), (10,)) |
| |
| @test_util.run_v1_only("Unsupported in cfv2") |
| def testWhileLoopSameReturnShape_False(self): |
| i = constant_op.constant(0) |
| c = lambda i, _: math_ops.less(i, 10) |
| |
| # Body returns a [tensor, []] |
| b = lambda i, _: [math_ops.add(i, 1), []] |
| |
| # Should only return the tensor. |
| r = control_flow_ops.while_loop(c, b, [i, []]) |
| self.assertEqual(self.evaluate(r), 10) |
| |
| # Adding maximum_iterations should yield the same result. |
| r = control_flow_ops.while_loop(c, b, [i, []], maximum_iterations=50) |
| # Note: this result is still incorrect - it should be just 10. |
| self.assertEqual(self.evaluate(r), [10, []]) |
| |
| def testWhileLoopSameReturnShape_FalseSingleLoopVar(self): |
| i = constant_op.constant(0) |
| c = lambda i: math_ops.less(i, 10) |
| |
| # Body return must be unpacked in this case. |
| b = lambda i: math_ops.add(i, 1) |
| |
| # Should only return the tensor. |
| r = control_flow_ops.while_loop(c, b, [i]) |
| self.assertEqual(self.evaluate(r), 10) |
| |
| # Adding maximum_iterations should yield the same result. |
| r = control_flow_ops.while_loop(c, b, [i], maximum_iterations=50) |
| self.assertEqual(self.evaluate(r), 10) |
| |
| def testWhileLoopSameReturnShape_True(self): |
| i = constant_op.constant(0) |
| c = lambda i, _: math_ops.less(i, 10) |
| |
| # Body returns a [tensor, []] |
| b = lambda i, _: [math_ops.add(i, 1), []] |
| |
| # Should only return the original structure. |
| r = control_flow_ops.while_loop(c, b, [i, []], return_same_structure=True) |
| self.assertEqual(self.evaluate(r), [10, []]) |
| |
| # Adding maximum_iterations should yield the same result. |
| r = control_flow_ops.while_loop( |
| c, b, [i, []], return_same_structure=True, maximum_iterations=50) |
| self.assertEqual(self.evaluate(r), [10, []]) |
| |
| def testWhileLoopSameReturnShape_TrueSingleLoopVar(self): |
| i = constant_op.constant(0) |
| c = lambda i: math_ops.less(i, 10) |
| |
| b = lambda i: [math_ops.add(i, 1)] |
| |
| # Should not unpack the single variable |
| r = control_flow_ops.while_loop(c, b, [i], return_same_structure=True) |
| self.assertEqual(self.evaluate(r), [10]) |
| |
| # Adding maximum_iterations should yield the same result. |
| r = control_flow_ops.while_loop( |
| c, b, [i], return_same_structure=True, maximum_iterations=50) |
| self.assertEqual(self.evaluate(r), [10]) |
| |
| @test_util.enable_control_flow_v2 |
| @test_util.run_in_graph_and_eager_modes |
| def testSkipsUnnecessaryCaptureGradients(self): |
| |
| @custom_gradient.custom_gradient |
| def gradient_trap(t): |
| |
| def grad(w): |
| # Computing this gradient should fail the test |
| check_ops.assert_equal(0, 1) |
| return w |
| |
| return t, grad |
| |
| x = array_ops.constant(0.0, name="x") |
| y = array_ops.constant(1.0, name="y") |
| |
| def cond(s): |
| return s < 10.0 |
| |
| def body(s): |
| return s + 2 * x + gradient_trap(y) |
| |
| with backprop.GradientTape() as tape: |
| tape.watch(x) |
| out = control_flow_ops.while_loop(cond, body, (array_ops.constant(0.0),)) |
| |
| grad = tape.gradient(out, x) |
| self.assertAllEqual(grad, 20.0) |
| |
| |
| class WhileLoopParallelismTest(test_util.TensorFlowTestCase, |
| parameterized.TestCase): |
| |
| def setUp(self): |
| super().setUp() |
| self._while_paralelism = while_v2.glob_stateful_parallelism |
| |
| def tearDown(self): |
| while_v2.glob_stateful_parallelism = self._while_paralelism |
| super().tearDown() |
| |
| @parameterized.parameters(*itertools.product( |
| (False, True), |
| (False, True), |
| (False, True), |
| (False, True), |
| (False, True), |
| )) |
| def testResourceHandlingInLoop(self, read_before, read_after, modify_in_loop, |
| modify_before, modify_after): |
| |
| if not tf2.enabled(): |
| self.skipTest("V2-only test.") |
| |
| while_v2.glob_stateful_parallelism = True |
| |
| ticker = variables.Variable(0) |
| |
| @def_function.function |
| def run_loop(n): |
| ticker.assign(0) |
| i = constant_op.constant(0) |
| t_acc = tensor_array_ops.TensorArray( |
| dtypes.int32, size=0, dynamic_size=True) |
| |
| if read_before: |
| rb = ticker.read_value() |
| else: |
| rb = constant_op.constant(0) |
| if modify_before: |
| ticker.assign_add(1) |
| |
| while i < n: |
| directives.set_loop_options(parallel_iterations=10) |
| |
| if modify_in_loop: |
| ticker.assign_add(1) |
| t_acc = t_acc.write(i, ticker.read_value()) |
| i += 1 |
| |
| if read_after: |
| ra = ticker.read_value() |
| else: |
| ra = constant_op.constant(0) |
| if modify_after: |
| ticker.assign_add(1) |
| |
| return t_acc.stack(), rb, ra |
| |
| # Warm-up. |
| self.evaluate(run_loop(1)) |
| |
| self.evaluate(ticker.assign(123)) |
| acc, rb, ra = run_loop(3) |
| self.assertEqual( |
| self.evaluate(math_ops.reduce_max(acc)), |
| int(modify_before) + 3 * int(modify_in_loop)) |
| |
| # Double check variable reads are still sequenced. |
| self.assertEqual(self.evaluate(rb), 0) |
| |
| if read_after: |
| expected_ra = int(modify_before) + 3 * int(modify_in_loop) |
| else: |
| expected_ra = 0 |
| self.assertEqual(self.evaluate(ra), expected_ra) |
| |
| # Double-check that the loop ran completely. |
| self.assertEqual( |
| self.evaluate(ticker.read_value()), |
| int(modify_before) + 3 * int(modify_in_loop) + int(modify_after)) |
| |
| def testMultiReadsBeforeWrite(self): |
| |
| if not tf2.enabled(): |
| self.skipTest("V2-only test.") |
| |
| while_v2.glob_stateful_parallelism = True |
| |
| ticker = variables.Variable(0) |
| |
| @def_function.function |
| def run_loop(n): |
| ticker.assign(0) |
| i = constant_op.constant(0) |
| t_acc = tensor_array_ops.TensorArray( |
| dtypes.int32, size=0, dynamic_size=True) |
| |
| while i < n: |
| directives.set_loop_options(parallel_iterations=10) |
| |
| a = ticker.read_value() |
| b = ticker.read_value() |
| t_acc = t_acc.write(2 * i, a) |
| t_acc = t_acc.write(2 * i + 1, b) |
| |
| # Slow write forces reads to sprint ahead if they can. |
| # This test verifies that they don't. |
| ticker.assign_add( |
| math_ops.cast( |
| math_ops.reduce_max( |
| random_ops.random_uniform( |
| shape=(1000,), minval=1.0, maxval=1.001)), |
| dtypes.int32)) |
| |
| i += 1 |
| |
| a = ticker.read_value() |
| b = ticker.read_value() |
| t_acc = t_acc.write(2 * i, a) |
| t_acc = t_acc.write(2 * i + 1, b) |
| |
| return t_acc.stack() |
| |
| # Warm-up. |
| self.evaluate(run_loop(1)) |
| |
| acc = run_loop(3) |
| self.assertAllEqual(acc, [0, 0, 1, 1, 2, 2, 3, 3]) |
| |
| def testCondDependenceOnMutatedResource(self): |
| |
| if not tf2.enabled(): |
| self.skipTest("V2-only test.") |
| |
| # TODO(b/152548567): Enable this. |
| while_v2.glob_stateful_parallelism = False |
| |
| ticker = variables.Variable(0) |
| counter = variables.Variable(1) |
| |
| @def_function.function |
| def run_loop(n): |
| ticker.assign(0) |
| counter.assign(0) |
| |
| while ticker.read_value() < n: |
| directives.set_loop_options(parallel_iterations=10) |
| |
| # Run a slow assign, to make sure counter sprints ahead. |
| ticker.assign_add( |
| math_ops.cast( |
| math_ops.reduce_max( |
| random_ops.random_uniform( |
| shape=(1000,), minval=1.0, maxval=1.001)), |
| dtypes.int32)) |
| counter.assign_add(1) |
| |
| return ticker.read_value(), counter.read_value() |
| |
| # Warm-up. |
| self.evaluate(run_loop(1)) |
| |
| t, c = run_loop(3) |
| self.assertEqual(self.evaluate(t), 3) |
| self.assertEqual(self.evaluate(c), 3) |
| |
| def testIndependentSideEffectsInCond(self): |
| |
| if not tf2.enabled(): |
| self.skipTest("V2-only test.") |
| |
| # TODO(b/152548567): Enable experimental_stateful_parallelism. |
| # Without proper wiring of control deps in the cond branch, the test is |
| # non-deterministic, running cond's record_side_effect ahead of its |
| # counterpart in the body. |
| while_v2.glob_stateful_parallelism = False |
| |
| state = [] |
| |
| def record_side_effect(c): |
| |
| def side_effect_py_fn(): |
| state.append(c) |
| return 0 |
| |
| script_ops.eager_py_func(side_effect_py_fn, [], [dtypes.int32]) |
| |
| @def_function.function |
| def run_loop(n): |
| |
| def complex_cond(i): |
| record_side_effect("A") |
| return i < n |
| |
| i = constant_op.constant(0) |
| |
| while complex_cond(i): |
| directives.set_loop_options(parallel_iterations=10) |
| |
| record_side_effect("B") |
| i += 1 |
| |
| return i |
| |
| # Warm-up. |
| self.evaluate(run_loop(1)) |
| |
| state.clear() |
| i = run_loop(3) |
| self.assertEqual(self.evaluate(i), 3) |
| self.assertListEqual(state, ["A", "B", "A", "B", "A", "B", "A"]) |
| |
| def testStatelessLoop(self): |
| |
| while_v2.glob_stateful_parallelism = True |
| |
| @def_function.function |
| def run_loop(n): |
| |
| a = 0 |
| b = 1 |
| |
| i = constant_op.constant(0) |
| while i < n: |
| directives.set_loop_options(parallel_iterations=10) |
| i += 1 |
| a += 2 |
| b *= 3 |
| |
| return i, a, b |
| |
| i, a, b = run_loop(3) |
| self.assertEqual(self.evaluate(i), 3) |
| self.assertEqual(self.evaluate(a), 6) |
| self.assertEqual(self.evaluate(b), 27) |
| |
| def testStatefulParallelism(self): |
| |
| if not tf2.enabled(): |
| self.skipTest("V2-only test.") |
| |
| while_v2.glob_stateful_parallelism = True |
| |
| ticker = variables.Variable(0) |
| # Secondary state for the pyfunc that lets us verify that things ran in |
| # the correct relative order. |
| ticker_state = [] |
| |
| def wait_then_tick(i): |
| # The contents of py_funcs is opaque, so TF doesn't see this variable |
| # assignment. In turn, this allows us to run it in parallel with |
| # the variable read. |
| def wait_then_tick_py_fn(i): |
| time.sleep(1) |
| ticker.assign_add(1) |
| ticker_state.append(i.numpy().item()) |
| return 1 |
| |
| return script_ops.eager_py_func(wait_then_tick_py_fn, [i], |
| [dtypes.int32])[0] |
| |
| @def_function.function |
| def run_loop(n): |
| ticker.assign(0) |
| i = constant_op.constant(0) |
| t_acc = tensor_array_ops.TensorArray( |
| dtypes.int32, size=0, dynamic_size=True) |
| |
| while i < n: |
| directives.set_loop_options(parallel_iterations=10) |
| |
| wait_then_tick(i + 1) |
| # The read is expected to run in much less than `wait_then_tick`, |
| # which sleeps for 1s. Hence all reads should complete before the first |
| # `wait_then_tick` increments the `ticker` variable. |
| t_acc = t_acc.write(i, ticker.read_value()) |
| i += 1 |
| |
| return t_acc.stack() |
| |
| # Warm-up. |
| self.evaluate(run_loop(1)) |
| |
| # This test is deterministic so long as the runtime is fast enough to |
| # execute `t_acc = t_acc.write(i, ticker.read_value())` in much less than |
| # one second. |
| self.evaluate(ticker.assign(123)) |
| ticker_state.clear() |
| acc = run_loop(3) |
| # Because the loop iterations are allowed to run in parallel, reads from |
| # different iterations may proceed ahead of pyfuncs from other iterations. |
| # Because reads are much faster, they should all complete before a single |
| # pyfunc does. |
| self.assertEqual(self.evaluate(math_ops.reduce_max(acc)), 0) |
| |
| # Double-check that the loop ran completely. |
| self.assertEqual(self.evaluate(ticker.read_value()), 3) |
| # Double check that the pyfuncs ran in order. |
| self.assertListEqual(ticker_state, [1, 2, 3]) |
| |
| |
| class AssertTest(test_util.TensorFlowTestCase): |
| |
| @test_util.run_deprecated_v1 |
| def testAssert(self): |
| i = constant_op.constant(0) |
| c = control_flow_ops.Assert(i < 10, [i, [10], [i + 1]]) |
| self.evaluate(c) |
| |
| i = constant_op.constant(10) |
| c = control_flow_ops.Assert(i < 10, [i, [10], [i + 1]]) |
| with self.assertRaises(errors.InvalidArgumentError): |
| self.evaluate(c) |
| |
| @test_util.run_in_graph_and_eager_modes |
| def testAssertInFunction(self): |
| # TODO(fishx): Re-enable this test for GPU. |
| # NOTE(fishx): Disable this test for now because, in GPU, multiple errors |
| # will be thrown. But since the root cause error is marked as "derived" |
| # error. So it might be ignored. |
| if test_util.is_gpu_available(): |
| self.skipTest("Skip GPU Test") |
| |
| @def_function.function |
| def whiny(value): |
| control_flow_ops.Assert(value, ["Raised false"]) |
| return constant_op.constant(5) |
| |
| with self.assertRaises(errors.InvalidArgumentError): |
| self.evaluate(whiny(False)) |
| |
| self.assertAllEqual(whiny(True), 5) |
| |
| |
| if __name__ == "__main__": |
| googletest.main() |