| from __future__ import absolute_import, division, print_function, unicode_literals |
| import torch |
| import warnings |
| |
| |
| def detach_variable(inputs): |
| if isinstance(inputs, tuple): |
| out = [] |
| for inp in inputs: |
| if not isinstance(inp, torch.Tensor): |
| out.append(inp) |
| continue |
| |
| x = inp.detach() |
| x.requires_grad = inp.requires_grad |
| out.append(x) |
| return tuple(out) |
| else: |
| raise RuntimeError( |
| "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__) |
| |
| |
| def check_backward_validity(inputs): |
| if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)): |
| warnings.warn("None of the inputs have requires_grad=True. Gradients will be None") |
| |
| |
| # We can't know if the run_fn will internally move some args to different devices, |
| # which would require logic to preserve rng states for those devices as well. |
| # We could paranoically stash and restore ALL the rng states for all visible devices, |
| # but that seems very wasteful for most cases. Compromise: Stash the RNG state for |
| # the device of all Tensor args. |
| # |
| # To consider: maybe get_device_states and set_device_states should reside in torch/random.py? |
| def get_device_states(*args): |
| # This will not error out if "arg" is a CPU tensor or a non-tensor type because |
| # the conditionals short-circuit. |
| fwd_gpu_devices = list(set(arg.get_device() for arg in args |
| if isinstance(arg, torch.Tensor) and arg.is_cuda)) |
| |
| fwd_gpu_states = [] |
| for device in fwd_gpu_devices: |
| with torch.cuda.device(device): |
| fwd_gpu_states.append(torch.cuda.get_rng_state()) |
| |
| return fwd_gpu_devices, fwd_gpu_states |
| |
| |
| def set_device_states(devices, states): |
| for device, state in zip(devices, states): |
| with torch.cuda.device(device): |
| torch.cuda.set_rng_state(state) |
| |
| |
| class CheckpointFunction(torch.autograd.Function): |
| |
| @staticmethod |
| def forward(ctx, run_function, preserve_rng_state, *args): |
| check_backward_validity(args) |
| ctx.run_function = run_function |
| ctx.preserve_rng_state = preserve_rng_state |
| if preserve_rng_state: |
| ctx.fwd_cpu_state = torch.get_rng_state() |
| # Don't eagerly initialize the cuda context by accident. |
| # (If the user intends that the context is initialized later, within their |
| # run_function, we SHOULD actually stash the cuda state here. Unfortunately, |
| # we have no way to anticipate this will happen before we run the function.) |
| ctx.had_cuda_in_fwd = False |
| if torch.cuda._initialized: |
| ctx.had_cuda_in_fwd = True |
| ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) |
| ctx.save_for_backward(*args) |
| with torch.no_grad(): |
| outputs = run_function(*args) |
| return outputs |
| |
| @staticmethod |
| def backward(ctx, *args): |
| if not torch.autograd._is_checkpoint_valid(): |
| raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") |
| inputs = ctx.saved_tensors |
| # Stash the surrounding rng state, and mimic the state that was |
| # present at this time during forward. Restore the surrouding state |
| # when we're done. |
| rng_devices = [] |
| if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: |
| rng_devices = ctx.fwd_gpu_devices |
| with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): |
| if ctx.preserve_rng_state: |
| torch.set_rng_state(ctx.fwd_cpu_state) |
| if ctx.had_cuda_in_fwd: |
| set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) |
| detached_inputs = detach_variable(inputs) |
| with torch.enable_grad(): |
| outputs = ctx.run_function(*detached_inputs) |
| |
| if isinstance(outputs, torch.Tensor): |
| outputs = (outputs,) |
| torch.autograd.backward(outputs, args) |
| grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp |
| for inp in detached_inputs) |
| return (None, None) + grads |
| |
| |
| def checkpoint(function, *args, **kwargs): |
| r"""Checkpoint a model or part of the model |
| |
| Checkpointing works by trading compute for memory. Rather than storing all |
| intermediate activations of the entire computation graph for computing |
| backward, the checkpointed part does **not** save intermediate activations, |
| and instead recomputes them in backward pass. It can be applied on any part |
| of a model. |
| |
| Specifically, in the forward pass, :attr:`function` will run in |
| :func:`torch.no_grad` manner, i.e., not storing the intermediate |
| activations. Instead, the forward pass saves the inputs tuple and the |
| :attr:`function` parameter. In the backwards pass, the saved inputs and |
| :attr:`function` is retreived, and the forward pass is computed on |
| :attr:`function` again, now tracking the intermediate activations, and then |
| the gradients are calculated using these activation values. |
| |
| .. warning:: |
| Checkpointing doesn't work with :func:`torch.autograd.grad`, but only |
| with :func:`torch.autograd.backward`. |
| |
| .. warning:: |
| If :attr:`function` invocation during backward does anything different |
| than the one during forward, e.g., due to some global variable, the |
| checkpointed version won't be equivalent, and unfortunately it can't be |
| detected. |
| |
| .. warning: |
| At least one of the inputs needs to have :code:`requires_grad=True` if |
| grads are needed for model inputs, otherwise the checkpointed part of the |
| model won't have gradients. |
| |
| Args: |
| function: describes what to run in the forward pass of the model or |
| part of the model. It should also know how to handle the inputs |
| passed as the tuple. For example, in LSTM, if user passes |
| ``(activation, hidden)``, :attr:`function` should correctly use the |
| first input as ``activation`` and the second input as ``hidden`` |
| preserve_rng_state(bool, optional, default=True): Omit stashing and restoring |
| the RNG state during each checkpoint. |
| args: tuple containing inputs to the :attr:`function` |
| |
| Returns: |
| Output of running :attr:`function` on :attr:`*args` |
| """ |
| # Hack to mix *args with **kwargs in a python 2.7-compliant way |
| preserve = kwargs.pop('preserve_rng_state', True) |
| if kwargs: |
| raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) |
| |
| return CheckpointFunction.apply(function, preserve, *args) |
| |
| |
| # TODO(sublee): When releasing PyTorch 1.3, |
| # fix the function signature to not accept variadic arguments. |
| # See also: https://github.com/pytorch/pytorch/issues/19260 |
| def checkpoint_sequential(functions, segments, *inputs, **kwargs): |
| r"""A helper function for checkpointing sequential models. |
| |
| Sequential models execute a list of modules/functions in order |
| (sequentially). Therefore, we can divide such a model in various segments |
| and checkpoint each segment. All segments except the last will run in |
| :func:`torch.no_grad` manner, i.e., not storing the intermediate |
| activations. The inputs of each checkpointed segment will be saved for |
| re-running the segment in the backward pass. |
| |
| See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. |
| |
| .. warning:: |
| Checkpointing doesn't work with :func:`torch.autograd.grad`, but only |
| with :func:`torch.autograd.backward`. |
| |
| .. warning: |
| At least one of the inputs needs to have :code:`requires_grad=True` if |
| grads are needed for model inputs, otherwise the checkpointed part of the |
| model won't have gradients. |
| |
| Args: |
| functions: A :class:`torch.nn.Sequential` or the list of modules or |
| functions (comprising the model) to run sequentially. |
| segments: Number of chunks to create in the model |
| inputs: tuple of Tensors that are inputs to :attr:`functions` |
| preserve_rng_state(bool, optional, default=True): Omit stashing and restoring |
| the RNG state during each checkpoint. |
| |
| Returns: |
| Output of running :attr:`functions` sequentially on :attr:`*inputs` |
| |
| Example: |
| >>> model = nn.Sequential(...) |
| >>> input_var = checkpoint_sequential(model, chunks, input_var) |
| """ |
| # Hack to mix *args with **kwargs in a python 2.7-compliant way |
| preserve = kwargs.pop('preserve_rng_state', True) |
| if kwargs: |
| raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) |
| |
| # To accept variadic arguments is not consistent with nn.Sequential. |
| # This interface will be changed at PyTorch 1.3. |
| # See also: https://github.com/pytorch/pytorch/issues/19260 |
| if not inputs: |
| warnings.warn('Giving no input to checkpoint_sequential has been deprecated, ' |
| 'a TypeError will be raised after PyTorch 1.3', |
| DeprecationWarning) |
| elif len(inputs) > 1: |
| warnings.warn('multiple inputs to checkpoint_sequential has been deprecated, ' |
| 'a TypeError will be raised after PyTorch 1.3', |
| DeprecationWarning) |
| |
| def run_function(start, end, functions): |
| def forward(*inputs): |
| for j in range(start, end + 1): |
| if isinstance(inputs, tuple): |
| inputs = functions[j](*inputs) |
| else: |
| inputs = functions[j](inputs) |
| return inputs |
| return forward |
| |
| if isinstance(functions, torch.nn.Sequential): |
| functions = list(functions.children()) |
| |
| segment_size = len(functions) // segments |
| # the last chunk has to be non-volatile |
| end = -1 |
| for start in range(0, segment_size * (segments - 1), segment_size): |
| end = start + segment_size - 1 |
| inputs = checkpoint(run_function(start, end, functions), *inputs, |
| preserve_rng_state=preserve) |
| if not isinstance(inputs, tuple): |
| inputs = (inputs,) |
| return run_function(end + 1, len(functions) - 1, functions)(*inputs) |