blob: 16419e45bdaa587d61c4d9ed400c4a12d9369824 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Critical Section object and execution logic."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# from tensorflow.core.protobuf import critical_section_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import tf_export
__all__ = ["CriticalSection"]
# Graph Keys
CRITICAL_SECTIONS = "critical_sections"
CRITICAL_SECTION_EXECUTIONS = "critical_section_executions"
class _ExecutionSignature(
collections.namedtuple("_ExecutionSignature",
("op", "handle",
"resources", "exclusive_resource_access"))):
"""A class storing an `ExecuteInCriticalResource` op and associated attrs."""
pass
def _identity(x):
"""Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`."""
if isinstance(x, tensor_array_ops.TensorArray):
return x.identity()
elif isinstance(x, ops.Operation):
return control_flow_ops.group(x)
elif context.executing_eagerly() and x is None:
return None
else:
return array_ops.identity(x)
def _get_colocation(op):
"""Get colocation symbol from op, if any."""
try:
return op.get_attr("_class")
except ValueError:
return None
@tf_export("CriticalSection")
class CriticalSection(object):
"""Critical section.
A `CriticalSection` object is a resource in the graph which executes subgraphs
in **serial** order. A common example of a subgraph one may wish to run
exclusively is the one given by the following function:
```python
v = resource_variable_ops.ResourceVariable(0.0, name="v")
def count():
value = v.read_value()
with tf.control_dependencies([value]):
with tf.control_dependencies([v.assign_add(1)]):
return tf.identity(value)
```
Here, a snapshot of `v` is captured in `value`; and then `v` is updated.
The snapshot value is returned.
If multiple workers or threads all execute `count` in parallel, there is no
guarantee that access to the variable `v` is atomic at any point within
any thread's calculation of `count`. In fact, even implementing an atomic
counter that guarantees that the user will see each value `0, 1, ...,` is
currently impossible.
The solution is to ensure any access to the underlying resource `v` is
only processed through a critical section:
```python
cs = CriticalSection()
f1 = cs.execute(count)
f2 = cs.execute(count)
output = f1 + f2
session.run(output)
```
The functions `f1` and `f2` will be executed serially, and updates to `v`
will be atomic.
**NOTES**
All resource objects, including the critical section and any captured
variables of functions executed on that critical section, will be
colocated to the same device (host and cpu/gpu).
When using multiple critical sections on the same resources, there is no
guarantee of exclusive access to those resources. This behavior is disallowed
by default (but see the kwarg `exclusive_resource_access`).
For example, running the same function in two separate critical sections
will not ensure serial execution:
```python
v = tf.compat.v1.get_variable("v", initializer=0.0, use_resource=True)
def accumulate(up):
x = v.read_value()
with tf.control_dependencies([x]):
with tf.control_dependencies([v.assign_add(up)]):
return tf.identity(x)
ex1 = CriticalSection().execute(
accumulate, 1.0, exclusive_resource_access=False)
ex2 = CriticalSection().execute(
accumulate, 1.0, exclusive_resource_access=False)
bad_sum = ex1 + ex2
sess.run(v.initializer)
sess.run(bad_sum) # May return 0.0
```
"""
def __init__(self, name=None, shared_name=None,
critical_section_def=None, import_scope=None):
"""Creates a critical section."""
context.ensure_initialized()
if critical_section_def and name is not None:
raise ValueError("critical_section_def and shared_name are "
"mutually exclusive.")
if critical_section_def:
self._init_from_proto(critical_section_def, import_scope=import_scope)
else:
self._init_from_args(name, shared_name)
def _init_from_proto(self, critical_section_def, import_scope): # pylint: disable=invalid-name
raise NotImplementedError("Not yet implemented")
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# assert isinstance(
# critical_section_def, critical_section_pb2.CriticalSectionDef)
# # Create from critical_section_def.
# g = ops.get_default_graph()
# self._handle = g.as_graph_element(
# ops.prepend_name_scope(
# critical_section_def.critical_section_name,
# import_scope=import_scope))
def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name
"""Initialize the CriticalSection from constructor arguments."""
with ops.name_scope(name, "CriticalSection", []) as name:
with ops.init_scope():
# pylint: disable=protected-access
container = ops.get_default_graph()._container
# pylint: enable=protected-access
if shared_name is None:
shared_name = name
if container is None:
container = ""
self._handle = gen_resource_variable_ops.mutex_v2(
shared_name=shared_name, container=container, name=name)
if not context.executing_eagerly():
ops.add_to_collections(CRITICAL_SECTIONS, self)
@property
def name(self):
return self._handle.op.name
def execute(self, fn, exclusive_resource_access=True, name=None):
"""Execute function `fn()` inside the critical section.
`fn` should not accept any arguments. To add extra arguments to when
calling `fn` in the critical section, create a lambda:
```python
critical_section.execute(lambda: fn(*my_args, **my_kwargs))
```
Args:
fn: The function to execute. Must return at least one tensor.
exclusive_resource_access: Whether the resources required by
`fn` should be exclusive to this `CriticalSection`. Default: `True`.
You may want to set this to `False` if you will be accessing a
resource in read-only mode in two different CriticalSections.
name: The name to use when creating the execute operation.
Returns:
The tensors returned from `fn()`.
Raises:
ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
or lazy way that may cause a deadlock.
ValueError: If `exclusive_resource_access == True` and
another `CriticalSection` has an execution requesting the same
resources as `fn``. Note, even if `exclusive_resource_access` is
`True`, if another execution in another `CriticalSection` was created
without `exclusive_resource_access=True`, a `ValueError` will be raised.
"""
with ops.name_scope(name, "critical_section_execute", []):
# Ensure that mutex locking only happens *after* all args and
# kwargs have been executed. This avoids certain types of deadlocks.
lock = gen_resource_variable_ops.mutex_lock(self._handle)
if not context.executing_eagerly():
# NOTE(ebrevdo): This is to ensure we don't pick up spurious
# Operations created by other threads.
with ops.get_default_graph()._lock: # pylint: disable=protected-access
existing_ops = ops.get_default_graph().get_operations()
with ops.control_dependencies([lock]):
r = fn()
# TODO(ebrevdo): If creating critical sections in a python loop, this
# makes graph creation time quadratic. Revisit if this
# becomes a problem.
created_ops = (set(ops.get_default_graph().get_operations())
.difference(existing_ops))
else:
with ops.control_dependencies([lock]):
r = fn()
if not context.executing_eagerly():
self._add_control_dependencies_to_lock(created_ops, lock.op)
# captured_resources is a list of resources that are directly
# accessed only by ops created during fn(), not by any
# ancestors of those ops in the graph.
captured_resources = object_identity.ObjectIdentitySet([
input_ for op in created_ops
for input_ in op.inputs
if input_.dtype == dtypes.resource
])
# NOTE(ebrevdo): The only time self._is_self_handle() is True
# in this call is if one of the recently created ops, within
# the execute(), themselves attempt to access the
# CriticalSection. This will cause a deadlock.
if any(self._is_self_handle(x) for x in captured_resources):
raise ValueError("The function fn attempts to directly access the "
"CriticalSection in which it would be running. "
"This is illegal and would cause deadlocks.")
self._check_multiple_access_to_resources(
captured_resources, exclusive_resource_access)
r_flat = [_identity(x) for x in nest.flatten(r)]
with ops.control_dependencies(r_flat):
# The identity must run on the same machine as self._handle
with ops.colocate_with(self._handle):
# Do not use array_ops.identity as there are special
# optimizations within TensorFlow which seem to elide it
# even when optimizations are disabled(!).
ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
lock)
# Make sure that if any element of r is accessed, all of
# them are executed together.
r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r)))
with ops.control_dependencies([ensure_lock_exists]):
outputs = nest.map_structure(_identity, r)
if not context.executing_eagerly():
signature = _ExecutionSignature(
op=lock.op,
handle=self._handle,
resources=list(captured_resources),
exclusive_resource_access=exclusive_resource_access)
ops.add_to_collections(
CRITICAL_SECTION_EXECUTIONS, signature)
return outputs
def _add_control_dependencies_to_lock(self, created_ops, lock_op):
"""To avoid deadlocks, all args must be executed before lock_op."""
# Get all arguments (explicit and captured) of all ops created by fn().
all_args = set([input_.op for op in created_ops for input_ in op.inputs])
all_args.update(
input_op for op in created_ops for input_op in op.control_inputs)
# Unfortunately, we can't use sets throughout because TF seems to
# create new Operation objects for the same op sometimes; and we
# can't rely on id(op).
# pylint: disable=protected-access
all_args_dict = dict((op._id, op) for op in all_args)
# Remove ops created within fn, or that lock_op already has a
# control dependency on. Also remove a possible self-loop.
for op in created_ops:
all_args_dict.pop(op._id, None)
for op in lock_op.control_inputs:
all_args_dict.pop(op._id, None)
for input_ in lock_op.inputs:
all_args_dict.pop(input_.op._id, None)
all_args_dict.pop(lock_op._id, None)
all_args = all_args_dict.values()
if not all_args:
# No control dependencies to add; return early.
return
# This group is important: it ensures that any ops in all_args
# outside the control context of the lock_op (and this fn, which
# runs in the same context) are added to this context before
# being added to the control dependencies of lock_op.
all_args = control_flow_ops.group(*all_args)
lock_op._add_control_input(all_args)
# pylint: enable=protected-access
def _is_self_handle(self, x):
"""Check if the tensor `x` is the same Mutex as `self._handle`."""
if isinstance(x, ops.EagerTensor):
return x is self._handle
return (x.op.type == "MutexV2"
# blank shared_name means the op will create a unique one.
and x.op.get_attr("shared_name")
and (x.op.get_attr("shared_name") ==
self._handle.op.get_attr("shared_name"))
and (x.op.device == self._handle.op.device
or _get_colocation(x.op) == _get_colocation(self._handle.op)))
def _check_multiple_access_to_resources(
self, captured_resources, exclusive_resource_access):
"""Raise if captured_resources are accessed by another CriticalSection.
Args:
captured_resources: Set of tensors of type resource.
exclusive_resource_access: Whether this execution requires exclusive
resource access.
Raises:
ValueError: If any tensors in `captured_resources` are also accessed
by another `CriticalSection`, and at least one of them requires
exclusive resource access.
"""
# Collections and op introspection does not work in eager
# mode. This is generally ok; since eager mode (as of
# writing) executes sequentially anyway.
for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
if self._is_self_handle(sg.handle):
# Other executions in the same critical section are allowed.
continue
if not (exclusive_resource_access or sg.exclusive_resource_access):
# Neither execution requested exclusive access.
continue
resource_intersection = captured_resources.intersection(sg.resources)
if resource_intersection:
raise ValueError(
"This execution would access resources: %s. Either this "
"lock (CriticalSection: %s) or lock '%s' "
"(CriticalSection: %s) requested exclusive resource access "
"of this resource. Did you mean to call execute with keyword "
"argument exclusive_resource_access=False?" %
(list(resource_intersection), self._handle, sg, sg.handle))
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# def to_proto(self, export_scope=None):
# """Converts a `CriticalSection` to a `CriticalSectoinDef` protocol buffer.
# Args:
# export_scope: Optional `string`. Name scope to remove.
# Returns:
# A `CriticalSectionDef` protocol buffer, or `None` if the
# `CriticalSection` is not in the specified name scope.
# """
# if export_scope is None or self.handle.name.startswith(export_scope):
# cs_def = critical_section_pb2.CriticalSectionDef()
# cs_def.critical_section_name = ops.strip_name_scope(
# self._handle.name, export_scope)
# return cs_def
# else:
# return None
# @staticmethod
# def from_proto(critical_section_def, import_scope=None):
# return CriticalSection(
# critical_section_def=critical_section_def, import_scope=import_scope)
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
# def _execution_to_proto_fn(execution_signature, export_scope=None):
# """Converts `_ExecutionSignature` to a `CriticalSectionExecutionDef`.
# # TODO(ebrevdo): Update for _ExecutionSignature storing resource list.
# Args:
# execution_signature: Instance of `_ExecutionSignature`.
# export_scope: The export scope, if any.
# Returns:
# An instance of `CriticalSectionExecutionDef`.
# """
# if (export_scope is None
# or execution_signature.op.name.startswith(export_scope)):
# op_def = critical_section_pb2.CriticalSectionExecutionDef()
# op_def.execute_in_critical_section_name = ops.strip_name_scope(
# execution_signature.op.name, export_scope)
# op_def.exclusive_resource_access = (
# execution_signature.exclusive_resource_access)
# return op_def
# else:
# return None
# def _execution_from_proto_fn(op_def, import_scope=None):
# """Converts a `CriticalSectionExecutionDef` to a `_ExecutionSignature`."""
# # TODO(ebrevdo): Update for _ExecutionSignature storing resource list.
# assert isinstance(
# op_def, critical_section_pb2.CriticalSectionExecutionDef)
# # Create from op_def.
# g = ops.get_default_graph()
# execution_op = g.as_graph_element(
# ops.prepend_name_scope(
# op_def.execute_in_critical_section_name,
# import_scope=import_scope))
# return _ExecutionSignature(
# op=execution_op,
# exclusive_resource_access=op_def.exclusive_resource_access)
# ops.register_proto_function(
# CRITICAL_SECTIONS,
# proto_type=critical_section_pb2.CriticalSectionDef,
# to_proto=CriticalSection.to_proto,
# from_proto=CriticalSection.from_proto)
# ops.register_proto_function(
# CRITICAL_SECTION_EXECUTIONS,
# proto_type=critical_section_pb2.CriticalSectionExecutionDef,
# to_proto=_execution_to_proto_fn,
# from_proto=_execution_from_proto_fn)