| import torch |
| import warnings |
| import weakref |
| from weakref import ReferenceType |
| from typing import Any, Callable, ContextManager, Iterable, List, Tuple, Dict, Optional, DefaultDict |
| from collections import defaultdict |
| import uuid |
| import contextlib |
| |
| __all__ = [ |
| "checkpoint", "checkpoint_sequential", "CheckpointFunction", |
| "check_backward_validity", "detach_variable", "get_device_states", |
| "set_device_states", "noop_context_fn", "set_checkpoint_early_stop" |
| ] |
| |
| def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: |
| if isinstance(inputs, tuple): |
| out = [] |
| for inp in inputs: |
| if not isinstance(inp, torch.Tensor): |
| out.append(inp) |
| continue |
| |
| x = inp.detach() |
| x.requires_grad = inp.requires_grad |
| out.append(x) |
| return tuple(out) |
| else: |
| raise RuntimeError( |
| "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__) |
| |
| |
| def check_backward_validity(inputs: Iterable[Any]) -> None: |
| if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)): |
| warnings.warn("None of the inputs have requires_grad=True. Gradients will be None") |
| |
| |
| # We can't know if the run_fn will internally move some args to different devices, |
| # which would require logic to preserve rng states for those devices as well. |
| # We could paranoically stash and restore ALL the rng states for all visible devices, |
| # but that seems very wasteful for most cases. Compromise: Stash the RNG state for |
| # the device of all Tensor args. |
| # |
| # To consider: maybe get_device_states and set_device_states should reside in torch/random.py? |
| def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]: |
| # This will not error out if "arg" is a CPU tensor or a non-tensor type because |
| # the conditionals short-circuit. |
| fwd_gpu_devices = list({arg.get_device() for arg in args |
| if isinstance(arg, torch.Tensor) and arg.is_cuda}) |
| |
| fwd_gpu_states = [] |
| for device in fwd_gpu_devices: |
| with torch.cuda.device(device): |
| fwd_gpu_states.append(torch.cuda.get_rng_state()) |
| |
| return fwd_gpu_devices, fwd_gpu_states |
| |
| |
| def set_device_states(devices, states) -> None: |
| for device, state in zip(devices, states): |
| with torch.cuda.device(device): |
| torch.cuda.set_rng_state(state) |
| |
| def _get_autocast_kwargs(): |
| gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), |
| "dtype": torch.get_autocast_gpu_dtype(), |
| "cache_enabled": torch.is_autocast_cache_enabled()} |
| |
| cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(), |
| "dtype": torch.get_autocast_cpu_dtype(), |
| "cache_enabled": torch.is_autocast_cache_enabled()} |
| |
| return gpu_autocast_kwargs, cpu_autocast_kwargs |
| |
| class CheckpointFunction(torch.autograd.Function): |
| |
| @staticmethod |
| def forward(ctx, run_function, preserve_rng_state, *args): |
| check_backward_validity(args) |
| ctx.run_function = run_function |
| ctx.preserve_rng_state = preserve_rng_state |
| # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. |
| ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() |
| if preserve_rng_state: |
| ctx.fwd_cpu_state = torch.get_rng_state() |
| # Don't eagerly initialize the cuda context by accident. |
| # (If the user intends that the context is initialized later, within their |
| # run_function, we SHOULD actually stash the cuda state here. Unfortunately, |
| # we have no way to anticipate this will happen before we run the function.) |
| ctx.had_cuda_in_fwd = False |
| if torch.cuda._initialized: |
| ctx.had_cuda_in_fwd = True |
| ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) |
| |
| # Save non-tensor inputs in ctx, keep a placeholder None for tensors |
| # to be filled out during the backward. |
| ctx.inputs = [] |
| ctx.tensor_indices = [] |
| tensor_inputs = [] |
| for i, arg in enumerate(args): |
| if torch.is_tensor(arg): |
| tensor_inputs.append(arg) |
| ctx.tensor_indices.append(i) |
| ctx.inputs.append(None) |
| else: |
| ctx.inputs.append(arg) |
| |
| ctx.save_for_backward(*tensor_inputs) |
| |
| with torch.no_grad(): |
| outputs = run_function(*args) |
| return outputs |
| |
| @staticmethod |
| def backward(ctx, *args): |
| if not torch.autograd._is_checkpoint_valid(): |
| raise RuntimeError( |
| "Checkpointing is not compatible with .grad() or when an `inputs` parameter" |
| " is passed to .backward(). Please use .backward() and do not pass its `inputs`" |
| " argument.") |
| # Copy the list to avoid modifying original list. |
| inputs = list(ctx.inputs) |
| tensor_indices = ctx.tensor_indices |
| tensors = ctx.saved_tensors |
| |
| # Fill in inputs with appropriate saved tensors. |
| for i, idx in enumerate(tensor_indices): |
| inputs[idx] = tensors[i] |
| |
| # Stash the surrounding rng state, and mimic the state that was |
| # present at this time during forward. Restore the surrounding state |
| # when we're done. |
| rng_devices = [] |
| if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: |
| rng_devices = ctx.fwd_gpu_devices |
| with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): |
| if ctx.preserve_rng_state: |
| torch.set_rng_state(ctx.fwd_cpu_state) |
| if ctx.had_cuda_in_fwd: |
| set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) |
| detached_inputs = detach_variable(tuple(inputs)) |
| with torch.enable_grad(), \ |
| torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ |
| torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): |
| outputs = ctx.run_function(*detached_inputs) |
| |
| if isinstance(outputs, torch.Tensor): |
| outputs = (outputs,) |
| |
| # run backward() with only tensor that requires grad |
| outputs_with_grad = [] |
| args_with_grad = [] |
| for i in range(len(outputs)): |
| if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: |
| outputs_with_grad.append(outputs[i]) |
| args_with_grad.append(args[i]) |
| if len(outputs_with_grad) == 0: |
| raise RuntimeError( |
| "none of output has requires_grad=True," |
| " this checkpoint() is not necessary") |
| torch.autograd.backward(outputs_with_grad, args_with_grad) |
| grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None |
| for inp in detached_inputs) |
| |
| return (None, None) + grads |
| |
| |
| def noop_context_fn(): |
| return contextlib.nullcontext(), contextlib.nullcontext() |
| |
| |
| def checkpoint( |
| function, |
| *args, |
| use_reentrant: bool = True, |
| context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, |
| **kwargs |
| ): |
| r"""Checkpoint a model or part of the model |
| |
| Checkpointing is a technique that trades compute for memory. Instead of |
| storing all intermediate activations of the entire computation graph for |
| the backward pass, the checkpointed part omits saving intermediate |
| activations and recomputes them during the backward pass. This can be |
| applied to any part of a model. |
| |
| There are currently two checkpointing implementations available, determined |
| by the :attr:`use_reentrant` parameter. It is recommended that you use |
| ``use_reentrant=False``. Please refer the note below for a discussion of |
| their differences. |
| |
| .. warning:: |
| |
| If the :attr:`function` invocation during the backward pass differs |
| from the forward pass, e.g., due to a global variable, the checkpointed |
| checkpointed version may not be equivalent, potentially causing an |
| error being raised or leading to silently incorrect gradients. |
| |
| .. warning:: |
| |
| If you are using the ``use_reentrant=True`` variant (this is currently |
| the default), please refer to the note below for important |
| considerations and potential limitations. |
| |
| .. note:: |
| |
| The reentrant variant of checkpoint (``use_reentrant=True``) and |
| the non-reentrant variant of checkpoint (``use_reentrant=False``) |
| differ in the following ways: |
| |
| * Non-reentrant checkpoint stops recomputation as soon as all needed |
| intermediate activations have been recomputed. This feature is enabled |
| by default, but can be disabled with :func:`set_checkpoint_early_stop`. |
| Reentrant checkpoint always recomputes :attr:`function` in its |
| entirety during the backward pass. |
| |
| * The reentrant variant does not record the autograd graph during the |
| forward pass, as it runs with the forward pass under |
| :func:`torch.no_grad`. The non-reentrant version does record the |
| autograd graph, allowing one to perform backward on the graph within |
| checkpointed regions. |
| |
| * The reentrant checkpoint only supports the |
| :func:`torch.autograd.backward` API for the backward pass without its |
| `inputs` argument, while the non-reentrant version supports all ways |
| of performing the backward pass. |
| |
| * At least one input and output must have ``requires_grad=True`` for the |
| reentrant variant. If this condition is unmet, the checkpointed part |
| of the model will not have gradients. The non-reentrant version does |
| not have this requirement. |
| |
| * The reentrant version does not consider tensors in nested structures |
| (e.g., custom objects, lists, dicts, etc) as participating in |
| autograd, while the non-reentrant version does. |
| |
| * The reentrant checkpoint does not support checkpointed regions with |
| detached tensors from the computational graph, whereas the |
| non-reentrant version does. For the reentrant variant, if the |
| checkpointed segment contains tensors detached using ``detach()`` or |
| with :func:`torch.no_grad`, the backward pass will raise an error. |
| This is because ``checkpoint`` makes all the outputs require gradients |
| and this causes issues when a tensor is defined to have no gradient in |
| the model. To avoid this, detach the tensors outside of the |
| ``checkpoint`` function. |
| |
| Args: |
| function: describes what to run in the forward pass of the model or |
| part of the model. It should also know how to handle the inputs |
| passed as the tuple. For example, in LSTM, if user passes |
| ``(activation, hidden)``, :attr:`function` should correctly use the |
| first input as ``activation`` and the second input as ``hidden`` |
| preserve_rng_state(bool, optional): Omit stashing and restoring |
| the RNG state during each checkpoint. |
| Default: ``True`` |
| use_reentrant(bool, optional): Use checkpointing |
| implementation that requires re-entrant autograd. |
| If ``use_reentrant=False`` is specified, ``checkpoint`` will use an |
| implementation that does not require re-entrant autograd. This |
| allows ``checkpoint`` to support additional functionality, such as |
| working as expected with ``torch.autograd.grad`` and support for |
| keyword arguments input into the checkpointed function. Note that future |
| versions of PyTorch will default to ``use_reentrant=False``. |
| Default: ``True`` |
| context_fn(Callable, optional): A callable returning a tuple of two |
| context managers. The function and its recomputation will be run |
| under the first and second context managers respectively. |
| This argument is only supported if ``use_reentrant=False``. |
| args: tuple containing inputs to the :attr:`function` |
| |
| Returns: |
| Output of running :attr:`function` on :attr:`*args` |
| """ |
| # Hack to mix *args with **kwargs in a python 2.7-compliant way |
| preserve = kwargs.pop('preserve_rng_state', True) |
| if kwargs and use_reentrant: |
| raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) |
| |
| if use_reentrant: |
| if context_fn is not noop_context_fn: |
| raise ValueError("Passing context_fn is only supported when use_reentrant=False.") |
| return CheckpointFunction.apply(function, preserve, *args) |
| else: |
| return _checkpoint_without_reentrant( |
| function, |
| preserve, |
| context_fn, |
| *args, |
| **kwargs, |
| ) |
| |
| |
| def checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwargs): |
| r"""A helper function for checkpointing sequential models. |
| |
| Sequential models execute a list of modules/functions in order |
| (sequentially). Therefore, we can divide such a model in various segments |
| and checkpoint each segment. All segments except the last will not store |
| the intermediate activations. The inputs of each checkpointed segment will |
| be saved for re-running the segment in the backward pass. |
| |
| .. warning:: |
| If you are using the ``use_reentrant=True` variant (this is the |
| default), please see :func:`~torch.utils.checkpoint.checkpoint` for |
| the important considerations and limitations of this variant. It is |
| recommended that you use ``use_reentrant=False``. |
| |
| .. warning: |
| Since PyTorch 1.4, it allows only one Tensor as the input and |
| intermediate outputs, just like :class:`torch.nn.Sequential`. |
| |
| Args: |
| functions: A :class:`torch.nn.Sequential` or the list of modules or |
| functions (comprising the model) to run sequentially. |
| segments: Number of chunks to create in the model |
| input: A Tensor that is input to :attr:`functions` |
| preserve_rng_state(bool, optional): Omit stashing and restoring |
| the RNG state during each checkpoint. |
| Default: ``True`` |
| use_reentrant(bool, optional): Use checkpointing |
| implementation that requires re-entrant autograd. |
| If ``use_reentrant=False`` is specified, ``checkpoint`` will use an |
| implementation that does not require re-entrant autograd. This |
| allows ``checkpoint`` to support additional functionality, such as |
| working as expected with ``torch.autograd.grad`` and support for |
| keyword arguments input into the checkpointed function. |
| Default: ``True`` |
| |
| Returns: |
| Output of running :attr:`functions` sequentially on :attr:`*inputs` |
| |
| Example: |
| >>> # xdoctest: +SKIP("stub") |
| >>> model = nn.Sequential(...) |
| >>> input_var = checkpoint_sequential(model, chunks, input_var) |
| """ |
| # Hack for keyword-only parameter in a python 2.7-compliant way |
| preserve = kwargs.pop('preserve_rng_state', True) |
| if kwargs: |
| raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) |
| |
| def run_function(start, end, functions): |
| def forward(input): |
| for j in range(start, end + 1): |
| input = functions[j](input) |
| return input |
| return forward |
| |
| if isinstance(functions, torch.nn.Sequential): |
| functions = list(functions.children()) |
| |
| segment_size = len(functions) // segments |
| # the last chunk has to be non-volatile |
| end = -1 |
| for start in range(0, segment_size * (segments - 1), segment_size): |
| end = start + segment_size - 1 |
| input = checkpoint( |
| run_function(start, end, functions), |
| input, |
| use_reentrant=use_reentrant, |
| preserve_rng_state=preserve |
| ) |
| return run_function(end + 1, len(functions) - 1, functions)(input) |
| |
| # NOTE [ Nestable Checkpoint ] |
| # |
| # The semantics of nested checkpoint can be defined by two basic rules. |
| # Following the two rules leads to an important implication that is central |
| # to motivating the design. |
| # |
| # Rule 1. Saved tensors are managed by inner-most checkpoint only and hidden |
| # from any outer layers of checkpoint. |
| # |
| # Rule 2. The inputs of inner checkpoints are treated as tensors saved to its |
| # parent checkpoint. |
| # |
| # Implication: To recompute any given saved tensor, we need to recompute all of |
| # the checkpoints wrapping it. |
| # |
| # Why is this implied? To unpack a saved tensor X during backward we need to |
| # recompute the inner-most checkpoint (#1), and in order to recompute that |
| # checkpoint I need to have its inputs, which are managed by that checkpoint's |
| # parent (#2), which thus also needs to be recomputed first. Continue this line |
| # of reasoning and we realize that in order to unpack X, all checkpoints that |
| # were active at the time X was saved need to be recomputed. (unless we have |
| # already done so in that backward for some other saved tensor). |
| # |
| # In practice, we use a noop autograd Function to save inputs as saved tensors. |
| # During unpack calling ctx.saved_tensor triggers the parent checkpoint to |
| # recompute. |
| # |
| # Rule 3. We should start recomputation as if there are no checkpoints currently |
| # active. Checkpoints encountered during recomputation are still |
| # respected. |
| # |
| # When we start recomputation, we push the saved variable hook meant for |
| # recomputation on the stack. See examples in Rule 6 for more context. |
| # |
| # * * * * |
| # |
| # Beyond the basic semantics specific to nested checkpoint, we impose several |
| # more constraints that may apply to checkpointing in general. |
| # |
| # Rule 4. Lifetime of recomputed tensors |
| # |
| # Recomputed tensors are considered specific to particular invocations |
| # of backward and are always cleared immediately as they are unpacked |
| # Particularly, we require this to happen even if retain_graph=True. |
| # |
| # [ Implementation details of Rule 4 ] |
| # |
| # If we were okay with recomputed tensors staying alive after backward is run |
| # with retain_graph=True, we would store recomputed variables as the values of a |
| # WeakKeyDictionary and pack strong references to the keys, so that as we |
| # backward, those packed keys would be cleared as long as retain_graph=False. |
| # Clearing the packed key clears the corresponding entry in the WKD. |
| # |
| # If we wish recomputed variables to be immediately cleared as we unpack them in |
| # the retain_graph=True case, we cannot rely on the packed keys to be cleared by |
| # backward automatically. Instead of packing the strong reference to the key |
| # directly, we pack a container object, which we manually clear as we unpack. |
| # |
| # An important detail is that if a second backward happens, the second |
| # recomputation needs to reset the container with a newly created key. |
| # |
| # Rule 5. Stop recomputation as soon as we've recomputed the saved tensors we |
| # know we need. |
| # |
| # [ Implementation details of Rule 5 ] |
| # |
| # During recomputation, raise an exception if the number of recomputed tensors |
| # matches the number of tensors that we expected to recompute. We wrap the |
| # recomputation call with a try-catch to catch this specific exception. See |
| # Rule #6 below for some examples. |
| # |
| # Rule 6. We support doing backward inside checkpoint context |
| # |
| # This section is just a bunch of random examples that we'd like to support, |
| # and comments on how that forced us to make certain design decisions. |
| # |
| # [ Basic case ] |
| # |
| # def fn(x): |
| # y = x.sin() |
| # z = y.cos() |
| # gx, = torch.autograd.grad(z, x, retains_grad=True) |
| # return gx, z |
| # |
| # out = checkpoint(fn)(inp) |
| # out.backward() |
| # |
| # Because z is saved by cos while checkpoint is enabled, it would not be |
| # actually saved, and so the .grad() call inside must trigger a recomputation. |
| # |
| # During recomputation the "inner pack hook" has two responsibilities: |
| # |
| # 1) As usual, populating the WeakKeyDictionary storing recomputed tensors |
| # 2) Pack the actual tensor (detached) so that one may perform backward on the |
| # recomputed graph. The tensors saved to this graph will live until the end |
| # of recomputation, or die earlier if someone performs backward with |
| # retain_graph=False. |
| # |
| # More generally performing backward on the recomputed graph occurs in the |
| # following cases: |
| # - If backward is performed inside forward, |
| # - During the original forward IF early-stop is disabled |
| # - During the original backward |
| # - If there are multiple .grad()/.backward() calls, we would perform backward |
| # on the recomputed graph even if early-stop is enabled (see the example below) |
| # |
| # [ Multiple backwards ] |
| # |
| # The example below shows what happens if during recomputation we find that some |
| # of the tensors we are trying to recompute have already been cleared. |
| # |
| # Spoiler: we don't do anything special, we just skip over them! |
| # |
| # def fn(x): |
| # y = x.sin() # (1) |
| # z = y.cos() # (2) |
| # gx, = torch.autograd.grad(z, x) # (3) |
| # w = x.sin() # (4) |
| # v = w.cos() # (5) |
| # gx2, = torch.autograd.grad(v, x) # (6) |
| # return x * gx * gx2 |
| # |
| # out = checkpoint(fn)(inp) |
| # |
| # In the code above fn is computed (potentially partially) 4 times in total. |
| # |
| # 1. Don't save x and y since we are inside a checkpoint. |
| # 2. Trigger a recompute of fn as we reach (3) since x and y weren't saved. |
| # 3. If early stop is enabled, stop at (2) |
| # 4. Continue original forward at (4), not saving x and w. |
| # 5. (5) triggers a recompute of fn |
| # 6. During recompute, we see that in the original graph, gx has already |
| # cleared x and y since backward is run at (3) without retain_graph=True |
| # We save x and w, however. |
| # 7. Continue with returning |
| |
| _enable_checkpoint_early_stop = True |
| |
| @contextlib.contextmanager |
| def set_checkpoint_early_stop(enable: bool): |
| """Context manager that sets whether checkpoint should stop recomputation |
| early. |
| |
| By default, non-reentrant checkpoint stops recomputation as soon as it |
| has computed all needed Tensors. This context manager can be used to disable |
| that feature if it is problematic for your specific application. |
| |
| This context manager only needs to be active when forward is run. It does |
| not need to be active during backward. |
| |
| Example:: |
| |
| >>> # xdoctest: +SKIP(failing) |
| >>> message = "saved tensors default hooks are disabled" |
| >>> with set_checkpoint_early_stop(False): |
| ... # Any checkpoint under this context manager will respect this |
| ... # context manager, even if its backward is performed outside. |
| ... out = checkpoint(fn, inputs) |
| ... |
| >>> out.backward() |
| """ |
| global _enable_checkpoint_early_stop |
| try: |
| prev = _enable_checkpoint_early_stop |
| _enable_checkpoint_early_stop = enable |
| yield |
| finally: |
| _enable_checkpoint_early_stop = prev |
| |
| class _Handle(): |
| pass |
| |
| class _Holder(): |
| def __init__(self): |
| self.handles: Dict[int, Optional[_Handle]] = dict() |
| |
| class _NoopSaveInputs(torch.autograd.Function): |
| @staticmethod |
| def forward(*args): |
| return torch.empty((0,)) |
| |
| @staticmethod |
| def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: |
| # Only tensors can be saved with ctx.save_for_backward, everything else |
| # is captured by get_args, which is saved directly on ctx |
| tensor_indices, tensors = zip(*[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)]) |
| idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)} |
| # args but with tensors replaced with None as placeholders |
| args = [None if isinstance(o, torch.Tensor) else o for o in inputs] |
| |
| def get_args(saved_tensors): |
| # restore the placeholders with the original tensors grabbed from |
| # ctx.saved_tensors (which may be saved on a parent checkpoint if |
| # this checkpoint is nested, and that would trigger a recursive |
| # unpack!) |
| ret = [saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o for i, o in enumerate(args)] |
| # grab the tail since we also saved the dummy to avoid having to explicitly |
| # handle the case where there are no tensor inputs |
| return ret[1:] |
| |
| ctx.get_args = get_args |
| ctx.save_for_backward(*tensors) |
| |
| @staticmethod |
| def backward(ctx, *grad_outputs): |
| raise AssertionError("Did not expect to backward on this graph") |
| |
| class _CheckpointFrame(): |
| def __init__(self, recompute_fn, early_stop): |
| self.recompute_fn = recompute_fn |
| self.input_saver = None |
| self.weak_holders: List[ReferenceType] = [] |
| # We store this as a weakkeydictionary so that in the case of a partial |
| # backward, the entries in the dict are cleared alongside the Holder |
| # which will be removed when the SavedVariable is cleared. |
| self.recomputed: DefaultDict[int, weakref.WeakKeyDictionary[_Handle, torch.Tensor]] = \ |
| defaultdict(weakref.WeakKeyDictionary) |
| # We need both recomp_counter and recomputed since they can diverge |
| # https://github.com/pytorch/pytorch/pull/90105#discussion_r1135889885 |
| self.recomp_counter: DefaultDict[int, int] = defaultdict(int) |
| self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool) |
| |
| # See Rule 5 |
| self.early_stop = early_stop |
| |
| # See Rule 5 |
| class _StopRecomputationError(Exception): |
| pass |
| |
| class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks): |
| def __init__(self, target_frame_ref: ReferenceType, gid: int): |
| def pack_hook(x): |
| target_frame = target_frame_ref() |
| assert target_frame is not None |
| recomp_idx = target_frame.recomp_counter[gid] |
| target_frame.recomp_counter[gid] += 1 |
| |
| if recomp_idx >= len(target_frame.weak_holders): |
| # We run into this case when early stop is not enabled and do |
| # grad within checkpoint. |
| return x.detach() |
| holder = target_frame.weak_holders[recomp_idx]() |
| |
| if holder is not None: |
| # See Rule 6: [ Multiple backwards ] above |
| if holder.handles.get(gid, None) is None: |
| holder.handles[gid] = _Handle() |
| target_frame.recomputed[gid][holder.handles[gid]] = x.detach() |
| |
| if target_frame.early_stop and \ |
| target_frame.recomp_counter[gid] == len(target_frame.weak_holders): |
| raise _StopRecomputationError() |
| # See Rule 6: [ Basic case ] above |
| return x.detach() |
| |
| def unpack_hook(x): |
| # See Rule 6: [ Basic case ] above for an example of when the graph |
| # created during recomputation could be backwarded. |
| return x |
| |
| super().__init__(pack_hook, unpack_hook) |
| |
| class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks): |
| def __init__(self, frame): |
| def pack_hook(_unused_x): |
| # See Rule 4 above |
| holder = _Holder() |
| frame.weak_holders.append(weakref.ref(holder)) |
| return holder |
| |
| def unpack_hook(holder): |
| gid = torch._C._current_graph_task_id() |
| if gid == -1: |
| # generate a temporary id if we trigger unpack outside of a backward call |
| gid = int(uuid.uuid4()) |
| |
| if not frame.is_recomputed[gid]: |
| ctx = frame.input_saver.grad_fn |
| args = ctx.get_args(ctx.saved_tensors) |
| |
| try: |
| with _recomputation_hook(weakref.ref(frame), gid), torch.autograd.enable_grad(): |
| frame.recompute_fn(*args) |
| if frame.early_stop: |
| raise AssertionError("if early stop is enabled, we don't expect to reach here") |
| except _StopRecomputationError: |
| pass |
| frame.is_recomputed[gid] = True |
| |
| if holder.handles[gid] is None: |
| raise RuntimeError( |
| "If you are calling ctx.saved_tensor in backward, make sure to do so only once. " |
| "Otherwise please open an issue with details on your use case." |
| ) |
| if holder.handles[gid] not in frame.recomputed[gid]: |
| raise RuntimeError( |
| "Attempt to retrieve a tensor saved by autograd multiple times without checkpoint" |
| " recomputation being triggered in between, this is not currently supported. Please" |
| " open an issue with details on your use case." |
| ) |
| ret = frame.recomputed[gid][holder.handles[gid]] |
| holder.handles[gid] = None |
| return ret |
| |
| super().__init__(pack_hook, unpack_hook) |
| |
| # NB: this helper wraps fn before calling checkpoint_impl. kwargs and |
| # saving/restoring of global state is handled here. |
| def _checkpoint_without_reentrant( |
| fn, |
| preserve_rng_state=True, |
| context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, |
| *args, |
| **kwargs |
| ): |
| """Checkpointining without re-entrant autograd |
| Args: |
| function: describes what to run in the forward pass of the model or |
| part of the model. It should also know how to handle the inputs |
| passed as the tuple. For example, in LSTM, if user passes |
| ``(activation, hidden)``, :attr:`function` should correctly use the |
| first input as ``activation`` and the second input as ``hidden`` |
| preserve_rng_state(bool, optional): Omit stashing and restoring |
| the RNG state during each checkpoint. |
| Default: ``True`` |
| context_fn(Callable, optional): A callable returning a tuple of two |
| context managers. The function and its recomputation will be run |
| under the first and second context managers respectively. |
| *args: Arguments to pass in to the given ``function``. |
| **kwargs: Keyword arguments to pass into the given ``function``. |
| """ |
| forward_context, recompute_context = context_fn() |
| # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. |
| gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs() |
| |
| if preserve_rng_state: |
| fwd_cpu_state = torch.get_rng_state() |
| # Don't eagerly initialize the cuda context by accident. |
| # (If the user intends that the context is initialized later, within their |
| # run_function, we SHOULD actually stash the cuda state here. Unfortunately, |
| # we have no way to anticipate this will happen before we run the function. |
| # If they do so, we raise an error.) |
| had_cuda_in_fwd = False |
| if torch.cuda._initialized: |
| had_cuda_in_fwd = True |
| fwd_gpu_devices, fwd_gpu_states = get_device_states(*args) |
| |
| def recompute_fn(*inputs): |
| kwargs, *args = inputs |
| # This will be called later during recomputation. This wrapping enables |
| # the necessary global state to be captured. |
| rng_devices = [] |
| if preserve_rng_state and had_cuda_in_fwd: |
| rng_devices = fwd_gpu_devices |
| with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state): |
| if preserve_rng_state: |
| torch.set_rng_state(fwd_cpu_state) |
| if had_cuda_in_fwd: |
| set_device_states(fwd_gpu_devices, fwd_gpu_states) |
| |
| with torch.cuda.amp.autocast(**gpu_autocast_kwargs), \ |
| torch.cpu.amp.autocast(**cpu_autocast_kwargs), \ |
| recompute_context: |
| fn(*args, **kwargs) |
| |
| new_frame = _CheckpointFrame(recompute_fn, _enable_checkpoint_early_stop) |
| dummy = torch.empty((0,), requires_grad=True) |
| new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args) |
| |
| # When ambient grad_mode is False |
| if new_frame.input_saver.grad_fn is None: |
| return fn(*args, **kwargs) |
| |
| with _checkpoint_hook(new_frame), \ |
| forward_context: |
| ret = fn(*args, **kwargs) |
| |
| if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd: |
| # Cuda was not initialized before running the forward, so we didn't |
| # stash the CUDA state. |
| raise RuntimeError( |
| "PyTorch's CUDA state was initialized in the forward pass " |
| "of a Checkpoint, which is not allowed. Please open an issue " |
| "if you need this feature.") |
| |
| return ret |