blob: 4e2e3934039795b931391aa8b11fae37ca118a84 [file] [log] [blame]
import torch
import torch.utils._pytree as pytree
from collections import namedtuple
import functools
# NOTE [CustomOp autograd kernel indirection]
# We register `inner` as the autograd kernel for this custom_op.
# `inner` either calls the autograd formula registered by the user,
# or goes into an `autograd_not_implemented` kernel.
#
# The reason why this indirection exists is
# so that we can swap out the autograd kernel (the PyTorch dispatcher
# doesn't actually allow us to do this). By default, we want
# the `autograd_not_implemented` behavior, but then the user may come
# and register something that is actually a backward formula
def autograd_kernel_indirection(custom_op):
autograd_fallback = autograd_not_implemented(custom_op)
def inner(*args, **kwargs):
if custom_op._has_impl('autograd'):
kernel = custom_op._get_impl('autograd').func
return kernel(*args, **kwargs)
# As explained in NOTE ["backward", "save_for_backward", and "autograd"],
# after the user gives us "backward" and "save_for_backward", we generate
# the "autograd" impl. If the user only provided one, then we tell
# the user they've done something wrong.
if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'):
missing = (
'save_for_backward' if custom_op._has_impl('backward')
else 'backward'
)
found = 'save_for_backward' if missing == 'backward' else 'backward'
loc = custom_op._get_impl(found).location
raise RuntimeError(
f"We found a '{found}' registration for {custom_op} at "
f"{loc} but were unable to find a '{missing}' registration. "
f"To use the CustomOp API to register a backward formula, "
f"please provide us both a backward function and a "
f"'save for backward' function via `impl_backward` and "
f"`impl_save_for_backward` respectively.")
return autograd_fallback(*args, **kwargs)
return inner
# TODO(#101191): Use the actual C++ autograd not implemented fallback,
# or change the default autograd fallback to the autograd not implemented fallback.
def autograd_not_implemented(custom_op):
def kernel(*args, **kwargs):
if torch.is_grad_enabled() and pytree.tree_any(
lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
):
raise RuntimeError("Autograd has not been implemented for operator")
with torch._C._AutoDispatchBelowAutograd():
return custom_op(*args, **kwargs)
return kernel
def construct_autograd_kernel(
schema,
output_differentiability,
forward_op,
save_for_backward_fn,
backward_fn):
def apply(*args):
flat_args, spec = pytree.tree_flatten(args)
def forward(ctx, *flat_args):
ctx.set_materialize_grads(True)
args = pytree.tree_unflatten(list(flat_args), spec)
with torch._C._AutoDispatchBelowAutograd():
output = forward_op(*args)
# We use the info about args to give better error messages in backward
args_info = namedtuple_args(
schema, pytree.tree_map(lambda arg: type(arg), args))
save_for_backward_fn_inputs = namedtuple_args(schema, args)
to_save = save_for_backward_fn(save_for_backward_fn_inputs, output)
save_pytree_for_backward(ctx, (to_save, args_info))
# Output must be one or more Tensors, no TensorList (yet)
if output_differentiability is not None:
if isinstance(output, tuple):
assert len(output_differentiability) == len(output)
for differentiable, out in zip(output_differentiability, output):
if not differentiable:
ctx.mark_non_differentiable(out)
else:
assert len(output_differentiability) == 1
if not output_differentiability[0]:
ctx.mark_non_differentiable(output)
return output
def backward(ctx, *grads):
saved, args_info = unpack_saved(ctx)
# There is nothing on the ctx object for now, it is just there so
# that we can add additional things in the future.
inner_ctx = object()
grad_inputs_dict = backward_fn(inner_ctx, saved, *grads)
# Massage the grad_inputs_dict to a form acceptable by
# autograd.Function.
validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info)
return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info)
generated_cls = gen_autograd_function(
forward_op._opname + '_customop', forward, backward)
return generated_cls.apply(*flat_args)
return apply
def gen_autograd_function(name, forward, backward):
generated_cls = type(
name,
(torch.autograd.Function,),
{
'forward': staticmethod(forward),
'backward': staticmethod(backward),
}
)
return generated_cls
@functools.lru_cache
def namedtuple_args_cls(schema):
attribs = [arg.name for arg in schema.arguments.flat_all]
name = str(schema.name) + "_args"
# mypy doesn't support dynamic namedtuple name
tuple_cls = namedtuple(name, attribs) # type: ignore[misc]
return tuple_cls
def namedtuple_args(schema, args):
assert isinstance(args, tuple)
tuple_cls = namedtuple_args_cls(schema)
return tuple_cls(*args)
def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info):
def error(what):
backward = forward_op._get_impl('backward')
raise RuntimeError(
f"In the backward function defined for {forward_op} at "
f"{backward.location} using the CustomOp API, {what}")
if not isinstance(grad_inputs_dict, dict):
error(f"expected the output of the backward function to be a dict but "
f"got {type(grad_inputs_dict)}")
expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all
if arg.type.is_tensor_like()}
actual_keys = grad_inputs_dict.keys()
if expected_keys != actual_keys:
error(f"expected the returned grad_input dict to have keys "
f"{expected_keys} but got {actual_keys}. The backward "
f"function must return a gradient (can be None) for each arg "
f"to the CustomOp that may be a Tensor or Sequence[Tensor]. "
f"Args declared to be non-Tensor-like types should not appear "
f"in the grad_input dict")
for name, grad in grad_inputs_dict.items():
arg_info = getattr(args_info, name)
if isinstance(arg_info, list):
if not isinstance(grad, (tuple, list)):
error(f"for input '{name}' expected the grad_input dict to "
f"hold a list of gradients but got object of type "
f"{type(grad)}.")
if not len(grad) == len(arg_info):
error(f"for input '{name}' expected the grad_input dict to "
f"hold a list of {len(arg_info)} gradients but got "
f"{len(grad)}")
for idx, (g, info) in enumerate(zip(grad, arg_info)):
if g is None:
continue
if not isinstance(g, torch.Tensor):
error(f"for input '{name}' expected the grad_input dict to "
f"hold a list of None or Tensor gradients but got "
f"object of {type(g)} at index {idx}")
if info != torch.Tensor:
error(f"for input '{name}', got a Tensor as the gradient "
f"for the {idx}-th value but expected None because "
f"the {idx}-th value was not a Tensor (it was "
f"type {arg_info}")
continue
if grad is None:
continue
if not isinstance(grad, torch.Tensor):
error(f"got object of type {type(grad)} as the gradient for input "
f"'{name}', "
f"but expected the gradient to be either None or a Tensor")
if arg_info != torch.Tensor:
error(f"got a Tensor as the gradient for input '{name}' but "
f"expected None as the gradient because input '{name}' "
f"was not a Tensor (it was type {arg_info}).")
def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info):
result = []
for name, arg_info in args_info._asdict().items():
if name not in grad_inputs_dict:
result.append(pytree.tree_map(lambda x: None, arg_info))
continue
result.append(grad_inputs_dict[name])
return tuple(pytree.tree_flatten(result)[0])
# Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it.
# autograd.Function prefers that users use ctx.save_for_backward to
# save Tensors (to avoid reference cycles) and for non-Tensors to go onto the
# ctx object.
def save_pytree_for_backward(ctx, stuff):
flat_stuff, spec = pytree.tree_flatten(stuff)
num_elts = len(flat_stuff)
tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
if isinstance(thing, torch.Tensor)]
non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff)
if not isinstance(thing, torch.Tensor)]
tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)]
non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)]
ctx.spec = spec
ctx.num_elts = num_elts
ctx.save_for_backward(*tensors)
ctx.tensor_idxs = tensor_idxs
ctx.saved_non_tensors = non_tensors
ctx.non_tensor_idxs = non_tensor_idxs
# Inverse operation to save_pytree_for_backward
def unpack_saved(ctx):
flat_stuff = [None] * ctx.num_elts
for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs):
flat_stuff[idx] = tensor
for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs):
flat_stuff[idx] = non_tensor
stuff = pytree.tree_unflatten(flat_stuff, ctx.spec)
return stuff