| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Gradient tape utilities.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import contextlib |
| |
| from tensorflow.python import pywrap_tfe |
| from tensorflow.python.util.lazy_loader import LazyLoader |
| |
| # There is a circular dependency between this, ops.py, and |
| # distribution_strategy_context. |
| # TODO(b/117329403): Remove this circular dependency. |
| distribution_strategy_context = LazyLoader( |
| "distribution_strategy_context", globals(), |
| "tensorflow.python.distribute." |
| "distribution_strategy_context") |
| |
| |
| class Tape(object): |
| """Represents a gradient propagation trace.""" |
| |
| def __init__(self, tape): |
| self._tape = tape |
| |
| def watched_variables(self): |
| return pywrap_tfe.TFE_Py_TapeWatchedVariables(self._tape) |
| |
| |
| def push_new_tape(persistent=False, watch_accessed_variables=True): |
| """Pushes a new tape onto the tape stack.""" |
| tape = pywrap_tfe.TFE_Py_TapeSetNew(persistent, watch_accessed_variables) |
| return Tape(tape) |
| |
| |
| def push_tape(tape): |
| """Pushes an existing tape onto the tape stack.""" |
| pywrap_tfe.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access |
| |
| |
| def watch(tape, tensor): |
| """Marks this tensor to be watched by the given tape.""" |
| pywrap_tfe.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access |
| |
| |
| class VariableWatcher(object): |
| """A scope that tracks all trainable variable accesses within it. |
| |
| This explicitly ignores variables that are not marked as trainable. |
| |
| Sample usage: |
| |
| var = tf.Variable(0.0) |
| with VariableWatcher() as variable_watcher: |
| var.assign_add(1.0) |
| |
| assert variable_watcher.watched_variables == [var] |
| """ |
| |
| def __init__(self): |
| self._variable_watcher = None |
| |
| def __enter__(self): |
| self._variable_watcher = pywrap_tfe.TFE_Py_VariableWatcherNew() |
| return self |
| |
| def __exit__(self, typ, value, traceback): |
| pywrap_tfe.TFE_Py_VariableWatcherRemove(self._variable_watcher) |
| |
| def watched_variables(self): |
| """Returns a tuple of variables accessed under this scope.""" |
| return pywrap_tfe.TFE_Py_VariableWatcherWatchedVariables( |
| self._variable_watcher) |
| |
| |
| def watch_variable(tape, variable): |
| """Marks this variable to be watched by the given tape.""" |
| strategy, context = ( |
| distribution_strategy_context.get_strategy_and_replica_context()) |
| if context: |
| variables = [strategy.extended.value_container(variable)] |
| else: |
| variables = strategy.experimental_local_results(variable) |
| for var in variables: |
| pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access |
| pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var) |
| |
| |
| def variable_accessed(variable): |
| """Notifies all tapes in the stack that a variable has been accessed. |
| |
| Args: |
| variable: variable to be watched. |
| """ |
| strategy, context = ( |
| distribution_strategy_context.get_strategy_and_replica_context()) |
| if context: |
| variables = [strategy.extended.value_container(variable)] |
| else: |
| variables = strategy.experimental_local_results(variable) |
| for var in variables: |
| pywrap_tfe.TFE_Py_TapeVariableAccessed(var) |
| pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var) |
| |
| |
| def variables_accessed(variables): |
| """Notifies all tapes in the stack that variables have been accessed. |
| |
| Only trainable variables are marked as accessed. |
| |
| Args: |
| variables: iterable of variables to mark as accessed. |
| """ |
| strategy, context = ( |
| distribution_strategy_context.get_strategy_and_replica_context()) |
| accessed = [] |
| if context: |
| accessed = [strategy.extended.value_container(variable) |
| for variable in variables if variable.trainable] |
| else: |
| for variable in variables: |
| if variable.trainable: |
| accessed.extend(strategy.experimental_local_results(variable)) |
| |
| for var in accessed: |
| pywrap_tfe.TFE_Py_TapeVariableAccessed(var) |
| pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var) |
| |
| |
| def pop_tape(tape): |
| """Pops the given tape in the stack.""" |
| pywrap_tfe.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access |
| |
| |
| @contextlib.contextmanager |
| def stop_recording(): |
| """Stop all gradient recording (backprop and forwardprop).""" |
| is_stopped = pywrap_tfe.TFE_Py_TapeSetIsStopped() |
| try: |
| if not is_stopped: |
| pywrap_tfe.TFE_Py_TapeSetStopOnThread() |
| yield |
| finally: |
| if not is_stopped: |
| pywrap_tfe.TFE_Py_TapeSetRestartOnThread() |
| |
| |
| def should_record_backprop(tensors): |
| """Returns true if any tape in the stack watches any of these tensors. |
| |
| Only takes GradientTapes into account, not forward accumulators. |
| |
| Args: |
| tensors: Tensors to check, typically inputs to an operation. |
| |
| Returns: |
| Boolean, whether any tape watches any of `tensors`. |
| """ |
| return pywrap_tfe.TFE_Py_TapeSetShouldRecordBackprop(tensors) |
| |
| |
| def record_operation(op_type, output_tensors, input_tensors, backward_function, |
| forward_function=None): |
| """Records the operation on all tapes in the stack.""" |
| pywrap_tfe.TFE_Py_TapeSetRecordOperation(op_type, output_tensors, |
| input_tensors, backward_function, |
| forward_function) |
| |
| |
| def record_operation_backprop_only(op_type, output_tensors, input_tensors, |
| backward_function): |
| """Records the operation on all backward tapes in the stack.""" |
| pywrap_tfe.TFE_Py_TapeSetRecordOperationBackprop(op_type, output_tensors, |
| input_tensors, |
| backward_function) |
| |
| |
| def record_operation_forwardprop_only(op_type, output_tensors, input_tensors, |
| backward_function, |
| forwardprop_output_indices): |
| """Records the operation on all forward accumulators in the stack. |
| |
| Args: |
| op_type: a string for the operation type, used in the backprop code |
| output_tensors: a list of Python Tensor objects output by the operation |
| input_tensors: a list of input Tensors to the recorded operation |
| backward_function: the function to be called to, given the gradients of the |
| output tensors, produce the gradients of the input tensors. This function |
| is automatically transposed to produce output gradients given input |
| gradients. |
| forwardprop_output_indices: indicates any output_tensors which contain JVPs. |
| Typically these will have come from TFE_Py_PackForwardGradients. May be |
| None or an empty sequence if there are no JVP outputs from the operation. |
| """ |
| pywrap_tfe.TFE_Py_TapeSetRecordOperationForwardprop( |
| op_type, output_tensors, input_tensors, backward_function, |
| forwardprop_output_indices) |
| |
| |
| def delete_trace(tensor_id): |
| """Deletes traces for this Tensor from all tapes in the stack.""" |
| pywrap_tfe.TFE_Py_TapeSetDeleteTrace(tensor_id) |
| |
| |
| def could_possibly_record(): |
| """Returns True if any tape is active.""" |
| return not pywrap_tfe.TFE_Py_TapeSetIsEmpty() |