| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Critical Section object and execution logic.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import collections |
| |
| # TODO(ebrevdo): Re-enable once CriticalSection is in core. |
| # from tensorflow.core.protobuf import critical_section_pb2 |
| |
| from tensorflow.python.eager import context |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import gen_resource_variable_ops |
| from tensorflow.python.ops import tensor_array_ops |
| from tensorflow.python.util import nest |
| from tensorflow.python.util import object_identity |
| from tensorflow.python.util.tf_export import tf_export |
| |
| |
| __all__ = ["CriticalSection"] |
| |
| |
| # Graph Keys |
| CRITICAL_SECTIONS = "critical_sections" |
| CRITICAL_SECTION_EXECUTIONS = "critical_section_executions" |
| |
| |
| class _ExecutionSignature( |
| collections.namedtuple("_ExecutionSignature", |
| ("op", "handle", |
| "resources", "exclusive_resource_access"))): |
| """A class storing an `ExecuteInCriticalResource` op and associated attrs.""" |
| pass |
| |
| |
| def _identity(x): |
| """Identity op that recognizes `TensorArray`, `Operation`, and `Tensor`.""" |
| if isinstance(x, tensor_array_ops.TensorArray): |
| return x.identity() |
| elif isinstance(x, ops.Operation): |
| return control_flow_ops.group(x) |
| elif context.executing_eagerly() and x is None: |
| return None |
| else: |
| return array_ops.identity(x) |
| |
| |
| def _get_colocation(op): |
| """Get colocation symbol from op, if any.""" |
| try: |
| return op.get_attr("_class") |
| except ValueError: |
| return None |
| |
| |
| @tf_export("CriticalSection") |
| class CriticalSection(object): |
| """Critical section. |
| |
| A `CriticalSection` object is a resource in the graph which executes subgraphs |
| in **serial** order. A common example of a subgraph one may wish to run |
| exclusively is the one given by the following function: |
| |
| ```python |
| v = resource_variable_ops.ResourceVariable(0.0, name="v") |
| |
| def count(): |
| value = v.read_value() |
| with tf.control_dependencies([value]): |
| with tf.control_dependencies([v.assign_add(1)]): |
| return tf.identity(value) |
| ``` |
| |
| Here, a snapshot of `v` is captured in `value`; and then `v` is updated. |
| The snapshot value is returned. |
| |
| If multiple workers or threads all execute `count` in parallel, there is no |
| guarantee that access to the variable `v` is atomic at any point within |
| any thread's calculation of `count`. In fact, even implementing an atomic |
| counter that guarantees that the user will see each value `0, 1, ...,` is |
| currently impossible. |
| |
| The solution is to ensure any access to the underlying resource `v` is |
| only processed through a critical section: |
| |
| ```python |
| cs = CriticalSection() |
| f1 = cs.execute(count) |
| f2 = cs.execute(count) |
| output = f1 + f2 |
| session.run(output) |
| ``` |
| The functions `f1` and `f2` will be executed serially, and updates to `v` |
| will be atomic. |
| |
| **NOTES** |
| |
| All resource objects, including the critical section and any captured |
| variables of functions executed on that critical section, will be |
| colocated to the same device (host and cpu/gpu). |
| |
| When using multiple critical sections on the same resources, there is no |
| guarantee of exclusive access to those resources. This behavior is disallowed |
| by default (but see the kwarg `exclusive_resource_access`). |
| |
| For example, running the same function in two separate critical sections |
| will not ensure serial execution: |
| |
| ```python |
| v = tf.compat.v1.get_variable("v", initializer=0.0, use_resource=True) |
| def accumulate(up): |
| x = v.read_value() |
| with tf.control_dependencies([x]): |
| with tf.control_dependencies([v.assign_add(up)]): |
| return tf.identity(x) |
| ex1 = CriticalSection().execute( |
| accumulate, 1.0, exclusive_resource_access=False) |
| ex2 = CriticalSection().execute( |
| accumulate, 1.0, exclusive_resource_access=False) |
| bad_sum = ex1 + ex2 |
| sess.run(v.initializer) |
| sess.run(bad_sum) # May return 0.0 |
| ``` |
| """ |
| |
| def __init__(self, name=None, shared_name=None, |
| critical_section_def=None, import_scope=None): |
| """Creates a critical section.""" |
| context.ensure_initialized() |
| if critical_section_def and name is not None: |
| raise ValueError("critical_section_def and shared_name are " |
| "mutually exclusive.") |
| if critical_section_def: |
| self._init_from_proto(critical_section_def, import_scope=import_scope) |
| else: |
| self._init_from_args(name, shared_name) |
| |
| def _init_from_proto(self, critical_section_def, import_scope): # pylint: disable=invalid-name |
| raise NotImplementedError("Not yet implemented") |
| # TODO(ebrevdo): Re-enable once CriticalSection is in core. |
| # assert isinstance( |
| # critical_section_def, critical_section_pb2.CriticalSectionDef) |
| # # Create from critical_section_def. |
| # g = ops.get_default_graph() |
| # self._handle = g.as_graph_element( |
| # ops.prepend_name_scope( |
| # critical_section_def.critical_section_name, |
| # import_scope=import_scope)) |
| |
| def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name |
| """Initialize the CriticalSection from constructor arguments.""" |
| with ops.name_scope(name, "CriticalSection", []) as name: |
| with ops.init_scope(): |
| # pylint: disable=protected-access |
| container = ops.get_default_graph()._container |
| # pylint: enable=protected-access |
| if shared_name is None: |
| shared_name = name |
| if container is None: |
| container = "" |
| self._handle = gen_resource_variable_ops.mutex_v2( |
| shared_name=shared_name, container=container, name=name) |
| |
| if not context.executing_eagerly(): |
| ops.add_to_collections(CRITICAL_SECTIONS, self) |
| |
| @property |
| def name(self): |
| return self._handle.op.name |
| |
| def execute(self, fn, exclusive_resource_access=True, name=None): |
| """Execute function `fn()` inside the critical section. |
| |
| `fn` should not accept any arguments. To add extra arguments to when |
| calling `fn` in the critical section, create a lambda: |
| |
| ```python |
| critical_section.execute(lambda: fn(*my_args, **my_kwargs)) |
| ``` |
| |
| Args: |
| fn: The function to execute. Must return at least one tensor. |
| exclusive_resource_access: Whether the resources required by |
| `fn` should be exclusive to this `CriticalSection`. Default: `True`. |
| You may want to set this to `False` if you will be accessing a |
| resource in read-only mode in two different CriticalSections. |
| name: The name to use when creating the execute operation. |
| |
| Returns: |
| The tensors returned from `fn()`. |
| |
| Raises: |
| ValueError: If `fn` attempts to lock this `CriticalSection` in any nested |
| or lazy way that may cause a deadlock. |
| ValueError: If `exclusive_resource_access == True` and |
| another `CriticalSection` has an execution requesting the same |
| resources as `fn``. Note, even if `exclusive_resource_access` is |
| `True`, if another execution in another `CriticalSection` was created |
| without `exclusive_resource_access=True`, a `ValueError` will be raised. |
| """ |
| with ops.name_scope(name, "critical_section_execute", []): |
| |
| # Ensure that mutex locking only happens *after* all args and |
| # kwargs have been executed. This avoids certain types of deadlocks. |
| lock = gen_resource_variable_ops.mutex_lock(self._handle) |
| |
| if not context.executing_eagerly(): |
| # NOTE(ebrevdo): This is to ensure we don't pick up spurious |
| # Operations created by other threads. |
| with ops.get_default_graph()._lock: # pylint: disable=protected-access |
| existing_ops = ops.get_default_graph().get_operations() |
| with ops.control_dependencies([lock]): |
| r = fn() |
| # TODO(ebrevdo): If creating critical sections in a python loop, this |
| # makes graph creation time quadratic. Revisit if this |
| # becomes a problem. |
| created_ops = (set(ops.get_default_graph().get_operations()) |
| .difference(existing_ops)) |
| else: |
| with ops.control_dependencies([lock]): |
| r = fn() |
| |
| if not context.executing_eagerly(): |
| self._add_control_dependencies_to_lock(created_ops, lock.op) |
| |
| # captured_resources is a list of resources that are directly |
| # accessed only by ops created during fn(), not by any |
| # ancestors of those ops in the graph. |
| captured_resources = object_identity.ObjectIdentitySet([ |
| input_ for op in created_ops |
| for input_ in op.inputs |
| if input_.dtype == dtypes.resource |
| ]) |
| |
| # NOTE(ebrevdo): The only time self._is_self_handle() is True |
| # in this call is if one of the recently created ops, within |
| # the execute(), themselves attempt to access the |
| # CriticalSection. This will cause a deadlock. |
| if any(self._is_self_handle(x) for x in captured_resources): |
| raise ValueError("The function fn attempts to directly access the " |
| "CriticalSection in which it would be running. " |
| "This is illegal and would cause deadlocks.") |
| |
| self._check_multiple_access_to_resources( |
| captured_resources, exclusive_resource_access) |
| |
| r_flat = [_identity(x) for x in nest.flatten(r)] |
| |
| with ops.control_dependencies(r_flat): |
| # The identity must run on the same machine as self._handle |
| with ops.colocate_with(self._handle): |
| # Do not use array_ops.identity as there are special |
| # optimizations within TensorFlow which seem to elide it |
| # even when optimizations are disabled(!). |
| ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock( |
| lock) |
| |
| # Make sure that if any element of r is accessed, all of |
| # them are executed together. |
| r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r))) |
| |
| with ops.control_dependencies([ensure_lock_exists]): |
| outputs = nest.map_structure(_identity, r) |
| |
| if not context.executing_eagerly(): |
| signature = _ExecutionSignature( |
| op=lock.op, |
| handle=self._handle, |
| resources=list(captured_resources), |
| exclusive_resource_access=exclusive_resource_access) |
| ops.add_to_collections( |
| CRITICAL_SECTION_EXECUTIONS, signature) |
| |
| return outputs |
| |
| def _add_control_dependencies_to_lock(self, created_ops, lock_op): |
| """To avoid deadlocks, all args must be executed before lock_op.""" |
| # Get all arguments (explicit and captured) of all ops created by fn(). |
| all_args = set([input_.op for op in created_ops for input_ in op.inputs]) |
| all_args.update( |
| input_op for op in created_ops for input_op in op.control_inputs) |
| # Unfortunately, we can't use sets throughout because TF seems to |
| # create new Operation objects for the same op sometimes; and we |
| # can't rely on id(op). |
| |
| # pylint: disable=protected-access |
| all_args_dict = dict((op._id, op) for op in all_args) |
| |
| # Remove ops created within fn, or that lock_op already has a |
| # control dependency on. Also remove a possible self-loop. |
| for op in created_ops: |
| all_args_dict.pop(op._id, None) |
| for op in lock_op.control_inputs: |
| all_args_dict.pop(op._id, None) |
| for input_ in lock_op.inputs: |
| all_args_dict.pop(input_.op._id, None) |
| all_args_dict.pop(lock_op._id, None) |
| |
| all_args = all_args_dict.values() |
| |
| if not all_args: |
| # No control dependencies to add; return early. |
| return |
| |
| # This group is important: it ensures that any ops in all_args |
| # outside the control context of the lock_op (and this fn, which |
| # runs in the same context) are added to this context before |
| # being added to the control dependencies of lock_op. |
| all_args = control_flow_ops.group(*all_args) |
| |
| lock_op._add_control_input(all_args) |
| # pylint: enable=protected-access |
| |
| def _is_self_handle(self, x): |
| """Check if the tensor `x` is the same Mutex as `self._handle`.""" |
| if isinstance(x, ops.EagerTensor): |
| return x is self._handle |
| return (x.op.type == "MutexV2" |
| # blank shared_name means the op will create a unique one. |
| and x.op.get_attr("shared_name") |
| and (x.op.get_attr("shared_name") == |
| self._handle.op.get_attr("shared_name")) |
| and (x.op.device == self._handle.op.device |
| or _get_colocation(x.op) == _get_colocation(self._handle.op))) |
| |
| def _check_multiple_access_to_resources( |
| self, captured_resources, exclusive_resource_access): |
| """Raise if captured_resources are accessed by another CriticalSection. |
| |
| Args: |
| captured_resources: Set of tensors of type resource. |
| exclusive_resource_access: Whether this execution requires exclusive |
| resource access. |
| |
| Raises: |
| ValueError: If any tensors in `captured_resources` are also accessed |
| by another `CriticalSection`, and at least one of them requires |
| exclusive resource access. |
| """ |
| # Collections and op introspection does not work in eager |
| # mode. This is generally ok; since eager mode (as of |
| # writing) executes sequentially anyway. |
| for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): |
| if self._is_self_handle(sg.handle): |
| # Other executions in the same critical section are allowed. |
| continue |
| if not (exclusive_resource_access or sg.exclusive_resource_access): |
| # Neither execution requested exclusive access. |
| continue |
| resource_intersection = captured_resources.intersection(sg.resources) |
| if resource_intersection: |
| raise ValueError( |
| "This execution would access resources: %s. Either this " |
| "lock (CriticalSection: %s) or lock '%s' " |
| "(CriticalSection: %s) requested exclusive resource access " |
| "of this resource. Did you mean to call execute with keyword " |
| "argument exclusive_resource_access=False?" % |
| (list(resource_intersection), self._handle, sg, sg.handle)) |
| |
| # TODO(ebrevdo): Re-enable once CriticalSection is in core. |
| |
| # def to_proto(self, export_scope=None): |
| # """Converts a `CriticalSection` to a `CriticalSectoinDef` protocol buffer. |
| |
| # Args: |
| # export_scope: Optional `string`. Name scope to remove. |
| |
| # Returns: |
| # A `CriticalSectionDef` protocol buffer, or `None` if the |
| # `CriticalSection` is not in the specified name scope. |
| # """ |
| # if export_scope is None or self.handle.name.startswith(export_scope): |
| # cs_def = critical_section_pb2.CriticalSectionDef() |
| # cs_def.critical_section_name = ops.strip_name_scope( |
| # self._handle.name, export_scope) |
| # return cs_def |
| # else: |
| # return None |
| |
| # @staticmethod |
| # def from_proto(critical_section_def, import_scope=None): |
| # return CriticalSection( |
| # critical_section_def=critical_section_def, import_scope=import_scope) |
| |
| |
| # TODO(ebrevdo): Re-enable once CriticalSection is in core. |
| |
| # def _execution_to_proto_fn(execution_signature, export_scope=None): |
| # """Converts `_ExecutionSignature` to a `CriticalSectionExecutionDef`. |
| # # TODO(ebrevdo): Update for _ExecutionSignature storing resource list. |
| |
| # Args: |
| # execution_signature: Instance of `_ExecutionSignature`. |
| # export_scope: The export scope, if any. |
| |
| # Returns: |
| # An instance of `CriticalSectionExecutionDef`. |
| # """ |
| # if (export_scope is None |
| # or execution_signature.op.name.startswith(export_scope)): |
| # op_def = critical_section_pb2.CriticalSectionExecutionDef() |
| # op_def.execute_in_critical_section_name = ops.strip_name_scope( |
| # execution_signature.op.name, export_scope) |
| # op_def.exclusive_resource_access = ( |
| # execution_signature.exclusive_resource_access) |
| # return op_def |
| # else: |
| # return None |
| |
| |
| # def _execution_from_proto_fn(op_def, import_scope=None): |
| # """Converts a `CriticalSectionExecutionDef` to a `_ExecutionSignature`.""" |
| # # TODO(ebrevdo): Update for _ExecutionSignature storing resource list. |
| # assert isinstance( |
| # op_def, critical_section_pb2.CriticalSectionExecutionDef) |
| |
| # # Create from op_def. |
| # g = ops.get_default_graph() |
| # execution_op = g.as_graph_element( |
| # ops.prepend_name_scope( |
| # op_def.execute_in_critical_section_name, |
| # import_scope=import_scope)) |
| # return _ExecutionSignature( |
| # op=execution_op, |
| # exclusive_resource_access=op_def.exclusive_resource_access) |
| |
| # ops.register_proto_function( |
| # CRITICAL_SECTIONS, |
| # proto_type=critical_section_pb2.CriticalSectionDef, |
| # to_proto=CriticalSection.to_proto, |
| # from_proto=CriticalSection.from_proto) |
| |
| # ops.register_proto_function( |
| # CRITICAL_SECTION_EXECUTIONS, |
| # proto_type=critical_section_pb2.CriticalSectionExecutionDef, |
| # to_proto=_execution_to_proto_fn, |
| # from_proto=_execution_from_proto_fn) |