# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Critical Section object and execution logic."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# from tensorflow.core.protobuf import critical_section_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import tf_export
__all__ = ["CriticalSection"]
# Graph Keys
CRITICAL_SECTIONS = "critical_sections"
CRITICAL_SECTION_EXECUTIONS = "critical_section_executions"
class _ExecutionSignature(
("op", "handle",
"resources", "exclusive_resource_access"))):
"""A class storing an `ExecuteInCriticalResource` op and associated attrs."""
def _identity(x):
"""Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`."""
if isinstance(x, tensor_array_ops.TensorArray):
return x.identity()
elif isinstance(x, ops.Operation):
elif context.executing_eagerly() and x is None:
return None
return array_ops.identity(x)
def _get_colocation(op):
"""Get colocation symbol from op, if any."""
return op.get_attr("_class")
except ValueError:
return None
class CriticalSection(object):
"""Critical section.
A `CriticalSection` object is a resource in the graph which executes subgraphs
in **serial** order. A common example of a subgraph one may wish to run
exclusively is the one given by the following function:
v = resource_variable_ops.ResourceVariable(0.0, name="v")
def count():
value = v.read_value()
with tf.control_dependencies([value]):
with tf.control_dependencies([v.assign_add(1)]):
return tf.identity(value)
Here, a snapshot of `v` is captured in `value`; and then `v` is updated.
The snapshot value is returned.
If multiple workers or threads all execute `count` in parallel, there is no
guarantee that access to the variable `v` is atomic at any point within
any thread's calculation of `count`. In fact, even implementing an atomic
counter that guarantees that the user will see each value `0, 1, ...,` is
currently impossible.
The solution is to ensure any access to the underlying resource `v` is
only processed through a critical section:
cs = CriticalSection()
f1 = cs.execute(count)
f2 = cs.execute(count)
output = f1 + f2
The functions `f1` and `f2` will be executed serially, and updates to `v`
will be atomic.
All resource objects, including the critical section and any captured
variables of functions executed on that critical section, will be
colocated to the same device (host and cpu/gpu).
When using multiple critical sections on the same resources, there is no
guarantee of exclusive access to those resources. This behavior is disallowed
by default (but see the kwarg `exclusive_resource_access`).
For example, running the same function in two separate critical sections
will not ensure serial execution:
v = tf.compat.v1.get_variable("v", initializer=0.0, use_resource=True)
def accumulate(up):
x = v.read_value()
with tf.control_dependencies([x]):
with tf.control_dependencies([v.assign_add(up)]):
return tf.identity(x)
ex1 = CriticalSection().execute(
accumulate, 1.0, exclusive_resource_access=False)
ex2 = CriticalSection().execute(
accumulate, 1.0, exclusive_resource_access=False)
bad_sum = ex1 + ex2 # May return 0.0
def __init__(self, name=None, shared_name=None,
critical_section_def=None, import_scope=None):
"""Creates a critical section."""
if critical_section_def and name is not None:
raise ValueError("critical_section_def and shared_name are "
"mutually exclusive.")
if critical_section_def:
self._init_from_proto(critical_section_def, import_scope=import_scope)
self._init_from_args(name, shared_name)
def _init_from_proto(self, critical_section_def, import_scope): # pylint: disable=invalid-name
raise NotImplementedError("Not yet implemented")
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# assert isinstance(
# critical_section_def, critical_section_pb2.CriticalSectionDef)
# # Create from critical_section_def.
# g = ops.get_default_graph()
# self._handle = g.as_graph_element(
# ops.prepend_name_scope(
# critical_section_def.critical_section_name,
# import_scope=import_scope))
def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name
"""Initialize the CriticalSection from constructor arguments."""
with ops.name_scope(name, "CriticalSection", []) as name:
with ops.init_scope():
# pylint: disable=protected-access
container = ops.get_default_graph()._container
# pylint: enable=protected-access
if shared_name is None:
shared_name = name
if container is None:
container = ""
self._handle = gen_resource_variable_ops.mutex_v2(
shared_name=shared_name, container=container, name=name)
if not context.executing_eagerly():
ops.add_to_collections(CRITICAL_SECTIONS, self)
def name(self):
def execute(self, fn, exclusive_resource_access=True, name=None):
"""Execute function `fn()` inside the critical section.
`fn` should not accept any arguments. To add extra arguments to when
calling `fn` in the critical section, create a lambda:
critical_section.execute(lambda: fn(*my_args, **my_kwargs))
fn: The function to execute. Must return at least one tensor.
exclusive_resource_access: Whether the resources required by
`fn` should be exclusive to this `CriticalSection`. Default: `True`.
You may want to set this to `False` if you will be accessing a
resource in read-only mode in two different CriticalSections.
name: The name to use when creating the execute operation.
The tensors returned from `fn()`.
ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
or lazy way that may cause a deadlock.
ValueError: If `exclusive_resource_access == True` and
another `CriticalSection` has an execution requesting the same
resources as `fn``. Note, even if `exclusive_resource_access` is
`True`, if another execution in another `CriticalSection` was created
without `exclusive_resource_access=True`, a `ValueError` will be raised.
with ops.name_scope(name, "critical_section_execute", []):
# Ensure that mutex locking only happens *after* all args and
# kwargs have been executed. This avoids certain types of deadlocks.
lock = gen_resource_variable_ops.mutex_lock(self._handle)
if not context.executing_eagerly():
# NOTE(ebrevdo): This is to ensure we don't pick up spurious
# Operations created by other threads.
with ops.get_default_graph()._lock: # pylint: disable=protected-access
existing_ops = ops.get_default_graph().get_operations()
with ops.control_dependencies([lock]):
r = fn()
# TODO(ebrevdo): If creating critical sections in a python loop, this
# makes graph creation time quadratic. Revisit if this
# becomes a problem.
created_ops = (set(ops.get_default_graph().get_operations())
with ops.control_dependencies([lock]):
r = fn()
if not context.executing_eagerly():
self._add_control_dependencies_to_lock(created_ops, lock.op)
# captured_resources is a list of resources that are directly
# accessed only by ops created during fn(), not by any
# ancestors of those ops in the graph.
captured_resources = object_identity.ObjectIdentitySet([
input_ for op in created_ops
for input_ in op.inputs
if input_.dtype == dtypes.resource
# NOTE(ebrevdo): The only time self._is_self_handle() is True
# in this call is if one of the recently created ops, within
# the execute(), themselves attempt to access the
# CriticalSection. This will cause a deadlock.
if any(self._is_self_handle(x) for x in captured_resources):
raise ValueError("The function fn attempts to directly access the "
"CriticalSection in which it would be running. "
"This is illegal and would cause deadlocks.")
captured_resources, exclusive_resource_access)
r_flat = [_identity(x) for x in nest.flatten(r)]
with ops.control_dependencies(r_flat):
# The identity must run on the same machine as self._handle
with ops.colocate_with(self._handle):
# Do not use array_ops.identity as there are special
# optimizations within TensorFlow which seem to elide it
# even when optimizations are disabled(!).
ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
# Make sure that if any element of r is accessed, all of
# them are executed together.
r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r)))
with ops.control_dependencies([ensure_lock_exists]):
outputs = nest.map_structure(_identity, r)
if not context.executing_eagerly():
signature = _ExecutionSignature(
return outputs
def _add_control_dependencies_to_lock(self, created_ops, lock_op):
"""To avoid deadlocks, all args must be executed before lock_op."""
# Get all arguments (explicit and captured) of all ops created by fn().
all_args = set([input_.op for op in created_ops for input_ in op.inputs])
input_op for op in created_ops for input_op in op.control_inputs)
# Unfortunately, we can't use sets throughout because TF seems to
# create new Operation objects for the same op sometimes; and we
# can't rely on id(op).
# pylint: disable=protected-access
all_args_dict = dict((op._id, op) for op in all_args)
# Remove ops created within fn, or that lock_op already has a
# control dependency on. Also remove a possible self-loop.
for op in created_ops:
all_args_dict.pop(op._id, None)
for op in lock_op.control_inputs:
all_args_dict.pop(op._id, None)
for input_ in lock_op.inputs:
all_args_dict.pop(input_.op._id, None)
all_args_dict.pop(lock_op._id, None)
all_args = all_args_dict.values()
if not all_args:
# No control dependencies to add; return early.
# This group is important: it ensures that any ops in all_args
# outside the control context of the lock_op (and this fn, which
# runs in the same context) are added to this context before
# being added to the control dependencies of lock_op.
all_args =*all_args)
# pylint: enable=protected-access
def _is_self_handle(self, x):
"""Check if the tensor `x` is the same Mutex as `self._handle`."""
if isinstance(x, ops.EagerTensor):
return x is self._handle
return (x.op.type == "MutexV2"
# blank shared_name means the op will create a unique one.
and x.op.get_attr("shared_name")
and (x.op.get_attr("shared_name") ==
and (x.op.device == self._handle.op.device
or _get_colocation(x.op) == _get_colocation(self._handle.op)))
def _check_multiple_access_to_resources(
self, captured_resources, exclusive_resource_access):
"""Raise if captured_resources are accessed by another CriticalSection.
captured_resources: Set of tensors of type resource.
exclusive_resource_access: Whether this execution requires exclusive
resource access.
ValueError: If any tensors in `captured_resources` are also accessed
by another `CriticalSection`, and at least one of them requires
exclusive resource access.
# Collections and op introspection does not work in eager
# mode. This is generally ok; since eager mode (as of
# writing) executes sequentially anyway.
for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
if self._is_self_handle(sg.handle):
# Other executions in the same critical section are allowed.
if not (exclusive_resource_access or sg.exclusive_resource_access):
# Neither execution requested exclusive access.
resource_intersection = captured_resources.intersection(sg.resources)
if resource_intersection:
raise ValueError(
"This execution would access resources: %s. Either this "
"lock (CriticalSection: %s) or lock '%s' "
"(CriticalSection: %s) requested exclusive resource access "
"of this resource. Did you mean to call execute with keyword "
"argument exclusive_resource_access=False?" %
(list(resource_intersection), self._handle, sg, sg.handle))
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# def to_proto(self, export_scope=None):
# """Converts a `CriticalSection` to a `CriticalSectoinDef` protocol buffer.
# Args:
# export_scope: Optional `string`. Name scope to remove.
# Returns:
# A `CriticalSectionDef` protocol buffer, or `None` if the
# `CriticalSection` is not in the specified name scope.
# """
# if export_scope is None or
# cs_def = critical_section_pb2.CriticalSectionDef()
# cs_def.critical_section_name = ops.strip_name_scope(
#, export_scope)
# return cs_def
# else:
# return None
# @staticmethod
# def from_proto(critical_section_def, import_scope=None):
# return CriticalSection(
# critical_section_def=critical_section_def, import_scope=import_scope)
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# def _execution_to_proto_fn(execution_signature, export_scope=None):
# """Converts `_ExecutionSignature` to a `CriticalSectionExecutionDef`.
# # TODO(ebrevdo): Update for _ExecutionSignature storing resource list.
# Args:
# execution_signature: Instance of `_ExecutionSignature`.
# export_scope: The export scope, if any.
# Returns:
# An instance of `CriticalSectionExecutionDef`.
# """
# if (export_scope is None
# or
# op_def = critical_section_pb2.CriticalSectionExecutionDef()
# op_def.execute_in_critical_section_name = ops.strip_name_scope(
#, export_scope)
# op_def.exclusive_resource_access = (
# execution_signature.exclusive_resource_access)
# return op_def
# else:
# return None
# def _execution_from_proto_fn(op_def, import_scope=None):
# """Converts a `CriticalSectionExecutionDef` to a `_ExecutionSignature`."""
# # TODO(ebrevdo): Update for _ExecutionSignature storing resource list.
# assert isinstance(
# op_def, critical_section_pb2.CriticalSectionExecutionDef)
# # Create from op_def.
# g = ops.get_default_graph()
# execution_op = g.as_graph_element(
# ops.prepend_name_scope(
# op_def.execute_in_critical_section_name,
# import_scope=import_scope))
# return _ExecutionSignature(
# op=execution_op,
# exclusive_resource_access=op_def.exclusive_resource_access)
# ops.register_proto_function(
# proto_type=critical_section_pb2.CriticalSectionDef,
# to_proto=CriticalSection.to_proto,
# from_proto=CriticalSection.from_proto)
# ops.register_proto_function(
# proto_type=critical_section_pb2.CriticalSectionExecutionDef,
# to_proto=_execution_to_proto_fn,
# from_proto=_execution_from_proto_fn)