| import torch |
| import torch._C as _C |
| import torch.utils.hooks as hooks |
| from torch._six import with_metaclass |
| import functools |
| import warnings |
| from collections import OrderedDict |
| |
| |
| class _ContextMethodMixin(object): |
| |
| def save_for_backward(self, *tensors): |
| """Saves given tensors for a future call to :func:`~Function.backward`. |
| |
| **This should be called at most once, and only from inside the** |
| :func:`forward` **method.** |
| |
| Later, saved tensors can be accessed through the :attr:`saved_tensors` |
| attribute; or, if the corresponding Variable is needed (e.g. for double |
| backwards), those can be accessed through the :attr:`saved_variables` |
| attribute. Before returning them to the user, a check is made, to ensure |
| they weren't used in any in-place operation that modified their content. |
| |
| Arguments can also be ``None``. |
| """ |
| self.to_save = tensors |
| |
| def mark_dirty(self, *args): |
| """Marks given tensors as modified in an in-place operation. |
| |
| **This should be called at most once, only from inside the** |
| :func:`forward` **method, and all arguments should be inputs.** |
| |
| Every tensor that's been modified in-place in a call to :func:`forward` |
| should be given to this function, to ensure correctness of our checks. |
| It doesn't matter whether the function is called before or after |
| modification. |
| """ |
| self.dirty_tensors = args |
| |
| def mark_shared_storage(self, *pairs): |
| warnings.warn( |
| 'mark_shared_storage is deprecated. ' |
| 'Tensors with shared storages are automatically tracked. Note ' |
| 'that calls to `set_()` are not tracked') |
| |
| def mark_non_differentiable(self, *args): |
| """Marks outputs as non-differentiable. |
| |
| **This should be called at most once, only from inside the** |
| :func:`forward` **method, and all arguments should be outputs.** |
| |
| This will mark outputs as not requiring gradients, increasing the |
| efficiency of backward computation. You still need to accept a gradient |
| for each output in :meth:`~Function.backward`, but it's always going to |
| be ``None``. |
| |
| This is used e.g. for indices returned from a max :class:`Function`. |
| """ |
| self.non_differentiable = args |
| |
| |
| class _HookMixin(object): |
| |
| @staticmethod |
| def _register_hook(backward_hooks, hook): |
| if backward_hooks is None: |
| backward_hooks = OrderedDict() |
| handle = hooks.RemovableHandle(backward_hooks) |
| backward_hooks[handle.id] = hook |
| return backward_hooks, handle |
| |
| |
| class BackwardCFunction(_C._FunctionBase, _ContextMethodMixin, _HookMixin): |
| _is_legacy = False |
| |
| def apply(self, *args): |
| return self._forward_cls.backward(self, *args) |
| |
| |
| class FunctionMeta(type): |
| """Function metaclass. |
| |
| This metaclass sets up the following properties: |
| _is_legacy: True if forward is not defined as a static method. |
| _backward_cls: The Function class corresponding to the differentiated |
| version of this function (which is generated on the fly by this |
| metaclass). |
| """ |
| |
| def __init__(cls, name, bases, attrs): |
| for super_cls in cls.mro(): |
| forward = super_cls.__dict__.get('forward') |
| if forward is not None: |
| has_static_forward = isinstance(forward, staticmethod) or isinstance(forward, classmethod) |
| break |
| |
| setattr(cls, '_is_legacy', not has_static_forward) |
| |
| # old-style functions |
| if not has_static_forward: |
| return super(FunctionMeta, cls).__init__(name, bases, attrs) |
| |
| backward_fn = type(name + 'Backward', (BackwardCFunction,), {'_forward_cls': cls}) |
| setattr(cls, '_backward_cls', backward_fn) |
| |
| return super(FunctionMeta, cls).__init__(name, bases, attrs) |
| |
| |
| class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)): |
| """Records operation history and defines formulas for differentiating ops. |
| |
| Every operation performed on :class:`Variable` s creates a new function |
| object, that performs the computation, and records that it happened. |
| The history is retained in the form of a DAG of functions, with edges |
| denoting data dependencies (``input <- output``). Then, when backward is |
| called, the graph is processed in the topological ordering, by calling |
| :func:`backward` methods of each :class:`Function` object, and passing |
| returned gradients on to next :class:`Function` s. |
| |
| Normally, the only way users interact with functions is by creating |
| subclasses and defining new operations. This is a recommended way of |
| extending torch.autograd. |
| |
| Each function is meant to be used only once (in the forward pass). |
| |
| Attributes: |
| requires_grad: Boolean indicating whether the :func:`backward` will |
| ever need to be called. |
| |
| Examples:: |
| |
| >>> class Exp(Function): |
| >>> |
| >>> @staticmethod |
| >>> def forward(ctx, i): |
| >>> result = i.exp() |
| >>> ctx.save_for_backward(result) |
| >>> return result |
| >>> |
| >>> @staticmethod |
| >>> def backward(ctx, grad_output): |
| >>> result, = ctx.saved_variables |
| >>> return grad_output * result |
| """ |
| |
| # only for backward compatibility |
| __call__ = _C._FunctionBase._do_forward |
| |
| # for the tracer |
| is_traceable = False |
| |
| @staticmethod |
| def forward(ctx, *args, **kwargs): |
| """Performs the operation. |
| |
| This function is to be overriden by all subclasses. |
| |
| It must accept a context ctx as the first argument, followed by any |
| number of arguments (tensors or other types). |
| |
| The context can be used to store variables that can be then retrieved |
| during the backward pass. |
| """ |
| raise NotImplementedError |
| |
| @staticmethod |
| def backward(ctx, *grad_outputs): |
| """Defines a formula for differentiating the operation. |
| |
| This function is to be overriden by all subclasses. |
| |
| It must accept a context ctx as the first argument, followed by as many |
| outputs did :func:`forward` return, and it should return as many |
| tensors, as there were inputs to :func:`forward`. Each argument is the |
| gradient w.r.t the given output, and each returned value should be the |
| gradient w.r.t. the corresponding input. |
| |
| The context can be used to retrieve variables saved during the forward |
| pass. |
| """ |
| raise NotImplementedError |
| |
| |
| def once_differentiable(fn): |
| from .variable import Variable |
| |
| @functools.wraps(fn) |
| def wrapper(ctx, *args): |
| with torch.no_grad(): |
| outputs = fn(ctx, *args) |
| |
| if not torch.is_grad_enabled(): |
| return outputs |
| |
| # If any of the inputs have requires_grad=True, we force the outputs |
| # to have requires_grad=True but point to a grad_fn which throws an |
| # error message during (double) back-propagation. |
| # XXX: this is only an approximation of requires_grad - there's no way |
| # to figure out if fn didn't use ctx.saved_variables and as a result |
| # some Variables might require grad, even if no args do. |
| # Unfortunately, this leads to unexpected error messages ("no nodes |
| # require computing gradients"), but I don't have a better idea. |
| # These functions would raise an error in backward anyway. |
| requires_grad = any(isinstance(arg, Variable) and arg.requires_grad |
| for arg in args) |
| if not requires_grad: |
| return outputs |
| |
| err_fn = torch._C._functions.DelayedError( |
| b"trying to differentiate twice a function that was marked" |
| b"with @once_differentiable") |
| |
| if not isinstance(outputs, tuple): |
| outputs = (outputs,) |
| |
| # Create aliases of each output that has requires_grad=True. We need |
| # at least one of the inputs to err_fn to require grad so that the |
| # output will have a grad_fn. |
| def fake_requires_grad(var): |
| if var is not None: |
| var = var.detach() |
| var.requires_grad = True |
| return var |
| |
| return err_fn(*[fake_requires_grad(v) for v in outputs]) |
| return wrapper |
| |
| |
| def traceable(fn_cls): |
| """Marks Function as traceable for the JIT. |
| |
| Traceable functions have additional restrictions - they can't pass any |
| data-dependent values to backward (e.g. Prod passes the output, which makes |
| it non-traceable), and their backward should be implemented entirely in terms |
| of operations on autograd Variables in all cases. |
| |
| DON'T USE THIS DECORATOR. IT IS FOR INTERNAL USE ONLY AND SHOULD BE HANDLED WITH |
| CARE (or can give incorrect results otherwise). |
| """ |
| fn_cls.is_traceable = True |
| return fn_cls |
| |
| |
| class InplaceFunction(Function): |
| |
| def __init__(self, inplace=False): |
| super(InplaceFunction, self).__init__() |
| self.inplace = inplace |
| |
| |
| def _nested_map(condition, fn, condition_msg=None): |
| def _map(obj): |
| if condition(obj): |
| return fn(obj) |
| elif obj is None: |
| return None |
| elif isinstance(obj, (list, tuple)): |
| return type(obj)(_map(x) for x in obj) |
| else: |
| raise ValueError("Auto nesting doesn't know how to process " |
| "an input object of type " + torch.typename(obj) + |
| (". Accepted types: " + condition_msg + |
| ", or lists/tuples of them" |
| if condition_msg else "")) |
| |
| return _map |
| |
| |
| def _iter_filter(condition, skip_unknown=False, condition_msg=None): |
| def _iter(obj): |
| if condition(obj): |
| yield obj |
| elif obj is None: |
| return |
| elif isinstance(obj, (list, tuple)): |
| for o in obj: |
| for var in _iter(o): |
| yield var |
| elif not skip_unknown: |
| raise ValueError("Auto nesting doesn't know how to process " |
| "an input object of type " + torch.typename(obj) + |
| (". Accepted types: " + condition_msg + |
| ", or lists/tuples of them" |
| if condition_msg else "")) |
| |
| return _iter |
| |
| |
| def _unflatten(input, proto): |
| # unflatten a list or tuple input into a nested list/tuple structure |
| # specified by proto |
| def unflatten_helper(input, proto): |
| res = [] |
| if not isinstance(proto, (list, tuple)): |
| return input[0], input[1:] |
| for e in proto: |
| res_e, input = unflatten_helper(input, e) |
| res.append(res_e) |
| return type(proto)(res), input |
| |
| return unflatten_helper(input, proto)[0] |
| |
| |
| _iter_variables = _iter_filter(lambda o: isinstance(o, torch.autograd.Variable), condition_msg="Variables") |
| _iter_variables_permissive = _iter_filter(lambda o: isinstance(o, torch.autograd.Variable), skip_unknown=True) |
| _iter_jit_values = _iter_filter(lambda o: o is None or isinstance(o, torch._C.Value), |
| condition_msg="jit's Values or None") |
| _iter_tensors = _iter_filter(torch.is_tensor, condition_msg="Tensors") |
| _iter_None_tensors = _iter_filter( |
| lambda o: o is None or torch.is_tensor(o) or isinstance(o, torch.autograd.Variable), |
| condition_msg="Tensors or None") |
| _map_variable_tensor = _nested_map(lambda o: isinstance(o, torch.autograd.Variable), |
| lambda o: o.data, condition_msg="Variables") |
| |
| |
| class NestedIOFunction(Function): |
| |
| def _do_forward(self, *input): |
| self._nested_input = input |
| flat_input = tuple(_iter_variables(input)) |
| flat_output = super(NestedIOFunction, self)._do_forward(*flat_input) |
| nested_output = self._nested_output |
| nested_variables = _unflatten(flat_output, self._nested_output) |
| return nested_variables |
| |
| def _do_backward(self, gradients, retain_variables): |
| self.retain_variables = retain_variables |
| result = super(NestedIOFunction, self)._do_backward(gradients, retain_variables) |
| if not retain_variables: |
| del self._nested_output |
| del self._to_save_nested |
| return result |
| |
| def backward(self, *gradients): |
| nested_gradients = _unflatten(gradients, self._nested_output) |
| result = self.backward_extended(*nested_gradients) |
| return tuple(_iter_None_tensors(result)) |
| |
| __call__ = _do_forward |
| |
| def forward(self, *args): |
| nested_tensors = _map_variable_tensor(self._nested_input) |
| result = self.forward_extended(*nested_tensors) |
| del self._nested_input |
| self._nested_output = result |
| return tuple(_iter_tensors(result)) |
| |
| def save_for_backward(self, *args): |
| self.to_save = tuple(_iter_tensors(args)) |
| self._to_save_nested = args |
| |
| @property |
| def saved_tensors(self): |
| flat_tensors = super(NestedIOFunction, self).saved_tensors |
| return _unflatten(flat_tensors, self._to_save_nested) |
| |
| def mark_dirty(self, *args, **kwargs): |
| self.dirty_tensors = tuple(_iter_tensors((args, kwargs))) |
| |
| def mark_non_differentiable(self, *args, **kwargs): |
| self.non_differentiable = tuple(_iter_tensors((args, kwargs))) |
| |
| def forward_extended(self, *input): |
| raise NotImplementedError |
| |
| def backward_extended(self, *grad_output): |
| raise NotImplementedError |