|  | # mypy: allow-untyped-defs | 
|  | import contextlib | 
|  | import platform | 
|  | import uuid | 
|  | import warnings | 
|  | import weakref | 
|  | from collections import defaultdict | 
|  | from typing import *  # noqa: F403 | 
|  | import enum | 
|  | from weakref import ReferenceType | 
|  |  | 
|  | import torch | 
|  | import torch.fx.traceback as fx_traceback | 
|  | from torch._functorch._aot_autograd.functional_utils import is_fun | 
|  | from torch.utils._pytree import tree_map | 
|  | from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode | 
|  | from torch.utils._python_dispatch import TorchDispatchMode | 
|  |  | 
|  | __all__ = [ | 
|  | "checkpoint", | 
|  | "checkpoint_sequential", | 
|  | "CheckpointError", | 
|  | "CheckpointFunction", | 
|  | "check_backward_validity", | 
|  | "detach_variable", | 
|  | "get_device_states", | 
|  | "set_device_states", | 
|  | "noop_context_fn", | 
|  | "set_checkpoint_early_stop", | 
|  | "DefaultDeviceType", | 
|  | "set_checkpoint_debug_enabled", | 
|  | "CheckpointPolicy", | 
|  | "SelectiveCheckpointContext", | 
|  | "create_selective_checkpoint_contexts", | 
|  | "SAC_IGNORED_OPS", | 
|  | ] | 
|  |  | 
|  | _DEFAULT_DETERMINISM_MODE = "default" | 
|  |  | 
|  | _checkpoint_debug_enabled: Optional[bool] = None | 
|  |  | 
|  |  | 
|  | @contextlib.contextmanager | 
|  | def set_checkpoint_debug_enabled(enabled: Optional[bool]): | 
|  | """ | 
|  | Context manager that sets whether checkpoint should print additional debug | 
|  | information when running. See the ``debug`` flag for | 
|  | :func:`~torch.utils.checkpoint.checkpoint` for more information. Note that | 
|  | when set, this context manager overrides the value of ``debug`` passed to | 
|  | checkpoint. To defer to the local setting, pass ``None`` to this context. | 
|  |  | 
|  | Args: | 
|  | enabled (bool): Whether checkpoint should print debug information. | 
|  | Default is 'None'. | 
|  | """ | 
|  | global _checkpoint_debug_enabled | 
|  | try: | 
|  | prev = _checkpoint_debug_enabled | 
|  | _checkpoint_debug_enabled = enabled | 
|  | yield | 
|  | finally: | 
|  | _checkpoint_debug_enabled = prev | 
|  |  | 
|  |  | 
|  | def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: | 
|  | if isinstance(inputs, tuple): | 
|  | out = [] | 
|  | for inp in inputs: | 
|  | if not isinstance(inp, torch.Tensor): | 
|  | out.append(inp) | 
|  | continue | 
|  |  | 
|  | x = inp.detach() | 
|  | x.requires_grad = inp.requires_grad | 
|  | out.append(x) | 
|  | return tuple(out) | 
|  | else: | 
|  | raise RuntimeError( | 
|  | "Only tuple of tensors is supported. Got Unsupported input type: ", | 
|  | type(inputs).__name__, | 
|  | ) | 
|  |  | 
|  |  | 
|  | def check_backward_validity(inputs: Iterable[Any]) -> None: | 
|  | if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)): | 
|  | warnings.warn( | 
|  | "None of the inputs have requires_grad=True. Gradients will be None" | 
|  | ) | 
|  |  | 
|  |  | 
|  | def _get_device_module(device="cuda"): | 
|  | device_module = getattr(torch, device) | 
|  | return device_module | 
|  |  | 
|  |  | 
|  | class DefaultDeviceType: | 
|  | r""" | 
|  | A class that manages the default device type for checkpointing. | 
|  |  | 
|  | If no non-CPU tensors are present, the default device type will | 
|  | be used. The default value is 'cuda'. The device type is used in | 
|  | the checkpointing process when determining which device states | 
|  | to save and restore for recomputation. | 
|  | """ | 
|  |  | 
|  | _default_device_type = "cuda" | 
|  |  | 
|  | @staticmethod | 
|  | def set_device_type(device: str = "cuda"): | 
|  | """ | 
|  | Set the default device type for checkpointing. | 
|  |  | 
|  | Args: | 
|  | device (str): The device type to be set as default. Default is 'cuda'. | 
|  | """ | 
|  | DefaultDeviceType._default_device_type = device | 
|  |  | 
|  | @staticmethod | 
|  | def get_device_type() -> str: | 
|  | """ | 
|  | Get the current default device type for checkpointing. | 
|  |  | 
|  | Returns: | 
|  | str: The current default device type. | 
|  | """ | 
|  | return DefaultDeviceType._default_device_type | 
|  |  | 
|  |  | 
|  | def _infer_device_type(*args): | 
|  | device_types = [] | 
|  |  | 
|  | def add_device_types(arg): | 
|  | nonlocal device_types | 
|  | if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu": | 
|  | device_types.append(arg.device.type) | 
|  | tree_map(add_device_types, args) | 
|  |  | 
|  | device_types_set = set(device_types) | 
|  | if len(device_types_set) > 1: | 
|  | warnings.warn( | 
|  | "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. " | 
|  | "Device state will only be saved for devices of a single device type, and the remaining " | 
|  | "devices will be ignored. Consequently, if any checkpointed functions involve randomness, " | 
|  | "this may result in incorrect gradients. (Note that if CUDA devices are among the devices " | 
|  | "detected, it will be prioritized; otherwise, the first device encountered will be selected.)" | 
|  | f"\nDevice types: {sorted(device_types_set)} first device type: {device_types[0]}" | 
|  | ) | 
|  | if len(device_types) == 0: | 
|  | return DefaultDeviceType.get_device_type() | 
|  | elif "cuda" in device_types_set: | 
|  | return "cuda" | 
|  | else: | 
|  | return device_types[0] | 
|  |  | 
|  |  | 
|  | # We can't know if the run_fn will internally move some args to different devices, | 
|  | # which would require logic to preserve rng states for those devices as well. | 
|  | # We could paranoically stash and restore ALL the rng states for all visible devices, | 
|  | # but that seems very wasteful for most cases.  Compromise:  Stash the RNG state for | 
|  | # the device of all Tensor args. | 
|  | # | 
|  | # To consider:  maybe get_device_states and set_device_states should reside in torch/random.py? | 
|  | def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]: | 
|  | # This will not error out if "arg" is a CPU tensor or a non-tensor type because | 
|  | # the conditionals short-circuit. | 
|  | fwd_device_ids = [] | 
|  |  | 
|  | def add_device_ids(arg): | 
|  | nonlocal fwd_device_ids | 
|  | if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu": | 
|  | fwd_device_ids.append(arg.get_device()) | 
|  | tree_map(add_device_ids, args) | 
|  |  | 
|  | fwd_device_states = [] | 
|  | device_module = _get_device_module(_infer_device_type(*args)) | 
|  |  | 
|  | for device_id in fwd_device_ids: | 
|  | with device_module.device(device_id): | 
|  | fwd_device_states.append(device_module.get_rng_state()) | 
|  |  | 
|  | return fwd_device_ids, fwd_device_states | 
|  |  | 
|  |  | 
|  | def set_device_states(devices, states) -> None: | 
|  | device_module = _get_device_module(_infer_device_type(*states)) | 
|  | for device, state in zip(devices, states): | 
|  | with device_module.device(device): | 
|  | device_module.set_rng_state(state) | 
|  |  | 
|  |  | 
|  | def _get_autocast_kwargs(device="cuda"): | 
|  | if torch.amp.is_autocast_available(device): | 
|  | device_autocast_kwargs = { | 
|  | "enabled": torch.is_autocast_enabled(device), | 
|  | "dtype": torch.get_autocast_dtype(device), | 
|  | "cache_enabled": torch.is_autocast_cache_enabled(), | 
|  | } | 
|  | else: | 
|  | device_autocast_kwargs = None | 
|  |  | 
|  | cpu_autocast_kwargs = { | 
|  | "enabled": torch.is_autocast_enabled('cpu'), | 
|  | "dtype": torch.get_autocast_dtype('cpu'), | 
|  | "cache_enabled": torch.is_autocast_cache_enabled(), | 
|  | } | 
|  |  | 
|  | return device_autocast_kwargs, cpu_autocast_kwargs | 
|  |  | 
|  |  | 
|  | class CheckpointFunction(torch.autograd.Function): | 
|  | @staticmethod | 
|  | def forward(ctx, run_function, preserve_rng_state, *args): | 
|  | check_backward_validity(args) | 
|  | ctx.run_function = run_function | 
|  | ctx.preserve_rng_state = preserve_rng_state | 
|  | # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. | 
|  | ctx.device = _infer_device_type(*args) | 
|  | ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs( | 
|  | ctx.device | 
|  | ) | 
|  | if preserve_rng_state: | 
|  | ctx.fwd_cpu_state = torch.get_rng_state() | 
|  | # Don't eagerly initialize the cuda context by accident. | 
|  | # (If the user intends that the context is initialized later, within their | 
|  | # run_function, we SHOULD actually stash the cuda state here.  Unfortunately, | 
|  | # we have no way to anticipate this will happen before we run the function.) | 
|  | ctx.had_device_in_fwd = False | 
|  | device_module = _get_device_module(ctx.device) | 
|  | if getattr(device_module, "_initialized", False): | 
|  | ctx.had_device_in_fwd = True | 
|  | ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args) | 
|  |  | 
|  | # Save non-tensor inputs in ctx, keep a placeholder None for tensors | 
|  | # to be filled out during the backward. | 
|  | ctx.inputs = [] | 
|  | ctx.tensor_indices = [] | 
|  | tensor_inputs = [] | 
|  | for i, arg in enumerate(args): | 
|  | if torch.is_tensor(arg): | 
|  | tensor_inputs.append(arg) | 
|  | ctx.tensor_indices.append(i) | 
|  | ctx.inputs.append(None) | 
|  | else: | 
|  | ctx.inputs.append(arg) | 
|  |  | 
|  | ctx.save_for_backward(*tensor_inputs) | 
|  |  | 
|  | with torch.no_grad(): | 
|  | outputs = run_function(*args) | 
|  | return outputs | 
|  |  | 
|  | @staticmethod | 
|  | def backward(ctx, *args): | 
|  | if not torch.autograd._is_checkpoint_valid(): | 
|  | raise RuntimeError( | 
|  | "When use_reentrant=True, torch.utils.checkpoint is incompatible" | 
|  | " with .grad() or passing an `inputs` parameter to .backward()." | 
|  | " To resolve this error, you can either set use_reentrant=False," | 
|  | " or call .backward() without passing the `inputs` argument." | 
|  | ) | 
|  | # Copy the list to avoid modifying original list. | 
|  | inputs = list(ctx.inputs) | 
|  | tensor_indices = ctx.tensor_indices | 
|  | tensors = ctx.saved_tensors | 
|  | device_module = _get_device_module(ctx.device) | 
|  |  | 
|  | # Fill in inputs with appropriate saved tensors. | 
|  | for i, idx in enumerate(tensor_indices): | 
|  | inputs[idx] = tensors[i] | 
|  |  | 
|  | # Stash the surrounding rng state, and mimic the state that was | 
|  | # present at this time during forward.  Restore the surrounding state | 
|  | # when we're done. | 
|  | rng_devices = [] | 
|  | if ctx.preserve_rng_state and ctx.had_device_in_fwd: | 
|  | rng_devices = ctx.fwd_devices | 
|  | with torch.random.fork_rng( | 
|  | devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device | 
|  | ): | 
|  | if ctx.preserve_rng_state: | 
|  | torch.set_rng_state(ctx.fwd_cpu_state) | 
|  | if ctx.had_device_in_fwd: | 
|  | set_device_states(ctx.fwd_devices, ctx.fwd_device_states) | 
|  | detached_inputs = detach_variable(tuple(inputs)) | 
|  |  | 
|  | device_autocast_ctx = torch.amp.autocast( | 
|  | device_type=ctx.device, **ctx.device_autocast_kwargs | 
|  | ) if torch.amp.is_autocast_available(ctx.device) else contextlib.nullcontext() | 
|  | with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined] | 
|  | outputs = ctx.run_function(*detached_inputs) | 
|  |  | 
|  | if isinstance(outputs, torch.Tensor): | 
|  | outputs = (outputs,) | 
|  |  | 
|  | # run backward() with only tensor that requires grad | 
|  | outputs_with_grad = [] | 
|  | args_with_grad = [] | 
|  | for i in range(len(outputs)): | 
|  | if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: | 
|  | outputs_with_grad.append(outputs[i]) | 
|  | args_with_grad.append(args[i]) | 
|  | if len(outputs_with_grad) == 0: | 
|  | raise RuntimeError( | 
|  | "none of output has requires_grad=True," | 
|  | " this checkpoint() is not necessary" | 
|  | ) | 
|  | torch.autograd.backward(outputs_with_grad, args_with_grad) | 
|  | grads = tuple( | 
|  | inp.grad if isinstance(inp, torch.Tensor) else None | 
|  | for inp in detached_inputs | 
|  | ) | 
|  |  | 
|  | return (None, None) + grads | 
|  |  | 
|  |  | 
|  | def noop_context_fn(): | 
|  | return contextlib.nullcontext(), contextlib.nullcontext() | 
|  |  | 
|  | # TorchDynamo does not step inside utils.checkpoint function.  The flow | 
|  | # looks likes this | 
|  | #  1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by | 
|  | #     speculatively checking if the forward function is safe to trace. | 
|  | #  2) If yes, then Dynamo-generated Fx graph has the wrapped higher | 
|  | #     order op. As a result, TorchDynamo does not look inside utils.checkpoint. | 
|  | #  3) If not, then TorchDynamo falls back to eager by performing a graph | 
|  | #     break. And here, the following disable wrapper ensures that | 
|  | #     TorchDynamo does not trigger again on the frames created by | 
|  | #     utils.checkpoint innards. | 
|  | @torch._disable_dynamo | 
|  | def checkpoint( | 
|  | function, | 
|  | *args, | 
|  | use_reentrant: Optional[bool] = None, | 
|  | context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, | 
|  | determinism_check: str = _DEFAULT_DETERMINISM_MODE, | 
|  | debug: bool = False, | 
|  | **kwargs | 
|  | ): | 
|  | r"""Checkpoint a model or part of the model. | 
|  |  | 
|  | Activation checkpointing is a technique that trades compute for memory. | 
|  | Instead of keeping tensors needed for backward alive until they are used in | 
|  | gradient computation during backward, forward computation in checkpointed | 
|  | regions omits saving tensors for backward and recomputes them during the | 
|  | backward pass. Activation checkpointing can be applied to any part of a | 
|  | model. | 
|  |  | 
|  | There are currently two checkpointing implementations available, determined | 
|  | by the :attr:`use_reentrant` parameter. It is recommended that you use | 
|  | ``use_reentrant=False``. Please refer the note below for a discussion of | 
|  | their differences. | 
|  |  | 
|  | .. warning:: | 
|  |  | 
|  | If the :attr:`function` invocation during the backward pass differs | 
|  | from the forward pass, e.g., due to a global variable, the checkpointed | 
|  | version may not be equivalent, potentially causing an | 
|  | error being raised or leading to silently incorrect gradients. | 
|  |  | 
|  | .. warning:: | 
|  |  | 
|  | The ``use_reentrant`` parameter should be passed explicitly. In version | 
|  | 2.4 we will raise an exception if ``use_reentrant`` is not passed. | 
|  | If you are using the ``use_reentrant=True`` variant, please refer to the | 
|  | note below for important considerations and potential limitations. | 
|  |  | 
|  | .. note:: | 
|  |  | 
|  | The reentrant variant of checkpoint (``use_reentrant=True``) and | 
|  | the non-reentrant variant of checkpoint (``use_reentrant=False``) | 
|  | differ in the following ways: | 
|  |  | 
|  | * Non-reentrant checkpoint stops recomputation as soon as all needed | 
|  | intermediate activations have been recomputed. This feature is enabled | 
|  | by default, but can be disabled with :func:`set_checkpoint_early_stop`. | 
|  | Reentrant checkpoint always recomputes :attr:`function` in its | 
|  | entirety during the backward pass. | 
|  |  | 
|  | * The reentrant variant does not record the autograd graph during the | 
|  | forward pass, as it runs with the forward pass under | 
|  | :func:`torch.no_grad`. The non-reentrant version does record the | 
|  | autograd graph, allowing one to perform backward on the graph within | 
|  | checkpointed regions. | 
|  |  | 
|  | * The reentrant checkpoint only supports the | 
|  | :func:`torch.autograd.backward` API for the backward pass without its | 
|  | `inputs` argument, while the non-reentrant version supports all ways | 
|  | of performing the backward pass. | 
|  |  | 
|  | * At least one input and output must have ``requires_grad=True`` for the | 
|  | reentrant variant. If this condition is unmet, the checkpointed part | 
|  | of the model will not have gradients. The non-reentrant version does | 
|  | not have this requirement. | 
|  |  | 
|  | * The reentrant version does not consider tensors in nested structures | 
|  | (e.g., custom objects, lists, dicts, etc) as participating in | 
|  | autograd, while the non-reentrant version does. | 
|  |  | 
|  | * The reentrant checkpoint does not support checkpointed regions with | 
|  | detached tensors from the computational graph, whereas the | 
|  | non-reentrant version does. For the reentrant variant, if the | 
|  | checkpointed segment contains tensors detached using ``detach()`` or | 
|  | with :func:`torch.no_grad`, the backward pass will raise an error. | 
|  | This is because ``checkpoint`` makes all the outputs require gradients | 
|  | and this causes issues when a tensor is defined to have no gradient in | 
|  | the model. To avoid this, detach the tensors outside of the | 
|  | ``checkpoint`` function. | 
|  |  | 
|  | Args: | 
|  | function: describes what to run in the forward pass of the model or | 
|  | part of the model. It should also know how to handle the inputs | 
|  | passed as the tuple. For example, in LSTM, if user passes | 
|  | ``(activation, hidden)``, :attr:`function` should correctly use the | 
|  | first input as ``activation`` and the second input as ``hidden`` | 
|  | preserve_rng_state(bool, optional):  Omit stashing and restoring | 
|  | the RNG state during each checkpoint. Note that under torch.compile, | 
|  | this flag doesn't take effect and we always preserve RNG state. | 
|  | Default: ``True`` | 
|  | use_reentrant(bool): | 
|  | specify whether to use the activation checkpoint variant that | 
|  | requires reentrant autograd. This parameter should be passed | 
|  | explicitly. In version 2.4 we will raise an exception if | 
|  | ``use_reentrant`` is not passed. If ``use_reentrant=False``, | 
|  | ``checkpoint`` will use an implementation that does not require | 
|  | reentrant autograd. This allows ``checkpoint`` to support additional | 
|  | functionality, such as working as expected with | 
|  | ``torch.autograd.grad`` and support for keyword arguments input into | 
|  | the checkpointed function. | 
|  | context_fn(Callable, optional): A callable returning a tuple of two | 
|  | context managers. The function and its recomputation will be run | 
|  | under the first and second context managers respectively. | 
|  | This argument is only supported if ``use_reentrant=False``. | 
|  | determinism_check(str, optional): A string specifying the determinism | 
|  | check to perform. By default it is set to ``"default"`` which | 
|  | compares the shapes, dtypes, and devices of the recomputed tensors | 
|  | against those the saved tensors. To turn off this check, specify | 
|  | ``"none"``. Currently these are the only two supported values. | 
|  | Please open an issue if you would like to see more determinism | 
|  | checks. This argument is only supported if ``use_reentrant=False``, | 
|  | if ``use_reentrant=True``, the determinism check is always disabled. | 
|  | debug(bool, optional): If ``True``, error messages will also include | 
|  | a trace of the operators ran during the original forward computation | 
|  | as well as the recomputation. This argument is only supported if | 
|  | ``use_reentrant=False``. | 
|  | args: tuple containing inputs to the :attr:`function` | 
|  |  | 
|  | Returns: | 
|  | Output of running :attr:`function` on :attr:`*args` | 
|  | """ | 
|  | if use_reentrant is None: | 
|  | warnings.warn( | 
|  | "torch.utils.checkpoint: the use_reentrant parameter should be " | 
|  | "passed explicitly. In version 2.4 we will raise an exception " | 
|  | "if use_reentrant is not passed. use_reentrant=False is " | 
|  | "recommended, but if you need to preserve the current default " | 
|  | "behavior, you can pass use_reentrant=True. Refer to docs for more " | 
|  | "details on the differences between the two variants.", | 
|  | stacklevel=2 | 
|  | ) | 
|  | use_reentrant = True | 
|  |  | 
|  | # Hack to mix *args with **kwargs in a python 2.7-compliant way | 
|  | preserve = kwargs.pop("preserve_rng_state", True) | 
|  | if kwargs and use_reentrant: | 
|  | raise ValueError( | 
|  | "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) | 
|  | ) | 
|  |  | 
|  | if use_reentrant: | 
|  | if context_fn is not noop_context_fn or debug is not False: | 
|  | raise ValueError( | 
|  | "Passing `context_fn` or `debug` is only supported when " | 
|  | "use_reentrant=False." | 
|  | ) | 
|  | return CheckpointFunction.apply(function, preserve, *args) | 
|  | else: | 
|  | gen = _checkpoint_without_reentrant_generator( | 
|  | function, preserve, context_fn, determinism_check, debug, *args, **kwargs | 
|  | ) | 
|  | # Runs pre-forward logic | 
|  | next(gen) | 
|  | ret = function(*args, **kwargs) | 
|  | # Runs post-forward logic | 
|  | try: | 
|  | next(gen) | 
|  | except StopIteration: | 
|  | return ret | 
|  |  | 
|  |  | 
|  | def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs): | 
|  | r"""Checkpoint a sequential model to save memory. | 
|  |  | 
|  | Sequential models execute a list of modules/functions in order | 
|  | (sequentially). Therefore, we can divide such a model in various segments | 
|  | and checkpoint each segment. All segments except the last will not store | 
|  | the intermediate activations. The inputs of each checkpointed segment will | 
|  | be saved for re-running the segment in the backward pass. | 
|  |  | 
|  | .. warning:: | 
|  | The ``use_reentrant`` parameter should be passed explicitly. In version | 
|  | 2.4 we will raise an exception if ``use_reentrant`` is not passed. | 
|  | If you are using the ``use_reentrant=True` variant, please see | 
|  | :func:`~torch.utils.checkpoint.checkpoint` for | 
|  | the important considerations and limitations of this variant. It is | 
|  | recommended that you use ``use_reentrant=False``. | 
|  |  | 
|  | .. warning: | 
|  | Since PyTorch 1.4, it allows only one Tensor as the input and | 
|  | intermediate outputs, just like :class:`torch.nn.Sequential`. | 
|  |  | 
|  | Args: | 
|  | functions: A :class:`torch.nn.Sequential` or the list of modules or | 
|  | functions (comprising the model) to run sequentially. | 
|  | segments: Number of chunks to create in the model | 
|  | input: A Tensor that is input to :attr:`functions` | 
|  | preserve_rng_state(bool, optional):  Omit stashing and restoring | 
|  | the RNG state during each checkpoint. | 
|  | Default: ``True`` | 
|  | use_reentrant(bool): | 
|  | specify whether to use the activation checkpoint variant that | 
|  | requires reentrant autograd. This parameter should be passed | 
|  | explicitly. In version 2.4 we will raise an exception if | 
|  | ``use_reentrant`` is not passed. If ``use_reentrant=False``, | 
|  | ``checkpoint`` will use an implementation that does not require | 
|  | reentrant autograd. This allows ``checkpoint`` to support additional | 
|  | functionality, such as working as expected with | 
|  | ``torch.autograd.grad`` and support for keyword arguments input into | 
|  | the checkpointed function. | 
|  |  | 
|  | Returns: | 
|  | Output of running :attr:`functions` sequentially on :attr:`*inputs` | 
|  |  | 
|  | Example: | 
|  | >>> # xdoctest: +SKIP("stub") | 
|  | >>> model = nn.Sequential(...) | 
|  | >>> input_var = checkpoint_sequential(model, chunks, input_var) | 
|  | """ | 
|  | if use_reentrant is None: | 
|  | warnings.warn( | 
|  | "torch.utils.checkpoint.checkpoint_sequential: the use_reentrant " | 
|  | "parameter should be passed explicitly. " | 
|  | "In version 2.4 we will raise an exception if use_reentrant " | 
|  | "is not passed. use_reentrant=False is " | 
|  | "recommended, but if you need to preserve the current default " | 
|  | "behavior, you can pass use_reentrant=True. Refer to docs for more " | 
|  | "details on the differences between the two variants." | 
|  | ) | 
|  | use_reentrant = True | 
|  |  | 
|  | # Hack for keyword-only parameter in a python 2.7-compliant way | 
|  | preserve = kwargs.pop("preserve_rng_state", True) | 
|  | if kwargs: | 
|  | raise ValueError( | 
|  | "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) | 
|  | ) | 
|  |  | 
|  | def run_function(start, end, functions): | 
|  | def forward(input): | 
|  | for j in range(start, end + 1): | 
|  | input = functions[j](input) | 
|  | return input | 
|  |  | 
|  | return forward | 
|  |  | 
|  | if isinstance(functions, torch.nn.Sequential): | 
|  | functions = list(functions.children()) | 
|  |  | 
|  | segment_size = len(functions) // segments | 
|  | # the last chunk has to be non-volatile | 
|  | end = -1 | 
|  | for start in range(0, segment_size * (segments - 1), segment_size): | 
|  | end = start + segment_size - 1 | 
|  | input = checkpoint( | 
|  | run_function(start, end, functions), | 
|  | input, | 
|  | use_reentrant=use_reentrant, | 
|  | preserve_rng_state=preserve, | 
|  | ) | 
|  | return run_function(end + 1, len(functions) - 1, functions)(input) | 
|  |  | 
|  |  | 
|  | def _internal_assert(cond): | 
|  | if not cond: | 
|  | raise AssertionError( | 
|  | "Something went unexpectedly wrong in activation checkpoint. " | 
|  | "Please report this bug by filing an issue to PyTorch." | 
|  | ) | 
|  |  | 
|  |  | 
|  | # NOTE [ Nestable Checkpoint ] | 
|  | # | 
|  | # The semantics of nested checkpoint can be defined by two basic rules. | 
|  | # Following the two rules leads to an important implication that is central | 
|  | # to motivating the design. | 
|  | # | 
|  | # Rule 1. Saved tensors are managed by inner-most checkpoint only and hidden | 
|  | #         from any outer layers of checkpoint. | 
|  | # | 
|  | # Rule 2. The inputs of inner checkpoints are treated as tensors saved to its | 
|  | #         parent checkpoint. | 
|  | # | 
|  | # Implication: To recompute any given saved tensor, we need to recompute all of | 
|  | #              the checkpoints wrapping it. | 
|  | # | 
|  | # Why is this implied? To unpack a saved tensor X during backward we need to | 
|  | # recompute the inner-most checkpoint (#1), and in order to recompute that | 
|  | # checkpoint I need to have its inputs, which are managed by that checkpoint's | 
|  | # parent (#2), which thus also needs to be recomputed first. Continue this line | 
|  | # of reasoning and we realize that in order to unpack X, all checkpoints that | 
|  | # were active at the time X was saved need to be recomputed. (unless we have | 
|  | # already done so in that backward for some other saved tensor). | 
|  | # | 
|  | # In practice, we use a noop autograd Function to save inputs as saved tensors. | 
|  | # During unpack calling ctx.saved_tensor triggers the parent checkpoint to | 
|  | # recompute. | 
|  | # | 
|  | # Rule 3. We should start recomputation as if there are no checkpoints currently | 
|  | #         active. Checkpoints encountered during recomputation are still | 
|  | #         respected. | 
|  | # | 
|  | # When we start recomputation, we push the saved variable hook meant for | 
|  | # recomputation on the stack. See examples in Rule 6 for more context. | 
|  | # | 
|  | #                                  * * * * | 
|  | # | 
|  | # Beyond the basic semantics specific to nested checkpoint, we impose several | 
|  | # more constraints that may apply to checkpointing in general. | 
|  | # | 
|  | # Rule 4. Lifetime of recomputed tensors | 
|  | # | 
|  | #         Recomputed tensors are considered specific to particular invocations | 
|  | #         of backward and are always cleared immediately as they are unpacked | 
|  | #         Particularly, we require this to happen even if retain_graph=True. | 
|  | # | 
|  | # [ Implementation details of Rule 4 ] | 
|  | # | 
|  | # If we were okay with recomputed tensors staying alive after backward is run | 
|  | # with retain_graph=True, we would store recomputed variables as the values of a | 
|  | # WeakKeyDictionary and pack strong references to the keys, so that as we | 
|  | # backward, those packed keys would be cleared as long as retain_graph=False. | 
|  | # Clearing the packed key clears the corresponding entry in the WKD. | 
|  | # | 
|  | # If we wish recomputed variables to be immediately cleared as we unpack them in | 
|  | # the retain_graph=True case, we cannot rely on the packed keys to be cleared by | 
|  | # backward automatically. Instead of packing the strong reference to the key | 
|  | # directly, we pack a container object, which we manually clear as we unpack. | 
|  | # | 
|  | # An important detail is that if a second backward happens, the second | 
|  | # recomputation needs to reset the container with a newly created key. | 
|  | # | 
|  | # Rule 5. Stop recomputation as soon as we've recomputed the saved tensors we | 
|  | #         know we need. | 
|  | # | 
|  | # [ Implementation details of Rule 5 ] | 
|  | # | 
|  | # During recomputation, raise an exception if the number of recomputed tensors | 
|  | # matches the number of tensors that we expected to recompute. We wrap the | 
|  | # recomputation call with a try-catch to catch this specific exception. See | 
|  | # Rule #6 below for some examples. | 
|  | # | 
|  | # Rule 6. We support doing backward inside checkpoint context | 
|  | # | 
|  | # [ retain_graph is True] | 
|  | # | 
|  | # def fn(x): | 
|  | #   y = x.sin() | 
|  | #   z = y.cos() | 
|  | #   gx, = torch.autograd.grad(z, x, retains_grad=True) | 
|  | #   return gx, z | 
|  | # | 
|  | # out = checkpoint(fn)(inp) | 
|  | # out.backward() | 
|  | # | 
|  | # Because z is saved by cos while checkpoint is enabled, it would not be | 
|  | # actually saved, and so the .grad() call inside must trigger a recomputation. | 
|  | # | 
|  | # During recomputation the "inner pack hook" has two responsibilities: | 
|  | # | 
|  | # 1) As usual, populating the WeakKeyDictionary storing recomputed tensors | 
|  | # 2) Pack the actual tensor (detached) so that one may perform backward on the | 
|  | #    recomputed graph. The tensors saved to this graph will live until the end | 
|  | #    of recomputation, or die earlier if someone performs backward with | 
|  | #    retain_graph=False. | 
|  | # | 
|  | # More generally performing backward on the recomputed graph occurs in the | 
|  | # following cases: | 
|  | # - If backward is performed inside forward, | 
|  | #   - During the original forward IF early-stop is disabled | 
|  | #   - During the original backward | 
|  | # - If there are multiple .grad()/.backward() calls, we would perform backward | 
|  | #   on the recomputed graph even if early-stop is enabled (see the example below) | 
|  | # | 
|  | # [ retain_graph is False ] | 
|  | # | 
|  | # The example below shows what happens if during recomputation we find that some | 
|  | # of the tensors we are trying to recompute have already been cleared. | 
|  | # | 
|  | # Spoiler: we don't do anything special, we just skip over them! | 
|  | # | 
|  | # def fn(x): | 
|  | #   y = x.sin()                           # (1) | 
|  | #   z = y.cos()                           # (2) | 
|  | #   gx, = torch.autograd.grad(z, x)       # (3) | 
|  | #   return x.cos() * gx                   # (4) | 
|  | # | 
|  | # out = checkpoint(fn)(inp) | 
|  | # out.backward()                          # (5) | 
|  | # | 
|  | # 1, 2. Don't save x and y since we are inside a checkpoint. | 
|  | # 3. Trigger a recompute of fn since x and y weren't saved. | 
|  | #    And depending on whether early stop is enabled, either stop at (2) or | 
|  | #    continue running the function. | 
|  | #    Because we are running backward with retain_graph=False, we clear x and y's | 
|  | #    holders. | 
|  | # 4. Don't save x since we are inside a checkpoint. | 
|  | # 5. Calling backward triggers another recompute of fn. During recompute, we see | 
|  | #    that x and y have already been cleared in the original graph as indicated | 
|  | #    by holder=None. We skip over them. We still save x at (4) (since its holder | 
|  | #    is still alive.) | 
|  |  | 
|  | _enable_checkpoint_early_stop = True | 
|  |  | 
|  |  | 
|  | @contextlib.contextmanager | 
|  | def set_checkpoint_early_stop(enable: bool): | 
|  | """Context manager that sets whether checkpoint should stop recomputation early. | 
|  |  | 
|  | By default, non-reentrant checkpoint stops recomputation as soon as it | 
|  | has computed all needed Tensors. This context manager can be used to disable | 
|  | that feature if it is problematic for your specific application. | 
|  |  | 
|  | This context manager only needs to be active when forward is run. It does | 
|  | not need to be active during backward. | 
|  |  | 
|  | Example:: | 
|  |  | 
|  | >>> # xdoctest: +SKIP(failing) | 
|  | >>> message = "saved tensors default hooks are disabled" | 
|  | >>> with set_checkpoint_early_stop(False): | 
|  | ...     # Any checkpoint under this context manager will respect this | 
|  | ...     # context manager, even if its backward is performed outside. | 
|  | ...     out = checkpoint(fn, inputs) | 
|  | ... | 
|  | >>> out.backward() | 
|  | """ | 
|  | global _enable_checkpoint_early_stop | 
|  | try: | 
|  | prev = _enable_checkpoint_early_stop | 
|  | _enable_checkpoint_early_stop = enable | 
|  | yield | 
|  | finally: | 
|  | _enable_checkpoint_early_stop = prev | 
|  |  | 
|  |  | 
|  | class _Handle: | 
|  | pass | 
|  |  | 
|  |  | 
|  | class _Holder: | 
|  | def __init__(self): | 
|  | self.handles: Dict[int, Optional[_Handle]] = dict() | 
|  |  | 
|  |  | 
|  | class _NoopSaveInputs(torch.autograd.Function): | 
|  | @staticmethod | 
|  | def forward(*args): | 
|  | return torch.empty((0,)) | 
|  |  | 
|  | @staticmethod | 
|  | def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: | 
|  | # Only tensors can be saved with ctx.save_for_backward, everything else | 
|  | # is captured by get_args, which is saved directly on ctx | 
|  | tensor_indices, tensors = zip( | 
|  | *[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)] | 
|  | ) | 
|  | idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)} | 
|  | # args but with tensors replaced with None as placeholders | 
|  | args = [None if isinstance(o, torch.Tensor) else o for o in inputs] | 
|  |  | 
|  | def get_args(saved_tensors): | 
|  | # restore the placeholders with the original tensors grabbed from | 
|  | # ctx.saved_tensors (which may be saved on a parent checkpoint if | 
|  | # this checkpoint is nested, and that would trigger a recursive | 
|  | # unpack!) | 
|  | ret = [ | 
|  | saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o | 
|  | for i, o in enumerate(args) | 
|  | ] | 
|  | # grab the tail since we also saved the dummy to avoid having to explicitly | 
|  | # handle the case where there are no tensor inputs | 
|  | return ret[1:] | 
|  |  | 
|  | ctx.get_args = get_args | 
|  | ctx.save_for_backward(*tensors) | 
|  |  | 
|  | @staticmethod | 
|  | def backward(ctx, *grad_outputs): | 
|  | raise AssertionError("Did not expect to backward on this graph") | 
|  |  | 
|  |  | 
|  | class _CheckpointFrame: | 
|  | def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn): | 
|  | self.recompute_fn = recompute_fn | 
|  | self.input_saver = None | 
|  | self.weak_holders: List[ReferenceType] = [] | 
|  | # We store this as a weakkeydictionary so that in the case of a partial | 
|  | # backward, the entries in the dict are cleared alongside the Holder | 
|  | # which will be removed when the SavedVariable is cleared. | 
|  | self.recomputed: DefaultDict[ | 
|  | int, weakref.WeakKeyDictionary[_Handle, torch.Tensor] | 
|  | ] = defaultdict(weakref.WeakKeyDictionary) | 
|  | # We need both recomp_counter and recomputed since they can diverge | 
|  | # https://github.com/pytorch/pytorch/pull/90105#discussion_r1135889885 | 
|  | self.recomp_counter: DefaultDict[int, int] = defaultdict(int) | 
|  | self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool) | 
|  |  | 
|  | # See Rule 5 | 
|  | self.early_stop = early_stop | 
|  |  | 
|  | # Debugging | 
|  | self.metadata_fn = metadata_fn | 
|  | self.unpack_error_cb = unpack_error_cb | 
|  | self.x_metadatas = [] | 
|  | self.forward_completed = False | 
|  | self.ignore_saved_mismatch = False | 
|  |  | 
|  | def check_recomputed_tensors_match(self, gid): | 
|  | if self.ignore_saved_mismatch: | 
|  | # TODO: we can probably make this check stricter by checking that | 
|  | #       the metadata of the first tensors still match. | 
|  | return | 
|  | # NOTE [ Error handling for checkpoint ] | 
|  | # | 
|  | # At a high level, we need to check that the tensors saved | 
|  | # during original forward matches tensors saved during recompute | 
|  | # This means handling 3 cases: | 
|  | # | 
|  | # 1. During recompute, more tensors were saved. | 
|  | # | 
|  | #    Usually this is hidden due to the StopRecomputationError | 
|  | #    but if early stop is not enabled, or we would have errored | 
|  | #    anyway because there aren't enough weak_holders. But we | 
|  | #    do want to have a nice error. See the _recomputation_hook | 
|  | #    for details. | 
|  | if not len(self.weak_holders) == self.recomp_counter[gid]: | 
|  | # 2. During recompute, fewer tensors were saved | 
|  | # | 
|  | # We know that everytime we save something do original forward | 
|  | # we append to weak_holder, and every time we save a tensor | 
|  | # during recompute we increment recompute_counter. | 
|  | raise CheckpointError( | 
|  | "torch.utils.checkpoint: A different number of tensors was saved " | 
|  | "during the original forward and recomputation.\n" | 
|  | f"Number of tensors saved during forward: {len(self.weak_holders)}\n" | 
|  | f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}" | 
|  | ) | 
|  |  | 
|  | # 3. During recompute, the same tensors were saved, but they | 
|  | #    have different metadata | 
|  | nb_meta_different = [] | 
|  | for idx, weak_holder in enumerate(self.weak_holders): | 
|  | holder = weak_holder() | 
|  | if holder is None: | 
|  | continue | 
|  | # We've seen all holders since we iterate over them in order | 
|  | # For every holder that is still alive now, it must've been | 
|  | # alive when we saw it during recompute, therefore, the | 
|  | # gid must be set. | 
|  | _internal_assert(gid in holder.handles) | 
|  | # We know this is the first unpack, so it couldn't have been set | 
|  | # to None yet. | 
|  | _internal_assert(holder.handles[gid] is not None) | 
|  | # We always set these together in the recomputation hook | 
|  | _internal_assert(holder.handles[gid] in self.recomputed[gid]) | 
|  | # see pack hook, x_metadata is 1:1 with weak_holders. | 
|  | x_meta = self.x_metadatas[idx] | 
|  | recomputed_x = self.recomputed[gid][holder.handles[gid]] | 
|  | if x_meta != self.metadata_fn(recomputed_x): | 
|  | nb_meta_different.append((idx, x_meta, self.metadata_fn(recomputed_x))) | 
|  |  | 
|  | if len(nb_meta_different) > 0: | 
|  | mismatched_tensors = "" | 
|  | for idx, x_meta, recomputed_meta in nb_meta_different: | 
|  | mismatched_tensors += ( | 
|  | f"tensor at position {idx}:\n" | 
|  | f"saved metadata: {x_meta}\n" | 
|  | f"recomputed metadata: {recomputed_meta}\n" | 
|  | ) | 
|  | raise CheckpointError( | 
|  | "torch.utils.checkpoint: Recomputed values for the following tensors " | 
|  | "have different metadata than during the forward pass.\n" | 
|  | f"{mismatched_tensors}" | 
|  | ) | 
|  |  | 
|  |  | 
|  | _checkpoint_error_template = """ \ | 
|  | An error happened while unpacking tensors; dumping logs of latest computation | 
|  | because you passed `debug=True` to `torch.utils.checkpoint.checkpoint()`. | 
|  | Scroll all the way down for guidance on how to navigate these logs. | 
|  |  | 
|  | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ | 
|  | |        1. Stack traces of the operators that ran in the original forward     | | 
|  | +------------------------------------------------------------------------------+ | 
|  |  | 
|  | {forward_traces} | 
|  | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ | 
|  | |        2. Stack traces of the operators that ran during recomputation        | | 
|  | +------------------------------------------------------------------------------+ | 
|  |  | 
|  | {recompute_traces} | 
|  | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ | 
|  | |       3. Log of operators in the original forward and recomputation          | | 
|  | +------------------------------------------------------------------------------+ | 
|  | (Scroll up to correlate stack traces with each operation listed below. This | 
|  | helps identify their source in the code.) | 
|  |  | 
|  | IMPORTANT: Differences in "detach" calls between the original forward and the | 
|  | recomputation are expected. They are introduced by the checkpointing | 
|  | mechanism and can be ignored. | 
|  |  | 
|  | Operations executed during the original forward: | 
|  |  | 
|  | {forward_ops} | 
|  |  | 
|  | Operations executed during recomputation: | 
|  |  | 
|  | {recompute_ops} | 
|  |  | 
|  | +------------------------------------------------------------------------------+ | 
|  | ERROR: Detected non-determinism while running activation checkpointing | 
|  |  | 
|  | You are seeing this error because you passed `debug=True` to checkpoint and | 
|  | tensors to be saved during the original forward and differ between those saved | 
|  | during recomputation. This can happen if different operators were ran in the | 
|  | original forward and in the recomputation. | 
|  |  | 
|  | To identify where the mismatch may be coming from, you can do the following: | 
|  |  | 
|  | 1) Compare the operators ran during original forward and recomputation to | 
|  | see where they differ. These operators are printed above in the order they | 
|  | were executed. | 
|  |  | 
|  | 2) Review the stack trace for each operator to locate its invocation source. | 
|  | Each operator's stack trace is printed in their execution order. | 
|  |  | 
|  | Note that the logs can be quite long. Here's how they are structured: | 
|  | (Tip: you can Ctrl-f for these headers) | 
|  |  | 
|  | 1. Stack traces of the operators that ran in the original forward | 
|  | 2. Stack traces of the operators that ran during recomputation | 
|  | 3. Log of operators in the original forward and recomputation | 
|  | 4. Error message                                             <--- You are here | 
|  | -------------------------------------------------------------------------------- | 
|  | """ | 
|  |  | 
|  | class CheckpointError(RuntimeError): | 
|  | pass | 
|  |  | 
|  |  | 
|  | def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[CheckpointError], None]]: | 
|  | # This function returns the context_fn and error_cb to be used by the | 
|  | # checkpointing mechanism. error_cb is invoked when an error is detected | 
|  | # during unpack. | 
|  |  | 
|  | # record_context_cpp is not support on non-linux non-x86_64 platforms | 
|  | cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux' | 
|  |  | 
|  | class CaptureLogs: | 
|  | def __init__(self): | 
|  | self.logs = None | 
|  | self.tbs = None | 
|  |  | 
|  | def get_context_manager(self): | 
|  | @contextlib.contextmanager | 
|  | def logging_mode(): | 
|  | with LoggingTensorMode(), \ | 
|  | capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb: | 
|  | self.logs, self.tbs = logs_and_tb | 
|  | yield logs_and_tb | 
|  | return logging_mode() | 
|  |  | 
|  | capture_logs_fwd = CaptureLogs() | 
|  | capture_logs_recompute = CaptureLogs() | 
|  |  | 
|  | def unpack_error_cb(e: CheckpointError): | 
|  | def get_str_tb(label, capture_logs): | 
|  | out = "" | 
|  | total_len = len(capture_logs.logs) | 
|  | for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs)): | 
|  | out += f"{log}   ({i + 1} of {total_len} in {label})\n\n" | 
|  | found_torch_dispatch = False | 
|  | for line in tb: | 
|  | # Start printing stack trace only after __torch_dispatch__ is found | 
|  | is_torch_dispatch = line['name'] == '__torch_dispatch__' | 
|  | if not found_torch_dispatch and not is_torch_dispatch: | 
|  | continue | 
|  | elif is_torch_dispatch: | 
|  | found_torch_dispatch = True | 
|  | continue | 
|  | out += f"{line['filename']}:{line['line']}:{line['name']}\n" | 
|  | out += "\n\n" | 
|  | return out | 
|  | assert capture_logs_fwd.logs is not None | 
|  | assert capture_logs_recompute.logs is not None | 
|  | raise CheckpointError( | 
|  | _checkpoint_error_template.format( | 
|  | forward_traces=get_str_tb("original", capture_logs_fwd), | 
|  | recompute_traces=get_str_tb("recompute", capture_logs_recompute), | 
|  | forward_ops="\n".join(capture_logs_fwd.logs), | 
|  | recompute_ops="\n".join(capture_logs_recompute.logs) | 
|  | ) | 
|  | ) from e | 
|  |  | 
|  | def context_fn(): | 
|  | return capture_logs_fwd.get_context_manager(), capture_logs_recompute.get_context_manager() | 
|  |  | 
|  | return context_fn, unpack_error_cb | 
|  |  | 
|  | def _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]: | 
|  | # These properties are fast to check, easy to understand | 
|  | return { | 
|  | "shape": x.shape, | 
|  | "dtype": x.dtype, | 
|  | "device": x.device | 
|  | } | 
|  |  | 
|  | _allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = { | 
|  | _DEFAULT_DETERMINISM_MODE: _default_meta_extractor, | 
|  | "none": lambda _: None, | 
|  | } | 
|  |  | 
|  | # See Rule 5 | 
|  | class _StopRecomputationError(Exception): | 
|  | pass | 
|  |  | 
|  |  | 
|  | class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks): | 
|  | def __init__(self, target_frame_ref: ReferenceType, gid: int): | 
|  | def pack_hook(x): | 
|  | target_frame = target_frame_ref() | 
|  | assert target_frame is not None  # appease mypy | 
|  | recomp_idx = target_frame.recomp_counter[gid] | 
|  | target_frame.recomp_counter[gid] += 1 | 
|  |  | 
|  | if recomp_idx >= len(target_frame.weak_holders): | 
|  | assert not target_frame.early_stop | 
|  | if not target_frame.forward_completed: | 
|  | # We run into this case when early stop is not enabled and do | 
|  | # grad within checkpoint. | 
|  | # We need to set this flag, so we don't error out later when | 
|  | # we check if the number of tensors saved during forward and | 
|  | # recomputation match. | 
|  | target_frame.ignore_saved_mismatch = True | 
|  | return x.detach() | 
|  | raise CheckpointError( | 
|  | "torch.utils.checkpoint: trying to save more tensors during " | 
|  | "recomputation than during the original forward pass." | 
|  | ) | 
|  |  | 
|  | holder = target_frame.weak_holders[recomp_idx]() | 
|  |  | 
|  | # This holder may have been cleared because someone may have called | 
|  | # backward within forward. If so, we don't need to save. | 
|  | if holder is not None: | 
|  | _internal_assert(holder.handles.get(gid, None) is None) | 
|  | holder.handles[gid] = _Handle() | 
|  | target_frame.recomputed[gid][holder.handles[gid]] = x.detach() | 
|  |  | 
|  | if target_frame.early_stop and target_frame.recomp_counter[gid] == len( | 
|  | target_frame.weak_holders | 
|  | ): | 
|  | raise _StopRecomputationError | 
|  | # See Rule 6: [ retain_graph is True ] above | 
|  | return x.detach() | 
|  |  | 
|  | def unpack_hook(x): | 
|  | # See Rule 6: [ retain_graph is True ] above for an example of when | 
|  | # the graph created during recomputation could be backwarded. | 
|  | return x | 
|  |  | 
|  | super().__init__(pack_hook, unpack_hook) | 
|  |  | 
|  |  | 
|  | class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks): | 
|  | def __init__(self, frame): | 
|  | def pack_hook(x): | 
|  | # See Rule 4 above | 
|  | holder = _Holder() | 
|  | frame.weak_holders.append(weakref.ref(holder)) | 
|  | # Save metadata to detect non-determinism | 
|  | if frame.metadata_fn is not None: | 
|  | with torch.no_grad(): | 
|  | frame.x_metadatas.append(frame.metadata_fn(x)) | 
|  | return holder | 
|  |  | 
|  | def unpack_hook(holder): | 
|  | gid = torch._C._current_graph_task_id() | 
|  | if gid == -1: | 
|  | # generate a temporary id if we trigger unpack outside of a backward call | 
|  | gid = int(uuid.uuid4()) | 
|  |  | 
|  | if not frame.is_recomputed[gid]: | 
|  | ctx = frame.input_saver.grad_fn | 
|  | args = ctx.get_args(ctx.saved_tensors) | 
|  |  | 
|  | try: | 
|  | with _recomputation_hook( | 
|  | weakref.ref(frame), gid | 
|  | ), torch.autograd.enable_grad(): | 
|  | frame.recompute_fn(*args) | 
|  | except _StopRecomputationError: | 
|  | pass | 
|  | frame.is_recomputed[gid] = True | 
|  | frame.check_recomputed_tensors_match(gid) | 
|  |  | 
|  | _internal_assert(gid in holder.handles) | 
|  |  | 
|  | if holder.handles[gid] is None: | 
|  | raise CheckpointError( | 
|  | "torch.utils.checkpoint: Unpack is being triggered for a tensor that was already " | 
|  | "unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do " | 
|  | "so only once. Otherwise please open an issue with details on your use case." | 
|  | ) | 
|  | _internal_assert(holder.handles[gid] in frame.recomputed[gid]) | 
|  | ret = frame.recomputed[gid][holder.handles[gid]] | 
|  | holder.handles[gid] = None | 
|  | return ret | 
|  |  | 
|  | if frame.unpack_error_cb is not None: | 
|  | def unpack_hook_with_error_cb(holder): | 
|  | try: | 
|  | return unpack_hook(holder) | 
|  | except CheckpointError as e: | 
|  | frame.unpack_error_cb(e) | 
|  | super().__init__(pack_hook, unpack_hook_with_error_cb) | 
|  | else: | 
|  | super().__init__(pack_hook, unpack_hook) | 
|  |  | 
|  |  | 
|  | def _is_compiling(func, args, kwargs): | 
|  | # Check if we are under AOTAutograd tracing | 
|  | # There should probably be a better way to do this... | 
|  | # TODO: unify _is_compiling across all compile stacks | 
|  | for arg in args: | 
|  | if isinstance(arg, torch.Tensor) and is_fun(arg): | 
|  | return True | 
|  | return False | 
|  |  | 
|  |  | 
|  | class _VersionWrapper: | 
|  | # Check that cached tensors are not mutated. | 
|  | def __init__(self, val): | 
|  | self.val: Union[torch.Tensor, Any] = val | 
|  | self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None | 
|  |  | 
|  | def get_val(self, allow_cache_entry_mutation): | 
|  | if self.version is not None and not allow_cache_entry_mutation: | 
|  | if self.val._version != self.version: | 
|  | # Can we give user a stack trace of where the mutation happened? | 
|  | raise RuntimeError( | 
|  | "Tensor cached during selective activation checkpoint has been mutated" | 
|  | ) | 
|  | return self.val | 
|  |  | 
|  |  | 
|  | def _maybe_detach(x, any_ret_has_alias_info): | 
|  | # We detach for two separate reasons: | 
|  | # - For view ops, we need to ensure that when the tensor is returned from | 
|  | #   CachedDispatchMode, as_view sees that the AutogradMeta is nullptr | 
|  | # - Avoid reference cycles | 
|  | # For case 1, it is not enough to check whether x has differentiable dtype | 
|  | # because non-differentiable dtype can have non-nullptr AutogradMeta, e.g. | 
|  | # when the tensor is a view. | 
|  | if isinstance(x, torch.Tensor) and (x.is_floating_point() or x.is_complex() or any_ret_has_alias_info): | 
|  | with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False): | 
|  | # Ensure that view performed beneath autograd properly propagates | 
|  | # version counter. TODO: Use reentrant_dispatch instead of | 
|  | # manually manipulating dispatch keys. Using reentrant_dispatch | 
|  | # would respect inference_mode, though that is not relevant for | 
|  | # this case. | 
|  | x = x.detach() | 
|  | return x | 
|  |  | 
|  |  | 
|  | class SelectiveCheckpointContext: | 
|  | """ | 
|  | Context passed to policy function during selective checkpointing. | 
|  |  | 
|  | This class is used to pass relevant metadata to the policy function during | 
|  | selective checkpointing. The metadata includes whether the current invocation | 
|  | of the policy function is during recomputation or not. | 
|  |  | 
|  | Example: | 
|  | >>> # xdoctest: +SKIP(stub) | 
|  | >>> | 
|  | >>> def policy_fn(ctx, op, *args, **kwargs): | 
|  | >>>    print(ctx.is_recompute) | 
|  | >>> | 
|  | >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) | 
|  | >>> | 
|  | >>> out = torch.utils.checkpoint.checkpoint( | 
|  | >>>     fn, x, y, | 
|  | >>>     use_reentrant=False, | 
|  | >>>     context_fn=context_fn, | 
|  | >>> ) | 
|  | """ | 
|  | def __init__(self, *, is_recompute): | 
|  | self.is_recompute = is_recompute | 
|  |  | 
|  |  | 
|  | class CheckpointPolicy(enum.Enum): | 
|  | """ | 
|  | Enum for specifying the policy for checkpointing during backpropagation. | 
|  |  | 
|  | The following policies are supported: | 
|  |  | 
|  | - ``{MUST,PREFER}_SAVE``: The operation's output will be saved during the forward | 
|  | pass and will not be recomputed during the backward pass | 
|  | - ``{MUST,PREFER}_RECOMPUTE``: The operation's output will not be saved during the | 
|  | forward pass and will be recomputed during the backward pass | 
|  |  | 
|  | Use ``MUST_*`` over ``PREFER_*`` to indicate that the policy should not be overridden | 
|  | by other subsystems like `torch.compile`. | 
|  |  | 
|  | .. note:: | 
|  | A policy function that always returns ``PREFER_RECOMPUTE`` is | 
|  | equivalent to vanilla checkpointing. | 
|  |  | 
|  | A policy function that returns ``PREFER_SAVE`` every op is | 
|  | NOT equivalent to not using checkpointing. Using such a policy would | 
|  | save additional tensors not limited to ones that are actually needed for | 
|  | gradient computation. | 
|  | """ | 
|  | MUST_SAVE = 0 | 
|  | PREFER_SAVE = 1 | 
|  | MUST_RECOMPUTE = 2 | 
|  | PREFER_RECOMPUTE = 3 | 
|  |  | 
|  |  | 
|  | SAC_IGNORED_OPS = { | 
|  | # AC inserts different number of detach during forward and recompute. | 
|  | torch.ops.aten.detach.default, | 
|  | # AC's determinism check invokes additional metadata ops during forward. | 
|  | # With subclasses involved, these metadata ops become dispatchable, this | 
|  | # can result in incorrectness if these ops are selected cached. | 
|  | torch.ops.prim.device.default, | 
|  | } | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns) | 
|  |  | 
|  |  | 
|  | class _CachingTorchDispatchMode(TorchDispatchMode): | 
|  | # Used together with _CachedTorchDispatchMode to implement SAC. | 
|  | def __init__(self, policy_fn, storage): | 
|  | self.policy_fn = policy_fn | 
|  | self.storage = storage | 
|  |  | 
|  | def __torch_dispatch__(self, func, types, args=(), kwargs=None): | 
|  | if func in SAC_IGNORED_OPS: | 
|  | return func(*args, **kwargs) | 
|  |  | 
|  | kwargs = {} if kwargs is None else kwargs | 
|  | policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False), | 
|  | func, *args, **kwargs) | 
|  | is_compiling = _is_compiling(func, args, kwargs) | 
|  |  | 
|  | if is_compiling and policy == CheckpointPolicy.MUST_SAVE: | 
|  | fx_traceback.current_meta["recompute"] = 0 | 
|  |  | 
|  | out = func(*args, **kwargs) | 
|  |  | 
|  | any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns) | 
|  |  | 
|  | if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: | 
|  | self.storage[func].append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out)) | 
|  | return out | 
|  |  | 
|  | class _CachedTorchDispatchMode(TorchDispatchMode): | 
|  | # Used together with _CachedTorchDispatchMode to implement SAC. | 
|  | def __init__(self, policy_fn, storage, allow_cache_entry_mutation): | 
|  | self.policy_fn = policy_fn | 
|  | self.storage = storage | 
|  | self.allow_cache_entry_mutation = allow_cache_entry_mutation | 
|  |  | 
|  | def __torch_dispatch__(self, func, types, args=(), kwargs=None): | 
|  | if func in SAC_IGNORED_OPS: | 
|  | return func(*args, **kwargs) | 
|  |  | 
|  | kwargs = {} if kwargs is None else kwargs | 
|  | policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True), | 
|  | func, *args, **kwargs) | 
|  | is_compiling = _is_compiling(func, args, kwargs) | 
|  |  | 
|  | if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: | 
|  | storage = self.storage.get(func) | 
|  | if storage is None: | 
|  | raise RuntimeError(f"{func} encountered during backward, but not found in storage") | 
|  | if len(storage) == 0: | 
|  | raise RuntimeError( | 
|  | "Trying to backward an extra time. You are only allowed to backward once " | 
|  | "on any region computed under selective activation checkpoint." | 
|  | ) | 
|  | out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0)) | 
|  | else: | 
|  | out = func(*args, **kwargs) | 
|  | return out | 
|  |  | 
|  |  | 
|  | def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): | 
|  | """ | 
|  | Helper to avoid recomputing certain ops during activation checkpointing. | 
|  |  | 
|  | Use this with `torch.utils.checkpoint.checkpoint` to control which | 
|  | operations are recomputed during the backward pass. | 
|  |  | 
|  | Args: | 
|  | policy_fn_or_list (Callable or List): | 
|  | - If a policy function is provided, it should accept a | 
|  | :class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and | 
|  | kwargs to the op, and return a :class:`CheckpointPolicy` enum value | 
|  | indicating whether the execution of the op should be recomputed or not. | 
|  | - If a list of operations is provided, it is equivalent to a policy | 
|  | returning `CheckpointPolicy.MUST_SAVE` for the specified | 
|  | operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other | 
|  | operations. | 
|  | allow_cache_entry_mutation (bool, optional): By default, an error is | 
|  | raised if any tensors cached by selective activation checkpoint are | 
|  | mutated in order to ensure correctness. If set to `True`, this check | 
|  | is disabled. | 
|  | Returns: | 
|  | A tuple of two context managers. | 
|  |  | 
|  | Example: | 
|  | >>> # xdoctest: +REQUIRES(LINUX) | 
|  | >>> import functools | 
|  | >>> | 
|  | >>> x = torch.rand(10, 10, requires_grad=True) | 
|  | >>> y = torch.rand(10, 10, requires_grad=True) | 
|  | >>> | 
|  | >>> ops_to_save = [ | 
|  | >>>    torch.ops.aten.mm.default, | 
|  | >>> ] | 
|  | >>> | 
|  | >>> def policy_fn(ctx, op, *args, **kwargs): | 
|  | >>>    if op in ops_to_save: | 
|  | >>>        return CheckpointPolicy.MUST_SAVE | 
|  | >>>    else: | 
|  | >>>        return CheckpointPolicy.PREFER_RECOMPUTE | 
|  | >>> | 
|  | >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) | 
|  | >>> | 
|  | >>> # or equivalently | 
|  | >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) | 
|  | >>> | 
|  | >>> def fn(x, y): | 
|  | >>>     return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y | 
|  | >>> | 
|  | >>> out = torch.utils.checkpoint.checkpoint( | 
|  | >>>     fn, x, y, | 
|  | >>>     use_reentrant=False, | 
|  | >>>     context_fn=context_fn, | 
|  | >>> ) | 
|  | """ | 
|  | # NB: If grad_mode is disabled, checkpoint would not run forward under | 
|  | #     context_fn anyway, so proceed as usual. | 
|  | if isinstance(policy_fn_or_list, list): | 
|  | for op in policy_fn_or_list: | 
|  | if not isinstance(op, torch._ops.OpOverload): | 
|  | _extra_msg = ( | 
|  | "Please update the OpOverloadPacket to a specific OpOverload." | 
|  | "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`." | 
|  | ) if isinstance(op, torch._ops.OpOverloadPacket) else "" | 
|  | raise ValueError( | 
|  | f"Expected op in `op_list` to be an OpOverload but got: {op} " | 
|  | f"of type {type(op)}. {_extra_msg}" | 
|  | ) | 
|  |  | 
|  | def policy_fn(ctx, op, *args, **kwargs): | 
|  | if op in policy_fn_or_list: | 
|  | return CheckpointPolicy.MUST_SAVE | 
|  | else: | 
|  | return CheckpointPolicy.PREFER_RECOMPUTE | 
|  | elif callable(policy_fn_or_list): | 
|  | policy_fn = policy_fn_or_list | 
|  | else: | 
|  | raise TypeError("policy_fn_or_list must be either a function or a list of ops.") | 
|  |  | 
|  | storage: Dict[Any, List[Any]] = defaultdict(list) | 
|  | return ( | 
|  | _CachingTorchDispatchMode(policy_fn, storage), | 
|  | _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation), | 
|  | ) | 
|  |  | 
|  | # NB: this helper wraps fn before calling checkpoint_impl. kwargs and | 
|  | #     saving/restoring of global state is handled here. | 
|  |  | 
|  | def _checkpoint_without_reentrant_generator( | 
|  | fn, | 
|  | preserve_rng_state=True, | 
|  | context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, | 
|  | determinism_check: str = _DEFAULT_DETERMINISM_MODE, | 
|  | debug: bool = False, | 
|  | *args, | 
|  | **kwargs | 
|  | ): | 
|  | """Checkpointing without reentrant autograd. | 
|  |  | 
|  | Args: | 
|  | function: describes what to run in the forward pass of the model or | 
|  | part of the model. It should also know how to handle the inputs | 
|  | passed as the tuple. For example, in LSTM, if user passes | 
|  | ``(activation, hidden)``, :attr:`function` should correctly use the | 
|  | first input as ``activation`` and the second input as ``hidden`` | 
|  | preserve_rng_state(bool, optional):  Omit stashing and restoring | 
|  | the RNG state during each checkpoint. | 
|  | Default: ``True`` | 
|  | context_fn(Callable, optional): A callable returning a tuple of two | 
|  | context managers. The function and its recomputation will be run | 
|  | under the first and second context managers respectively. | 
|  | determinism_check(str, optional): A string specifying the determinism | 
|  | check to perform. By default it is set to ``"default"`` which | 
|  | compares the shapes, dtypes, and devices of the recomputed tensors | 
|  | against those the saved tensors. To turn off this check, specify | 
|  | ``"none"``. Currently these are the only two supported values. | 
|  | Please open an issue if you would like to see more determinism | 
|  | checks. | 
|  | debug(bool, optional): If ``True``, error messages will also include | 
|  | a trace of the operators ran during the original forward computation | 
|  | as well as the recomputation. | 
|  | *args: Arguments to pass in to the given ``function``. | 
|  | **kwargs: Keyword arguments to pass into the given ``function``. | 
|  | """ | 
|  | unpack_error_cb = None | 
|  |  | 
|  | if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug: | 
|  | if context_fn != noop_context_fn: | 
|  | raise ValueError( | 
|  | "debug=True is incompatible with non-default context_fn" | 
|  | ) | 
|  | context_fn, unpack_error_cb = _get_debug_context_and_cb() | 
|  |  | 
|  | if determinism_check in _allowed_determinism_checks_to_fns: | 
|  | metadata_fn = _allowed_determinism_checks_to_fns[determinism_check] | 
|  | else: | 
|  | raise ValueError( | 
|  | f"determinism_check should be one of {list(_allowed_determinism_checks_to_fns.keys())}, " | 
|  | f"but got {determinism_check}" | 
|  | ) | 
|  |  | 
|  | device = _infer_device_type(*args) | 
|  | device_module = _get_device_module(device) | 
|  | forward_context, recompute_context = context_fn() | 
|  | if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn: | 
|  | assert ( | 
|  | isinstance(forward_context, TorchDispatchMode) and | 
|  | isinstance(recompute_context, TorchDispatchMode) | 
|  | ), \ | 
|  | "In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " + \ | 
|  | "must generate a tuple of two `TorchDispatchMode`s." | 
|  | # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. | 
|  | device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device=device) | 
|  |  | 
|  | if preserve_rng_state: | 
|  | fwd_cpu_state = torch.get_rng_state() | 
|  | # Don't eagerly initialize the cuda context by accident. | 
|  | # (If the user intends that the context is initialized later, within their | 
|  | # run_function, we SHOULD actually stash the cuda state here.  Unfortunately, | 
|  | # we have no way to anticipate this will happen before we run the function. | 
|  | # If they do so, we raise an error.) | 
|  | had_device_in_fwd = False | 
|  | if getattr(device_module, "_initialized", False): | 
|  | had_device_in_fwd = True | 
|  | fwd_devices, fwd_device_states = get_device_states(*args) | 
|  |  | 
|  | def recompute_fn(*inputs): | 
|  | kwargs, *args = inputs | 
|  | # This will be called later during recomputation. This wrapping enables | 
|  | # the necessary global state to be captured. | 
|  | rng_devices = [] | 
|  | if preserve_rng_state and had_device_in_fwd: | 
|  | rng_devices = fwd_devices | 
|  | with torch.random.fork_rng( | 
|  | devices=rng_devices, enabled=preserve_rng_state, device_type=device | 
|  | ): | 
|  | if preserve_rng_state: | 
|  | torch.set_rng_state(fwd_cpu_state) | 
|  | if had_device_in_fwd: | 
|  | set_device_states(fwd_devices, fwd_device_states) | 
|  |  | 
|  | device_autocast_ctx = torch.amp.autocast( | 
|  | device_type=device, **device_autocast_kwargs | 
|  | ) if torch.amp.is_autocast_available(device) else contextlib.nullcontext() | 
|  | with device_autocast_ctx, torch.amp.autocast("cpu", **cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined] | 
|  | fn(*args, **kwargs) | 
|  |  | 
|  | new_frame = _CheckpointFrame( | 
|  | recompute_fn, | 
|  | _enable_checkpoint_early_stop, | 
|  | unpack_error_cb, | 
|  | metadata_fn | 
|  | ) | 
|  | dummy = torch.empty((0,), requires_grad=True) | 
|  | new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args) | 
|  |  | 
|  | # When ambient grad_mode is False | 
|  | if new_frame.input_saver.grad_fn is None: | 
|  | yield | 
|  | return | 
|  |  | 
|  | with _checkpoint_hook(new_frame), forward_context: | 
|  | yield | 
|  | new_frame.forward_completed = True | 
|  |  | 
|  | if getattr(device_module, "_initialized", False) and \ | 
|  | preserve_rng_state and not had_device_in_fwd:  # type: ignore[possibly-undefined] | 
|  | # Device was not initialized before running the forward, so we didn't | 
|  | # stash the device state. | 
|  | raise RuntimeError( | 
|  | "PyTorch's device state was initialized in the forward pass " | 
|  | "of a Checkpoint, which is not allowed. Please open an issue " | 
|  | "if you need this feature." | 
|  | ) | 
|  |  | 
|  | return |