blob: 5e515e15b428c4c87decece1dbc394018abe6c39 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""critical section tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import critical_section_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# from tensorflow.python.training import saver as saver_lib
@test_util.with_control_flow_v2
class CriticalSectionTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def testCreateCriticalSection(self):
cs = critical_section_ops.CriticalSection(shared_name="cs")
v = resource_variable_ops.ResourceVariable(0.0, name="v")
def fn(a, b):
c = v.value()
with ops.control_dependencies([c]):
nv = v.assign_add(a * b)
with ops.control_dependencies([nv]):
return array_ops.identity(c)
num_concurrent = 100
r = [cs.execute(lambda: fn(1.0, 2.0)) for _ in range(num_concurrent)]
self.evaluate(v.initializer)
r_value = self.evaluate(r)
self.assertAllClose([2.0 * i for i in range(num_concurrent)],
sorted(r_value))
@parameterized.named_parameters(
("Inner%sOuter%s" % (inner, outer), inner, outer)
for (inner, outer) in itertools.product(*([(False, True)] * 2)))
@test_util.run_in_graph_and_eager_modes
@test_util.xla_allow_fallback("b/128495870")
def testCriticalSectionWithControlFlow(self, outer_cond, inner_cond):
if (not context.executing_eagerly() and
control_flow_v2_toggles.control_flow_v2_enabled()):
self.skipTest("b/135070612")
cs = critical_section_ops.CriticalSection(shared_name="cs")
v = resource_variable_ops.ResourceVariable(0.0, name="v")
num_concurrent = 100
# pylint: disable=cell-var-from-loop
def fn(a, b):
c = v.read_value()
def true_fn():
with ops.control_dependencies([c]):
nv = v.assign_add(a * b)
with ops.control_dependencies([nv]):
return array_ops.identity(c)
return control_flow_ops.cond(
array_ops.identity(inner_cond), true_fn, lambda: c)
def execute():
return cs.execute(lambda: fn(1.0, 2.0))
r = [
control_flow_ops.cond(array_ops.identity(outer_cond),
execute,
v.read_value)
for _ in range(num_concurrent)
]
# pylint: enable=cell-var-from-loop
self.evaluate(v.initializer)
r_value = self.evaluate(r)
if inner_cond and outer_cond:
self.assertAllClose([2.0 * i for i in range(num_concurrent)],
sorted(r_value))
else:
self.assertAllClose([0] * num_concurrent, r_value)
@test_util.run_v1_only("b/123990562 Sees CancelledError on some calls")
def testCriticalSectionInParallelDoesntDeadlockOnError(self):
# No eager mode execution of this test because eager does not
# run fn() in parallel, which is where the deadlock could
# potentially occur (in graph mode).
cs = critical_section_ops.CriticalSection(shared_name="cs")
v = resource_variable_ops.ResourceVariable(0.0, name="v")
def fn(i):
error = control_flow_ops.Assert((i % 2) == 1, ["Error"])
with ops.control_dependencies([error]):
return v.read_value()
num_concurrent = 2
@def_function.function(autograph=False)
def run_concurrently():
return [cs.execute(lambda: fn(i)) for i in range(num_concurrent)]
if not context.executing_eagerly():
run_concurrently = run_concurrently()
self.evaluate(v.initializer)
for _ in range(100):
with self.assertRaisesOpError("Error"):
if context.executing_eagerly():
run_concurrently()
else:
self.evaluate(run_concurrently)
@test_util.run_in_graph_and_eager_modes
def testCreateCriticalSectionFnReturnsOp(self):
cs = critical_section_ops.CriticalSection(shared_name="cs")
v = resource_variable_ops.ResourceVariable(0.0, name="v")
def fn_return_op(a, b):
c = v.read_value()
with ops.control_dependencies([c]):
nv = v.assign_add(a * b)
with ops.control_dependencies([nv]):
return control_flow_ops.no_op()
num_concurrent = 100
r = [cs.execute(lambda: fn_return_op(1.0, 2.0))
for _ in range(num_concurrent)]
self.evaluate(v.initializer)
self.evaluate(r)
final_v = self.evaluate(v)
self.assertAllClose(2.0 * num_concurrent, final_v)
@test_util.run_v1_only("Collections don't exist in TF2")
def testCollection(self):
cs = critical_section_ops.CriticalSection(shared_name="cs")
self.assertIn(
cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS))
add = lambda x: x + 1
execute = cs.execute(lambda: add(1.0), name="my_execute")
execute_op = [
x for x in execute.graph.get_operations()
if "my_execute" in x.name and "MutexLock" in x.type
][0]
self.assertIn(
execute_op,
[signature.op for signature in
ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)])
@test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
def testRecursiveCriticalSectionAccessIsIllegal(self):
# This does not work properly in eager mode. Eager users will
# just hit a deadlock if they do this. But at least it'll be easier
# to debug.
cs = critical_section_ops.CriticalSection()
add = lambda y: y + 1
def fn(x):
return cs.execute(lambda: add(x))
with self.assertRaisesRegexp(
ValueError,
r"attempts to directly access the CriticalSection in which it "
r"would be running"):
cs.execute(lambda: fn(1.0))
def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self):
# This one is subtle; and we're being overly cautious here. The
# deadlock we are ensuring we catch is:
#
# to_capture = CS[lambda x: x + 1](1.0)
# deadlocked = CS[lambda x: x + to_capture](1.0)
#
# This would have caused a deadlock because executing `deadlocked` will
# lock the mutex on CS; but then due to dependencies, will attempt
# to compute `to_capture`. This computation requires locking CS,
# but that is not possible now because CS is already locked by
# `deadlocked`.
#
# We check that CriticalSection.execute properly inserts new
# control dependencies to its lock to ensure all captured
# operations are finished before anything runs within the critical section.
cs = critical_section_ops.CriticalSection(shared_name="cs")
fn = array_ops.identity
to_capture = cs.execute(lambda: fn(1.0))
fn_captures = lambda x: x + to_capture
to_capture_too = array_ops.identity(to_capture)
ex_0 = cs.execute(lambda: fn_captures(1.0))
with ops.control_dependencies([to_capture]):
# This is OK because to_capture will execute before this next call
ex_1 = cs.execute(lambda: fn_captures(1.0))
dependency = array_ops.identity(to_capture)
fn_captures_dependency = lambda x: x + dependency
ex_2 = cs.execute(lambda: fn_captures_dependency(1.0))
with ops.control_dependencies([to_capture_too]):
ex_3 = cs.execute(lambda: fn_captures_dependency(1.0))
# Ensure there's no actual deadlock on to_execute.
self.assertEquals(2.0, self.evaluate(ex_0))
self.assertEquals(2.0, self.evaluate(ex_1))
self.assertEquals(2.0, self.evaluate(ex_2))
self.assertEquals(2.0, self.evaluate(ex_3))
def testRecursiveCriticalSectionAccessWithinLoopIsProtected(self):
cs = critical_section_ops.CriticalSection(shared_name="cs")
def body_implicit_capture(i, j):
# This would have caused a deadlock if not for logic in execute
# that inserts additional control dependencies onto the lock op:
# * Loop body argument j is captured by fn()
# * i is running in parallel to move forward the execution
# * j is not being checked by the predicate function
# * output of cs.execute() is returned as next j.
fn = lambda: j + 1
return (i + 1, cs.execute(fn))
(i_n, j_n) = control_flow_ops.while_loop(
lambda i, _: i < 1000,
body_implicit_capture,
[0, 0],
parallel_iterations=25)
# For consistency between eager and graph mode.
i_n = array_ops.identity(i_n)
logging.warn(
"\n==============\nRunning "
"'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
"body_implicit_capture'\n"
"==============\n")
self.assertEquals((1000, 1000), self.evaluate((i_n, j_n)))
logging.warn(
"\n==============\nSuccessfully finished running "
"'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
"body_implicit_capture'\n"
"==============\n")
def body_implicit_capture_protected(i, j):
# This version is ok because we manually add a control
# dependency on j, which is an argument to the while_loop body
# and captured by fn.
fn = lambda: j + 1
with ops.control_dependencies([j]):
return (i + 1, cs.execute(fn))
(i_n, j_n) = control_flow_ops.while_loop(
lambda i, _: i < 1000,
body_implicit_capture_protected,
[0, 0],
parallel_iterations=25)
# For consistency between eager and graph mode.
i_n = array_ops.identity(i_n)
logging.warn(
"\n==============\nRunning "
"'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
"body_implicit_capture_protected'\n"
"==============\n")
self.assertEquals((1000, 1000), self.evaluate((i_n, j_n)))
logging.warn(
"\n==============\nSuccessfully finished running "
"'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
"body_implicit_capture_protected'\n"
"==============\n")
def body_args_capture(i, j):
# This version is ok because j is an argument to fn and we can
# ensure there's a control dependency on j.
fn = lambda x: x + 1
return (i + 1, cs.execute(lambda: fn(j)))
(i_n, j_n) = control_flow_ops.while_loop(
lambda i, _: i < 1000,
body_args_capture,
[0, 0],
parallel_iterations=25)
# For consistency between eager and graph mode.
i_n = array_ops.identity(i_n)
logging.warn(
"\n==============\nRunning "
"'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
"body_args_capture'\n"
"==============\n")
self.assertEquals((1000, 1000), self.evaluate((i_n, j_n)))
logging.warn(
"\n==============\nSuccessfully finished running "
"'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
"body_args_capture'\n"
"==============\n")
@test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self):
# This does not work properly in eager mode. Eager users will
# just hit a deadlock if they do this. But at least it'll be easier
# to debug.
cs = critical_section_ops.CriticalSection(shared_name="cs")
cs_same = critical_section_ops.CriticalSection(shared_name="cs")
add = lambda x: x + 1
def fn(x):
return cs_same.execute(lambda: add(x))
with self.assertRaisesRegexp(
ValueError,
r"attempts to directly access the CriticalSection in which it "
r"would be running"):
cs.execute(lambda: fn(1.0))
@test_util.run_v1_only("b/123955885 Can't identify deadlocks in eager mode")
def testMultipleCSExecutionsRequestSameResource(self):
cs0 = critical_section_ops.CriticalSection()
cs1 = critical_section_ops.CriticalSection()
v = resource_variable_ops.ResourceVariable(0.0, name="v")
cs0.execute(lambda: v + 1)
# It's OK for the same CriticalSection to access this resource.
cs0.execute(lambda: v - 1)
# It's *not* OK for a different CriticalSection to access it by
# default.
with self.assertRaisesRegexp(
ValueError, "requested exclusive resource access"):
cs1.execute(lambda: v + 1)
# It's not even OK if the second call doesn't request exclusive access.
with self.assertRaisesRegexp(
ValueError, "requested exclusive resource access"):
cs1.execute(lambda: v + 1, exclusive_resource_access=False)
v2 = resource_variable_ops.ResourceVariable(0.0, name="v2")
cs0.execute(lambda: v2 + 1, exclusive_resource_access=False)
# It's OK if neither requests exclusive resource access.
cs1.execute(lambda: v2 + 1, exclusive_resource_access=False)
# It's not OK if the second request requires exlusive resource
# access.
with self.assertRaisesRegexp(
ValueError, "requested exclusive resource access"):
cs1.execute(lambda: v2 + 1)
def testControlDependencyFromOutsideWhileLoopMixedWithInsideLoop(self):
cs = critical_section_ops.CriticalSection()
v = resource_variable_ops.ResourceVariable(0, name="v")
# Make sure that the control dependencies on v do not cause issues
# in the lock_op's automatic control dependency adder.
#
# Note, here v must be a resource variable (or something similar),
# otherwise it gets hoisted into the while_loop by the time we add
# control dependencies to the lock_op.
def body(i):
add_j = lambda j: v + j + 1
return cs.execute(lambda: add_j(i))
out = control_flow_ops.while_loop(
lambda i: i < 10, body, [0])
self.evaluate(v.initializer)
self.assertEqual(10, self.evaluate(out))
@test_util.run_in_graph_and_eager_modes
def testInsideFunction(self):
if test_util.is_gpu_available():
self.skipTest(
"b/123899495: Colocation errors for critical sections in map on GPU")
cs = critical_section_ops.CriticalSection()
with ops.device("/gpu:0" if test_util.is_gpu_available() else "/cpu:0"):
v = resource_variable_ops.ResourceVariable(1)
def fn():
return v.read_value()
# map() creates a TensorFlow function.
ds = dataset_ops.Dataset.range(1)
if test_util.is_gpu_available():
ds = (ds.apply(prefetching_ops.copy_to_device("/gpu:0"))
.apply(prefetching_ops.map_on_gpu(lambda _: cs.execute(fn))))
else:
ds = ds.map(lambda _: cs.execute(fn))
def get_first():
if context.executing_eagerly():
return self.evaluate(ds.make_one_shot_iterator().get_next())
itr = ds.make_initializable_iterator()
self.evaluate([v.initializer, itr.initializer])
return self.evaluate(itr.get_next())
self.assertEqual(1, get_first())
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
#
# def testCriticalSectionAndExecuteOpSaverRoundTrip(self):
# cs = critical_section_ops.CriticalSection()
# r = cs.execute(lambda x: x + 1, 1.0)
# graph = ops.get_default_graph()
# meta_graph = saver_lib.export_meta_graph(
# graph=graph, collection_list=graph.get_all_collection_keys())
# graph_copy = ops.Graph()
# with graph_copy.as_default():
# _ = saver_lib.import_meta_graph(meta_graph, import_scope="imported")
# restored_cs = ops.get_collection(critical_section_ops.CRITICAL_SECTIONS)
# restored_exec = ops.get_collection(
# critical_section_ops.CRITICAL_SECTION_EXECUTIONS)
# self.assertEqual(1, len(restored_cs))
# self.assertEqual(1, len(restored_exec))
# self.assertEqual(restored_cs[0].name, "imported/%s" % cs.name)
# self.assertEqual(restored_exec[0].op.name, "imported/%s" % r.op.name)
# def testToProto(self):
# cs = critical_section_ops.CriticalSection(shared_name="cs")
# proto = cs.to_proto()
# self.assertEqual(proto.critical_section_name, cs._handle.name)
# cs_copy = critical_section_ops.CriticalSection.from_proto(proto)
# self.assertEqual(cs_copy._handle, cs._handle)
if __name__ == "__main__":
test.main()