|  | import torch | 
|  | import torch.utils._pytree as pytree | 
|  | from collections import namedtuple | 
|  | import functools | 
|  |  | 
|  |  | 
|  | # NOTE [CustomOp autograd kernel indirection] | 
|  | # We register `inner` as the autograd kernel for this custom_op. | 
|  | # `inner` either calls the autograd formula registered by the user, | 
|  | # or goes into an `autograd_not_implemented` kernel. | 
|  | # | 
|  | # The reason why this indirection exists is | 
|  | # so that we can swap out the autograd kernel (the PyTorch dispatcher | 
|  | # doesn't actually allow us to do this). By default, we want | 
|  | # the `autograd_not_implemented` behavior, but then the user may come | 
|  | # and register something that is actually a backward formula | 
|  | def autograd_kernel_indirection(custom_op): | 
|  | # TODO(#101191): Use the actual C++ autograd not implemented fallback, | 
|  | # or change the default autograd fallback to the autograd not implemented fallback. | 
|  | def autograd_not_implemented(*args, **kwargs) -> None: | 
|  | if torch.is_grad_enabled() and pytree.tree_any( | 
|  | lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs) | 
|  | ): | 
|  | raise RuntimeError("Autograd has not been implemented for operator") | 
|  | guard = torch._C._AutoDispatchBelowAutograd() | 
|  | try: | 
|  | return custom_op(*args, **kwargs) | 
|  | finally: | 
|  | del guard | 
|  |  | 
|  | def inner(*args, **kwargs): | 
|  | if custom_op._has_impl('autograd'): | 
|  | kernel = custom_op._get_impl('autograd').func | 
|  | return kernel(*args, **kwargs) | 
|  | # As explained in NOTE ["backward", "save_for_backward", and "autograd"], | 
|  | # after the user gives us "backward" and "save_for_backward", we generate | 
|  | # the "autograd" impl. If the user only provided one, then we tell | 
|  | # the user they've done something wrong. | 
|  | if custom_op._has_impl('save_for_backward') or custom_op._has_impl('backward'): | 
|  | missing = ( | 
|  | 'save_for_backward' if custom_op._has_impl('backward') | 
|  | else 'backward' | 
|  | ) | 
|  | found = 'save_for_backward' if missing == 'backward' else 'backward' | 
|  | loc = custom_op._get_impl(found).location | 
|  | raise RuntimeError( | 
|  | f"We found a '{found}' registration for {custom_op} at " | 
|  | f"{loc} but were unable to find a '{missing}' registration. " | 
|  | f"To use the CustomOp API to register a backward formula, " | 
|  | f"please provide us both a backward function and a " | 
|  | f"'save for backward' function via `impl_backward` and " | 
|  | f"`impl_save_for_backward` respectively.") | 
|  | return autograd_not_implemented(*args, **kwargs) | 
|  | return inner | 
|  |  | 
|  |  | 
|  | def construct_autograd_kernel( | 
|  | schema, | 
|  | output_differentiability, | 
|  | forward_op, | 
|  | save_for_backward_fn, | 
|  | backward_fn): | 
|  |  | 
|  | def apply(*args): | 
|  | flat_args, spec = pytree.tree_flatten(args) | 
|  |  | 
|  | def forward(ctx, *flat_args): | 
|  | ctx.set_materialize_grads(True) | 
|  | args = pytree.tree_unflatten(list(flat_args), spec) | 
|  | guard = torch._C._AutoDispatchBelowAutograd() | 
|  | try: | 
|  | output = forward_op(*args) | 
|  | finally: | 
|  | del guard | 
|  |  | 
|  | # We use the info about args to give better error messages in backward | 
|  | args_info = namedtuple_args( | 
|  | schema, pytree.tree_map(lambda arg: type(arg), args)) | 
|  |  | 
|  | save_for_backward_fn_inputs = namedtuple_args(schema, args) | 
|  | to_save = save_for_backward_fn(save_for_backward_fn_inputs, output) | 
|  |  | 
|  | save_pytree_for_backward(ctx, (to_save, args_info)) | 
|  |  | 
|  | # Output must be one or more Tensors, no TensorList (yet) | 
|  | if output_differentiability is not None: | 
|  | if isinstance(output, tuple): | 
|  | assert len(output_differentiability) == len(output) | 
|  | for differentiable, out in zip(output_differentiability, output): | 
|  | if not differentiable: | 
|  | ctx.mark_non_differentiable(out) | 
|  | else: | 
|  | assert len(output_differentiability) == 1 | 
|  | if not output_differentiability[0]: | 
|  | ctx.mark_non_differentiable(output) | 
|  |  | 
|  | return output | 
|  |  | 
|  | def backward(ctx, *grads): | 
|  | saved, args_info = unpack_saved(ctx) | 
|  | # There is nothing on the ctx object for now, it is just there so | 
|  | # that we can add additional things in the future. | 
|  | inner_ctx = object() | 
|  | grad_inputs_dict = backward_fn(inner_ctx, saved, *grads) | 
|  |  | 
|  | # Massage the grad_inputs_dict to a form acceptable by | 
|  | # autograd.Function. | 
|  | validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info) | 
|  | return grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info) | 
|  |  | 
|  | generated_cls = gen_autograd_function( | 
|  | forward_op._opname + '_customop', forward, backward) | 
|  |  | 
|  | return generated_cls.apply(*flat_args) | 
|  | return apply | 
|  |  | 
|  |  | 
|  | def gen_autograd_function(name, forward, backward): | 
|  | generated_cls = type( | 
|  | name, | 
|  | (torch.autograd.Function,), | 
|  | { | 
|  | 'forward': staticmethod(forward), | 
|  | 'backward': staticmethod(backward), | 
|  | } | 
|  | ) | 
|  | return generated_cls | 
|  |  | 
|  |  | 
|  | @functools.lru_cache | 
|  | def namedtuple_args_cls(schema): | 
|  | attribs = [arg.name for arg in schema.arguments.flat_all] | 
|  | name = str(schema.name) + "_args" | 
|  | # mypy doesn't support dynamic namedtuple name | 
|  | tuple_cls = namedtuple(name, attribs)  # type: ignore[misc] | 
|  | return tuple_cls | 
|  |  | 
|  |  | 
|  | def namedtuple_args(schema, args): | 
|  | assert isinstance(args, tuple) | 
|  | tuple_cls = namedtuple_args_cls(schema) | 
|  | return tuple_cls(*args) | 
|  |  | 
|  |  | 
|  | def validate_grad_inputs_dict(grad_inputs_dict, forward_op, args_info): | 
|  | def error(what): | 
|  | backward = forward_op._get_impl('backward') | 
|  | raise RuntimeError( | 
|  | f"In the backward function defined for {forward_op} at " | 
|  | f"{backward.location} using the CustomOp API, {what}") | 
|  |  | 
|  | if not isinstance(grad_inputs_dict, dict): | 
|  | error(f"expected the output of the backward function to be a dict but " | 
|  | f"got {type(grad_inputs_dict)}") | 
|  |  | 
|  | expected_keys = {arg.name for arg in forward_op._schema.arguments.flat_all | 
|  | if arg.type.is_tensor_like()} | 
|  | actual_keys = grad_inputs_dict.keys() | 
|  | if expected_keys != actual_keys: | 
|  | error(f"expected the returned grad_input dict to have keys " | 
|  | f"{expected_keys} but got {actual_keys}. The backward " | 
|  | f"function must return a gradient (can be None) for each arg " | 
|  | f"to the CustomOp that may be a Tensor or Sequence[Tensor]. " | 
|  | f"Args declared to be non-Tensor-like types should not appear " | 
|  | f"in the grad_input dict") | 
|  |  | 
|  | for name, grad in grad_inputs_dict.items(): | 
|  | arg_info = getattr(args_info, name) | 
|  |  | 
|  | if isinstance(arg_info, list): | 
|  | if not isinstance(grad, (tuple, list)): | 
|  | error(f"for input '{name}' expected the grad_input dict to " | 
|  | f"hold a list of gradients but got object of type " | 
|  | f"{type(grad)}.") | 
|  | if not len(grad) == len(arg_info): | 
|  | error(f"for input '{name}' expected the grad_input dict to " | 
|  | f"hold a list of {len(arg_info)} gradients but got " | 
|  | f"{len(grad)}") | 
|  | for idx, (g, info) in enumerate(zip(grad, arg_info)): | 
|  | if g is None: | 
|  | continue | 
|  | if not isinstance(g, torch.Tensor): | 
|  | error(f"for input '{name}' expected the grad_input dict to " | 
|  | f"hold a list of None or Tensor gradients but got " | 
|  | f"object of {type(g)} at index {idx}") | 
|  | if info != torch.Tensor: | 
|  | error(f"for input '{name}', got a Tensor as the gradient " | 
|  | f"for the {idx}-th value but expected None because " | 
|  | f"the {idx}-th value was not a Tensor (it was " | 
|  | f"type {arg_info}") | 
|  | continue | 
|  |  | 
|  | if grad is None: | 
|  | continue | 
|  | if not isinstance(grad, torch.Tensor): | 
|  | error(f"got object of type {type(grad)} as the gradient for input " | 
|  | f"'{name}', " | 
|  | f"but expected the gradient to be either None or a Tensor") | 
|  | if arg_info != torch.Tensor: | 
|  | error(f"got a Tensor as the gradient for input '{name}' but " | 
|  | f"expected None as the gradient because input '{name}' " | 
|  | f"was not a Tensor (it was type {arg_info}).") | 
|  |  | 
|  | def grad_inputs_dict_to_flat_tuple(grad_inputs_dict, args_info): | 
|  | result = [] | 
|  | for name, arg_info in args_info._asdict().items(): | 
|  | if name not in grad_inputs_dict: | 
|  | result.append(pytree.tree_map(lambda x: None, arg_info)) | 
|  | continue | 
|  | result.append(grad_inputs_dict[name]) | 
|  | return tuple(pytree.tree_flatten(result)[0]) | 
|  |  | 
|  | # Saves "stuff" (a pytree) onto the ctx object. Use unpack_saved to unpack it. | 
|  | # autograd.Function prefers that users use ctx.save_for_backward to | 
|  | # save Tensors (to avoid reference cycles) and for non-Tensors to go onto the | 
|  | # ctx object. | 
|  | def save_pytree_for_backward(ctx, stuff): | 
|  | flat_stuff, spec = pytree.tree_flatten(stuff) | 
|  | num_elts = len(flat_stuff) | 
|  | tensor_idxs = [idx for idx, thing in enumerate(flat_stuff) | 
|  | if isinstance(thing, torch.Tensor)] | 
|  | non_tensor_idxs = [idx for idx, thing in enumerate(flat_stuff) | 
|  | if not isinstance(thing, torch.Tensor)] | 
|  | tensors = [thing for thing in flat_stuff if isinstance(thing, torch.Tensor)] | 
|  | non_tensors = [thing for thing in flat_stuff if not isinstance(thing, torch.Tensor)] | 
|  |  | 
|  | ctx.spec = spec | 
|  | ctx.num_elts = num_elts | 
|  | ctx.save_for_backward(*tensors) | 
|  | ctx.tensor_idxs = tensor_idxs | 
|  | ctx.saved_non_tensors = non_tensors | 
|  | ctx.non_tensor_idxs = non_tensor_idxs | 
|  |  | 
|  |  | 
|  | # Inverse operation to save_pytree_for_backward | 
|  | def unpack_saved(ctx): | 
|  | flat_stuff = [None] * ctx.num_elts | 
|  | for tensor, idx in zip(ctx.saved_tensors, ctx.tensor_idxs): | 
|  | flat_stuff[idx] = tensor | 
|  | for non_tensor, idx in zip(ctx.saved_non_tensors, ctx.non_tensor_idxs): | 
|  | flat_stuff[idx] = non_tensor | 
|  | stuff = pytree.tree_unflatten(flat_stuff, ctx.spec) | 
|  | return stuff |