[custom_op] stop using nonlocals to store information (#128547) (#128616)
Fixes https://github.com/pytorch/pytorch/issues/128544
Fixes https://github.com/pytorch/pytorch/issues/128535
We had a problem with multithreading where the nonlocals were being
clobbered. In the first place, we stored these nonlocals because we
wanted to ferry information from an autograd.Function.apply to
autograd.Function.forward.
Our new approach is:
- pass the information directly as an input to the
autograd.Function.apply. This means that the autograd.Function.forward
will receive the information too.
- this messes up ctx.needs_input_grad, which has an element per input to
forward. The user should not see the additional information we passed.
We fix this by temporarily overriding ctx.needs_input_grad to the
right thing.
- this exposed a bug in that ctx.needs_input_grad wasn't correct for
TensorList inputs. This PR fixes that too.
Test Plan:
- existing and new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128547
Approved by: https://github.com/williamwen42, https://github.com/soulitzer
diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py
index e2af3ef..a3c37d3 100644
--- a/test/test_custom_ops.py
+++ b/test/test_custom_ops.py
@@ -2291,10 +2291,13 @@
class Stack(torch.autograd.Function):
@staticmethod
def forward(ctx, xs):
+ ctx.num_xs = len(xs)
return torch.stack(xs)
@staticmethod
def backward(ctx, grad):
+ expected = ([True] * ctx.num_xs,)
+ self.assertEqual(ctx.needs_input_grad, expected)
return list(grad.unbind(0))
# call two applys, do a backward on the first
@@ -2327,19 +2330,21 @@
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, xs):
- if len(xs) > 0:
+ if len(xs) > 1:
return Foo.apply(xs[1:])
ctx.len_xs = len(xs)
- return x.sin()
+ return xs[0].sin()
@staticmethod
def backward(ctx, grad):
- result = [None] * len_xs
+ result = [None] * ctx.len_xs
result[-1] = grad.cos()
return result
- with self.assertRaisesRegex(NotImplementedError, "Recursive call"):
- Foo.apply(xs)
+ # should work
+ result = Foo.apply(xs)
+ expected = xs[-1].sin()
+ self.assertEqual(result, expected)
# recursive on backward
@torch._library.autograd.supports_tensorlist
diff --git a/torch/_library/autograd.py b/torch/_library/autograd.py
index 1ff5696..9001948 100644
--- a/torch/_library/autograd.py
+++ b/torch/_library/autograd.py
@@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import dataclasses
-from typing import Any, Callable, Optional, Protocol
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, Optional, Protocol
from .. import _C, _ops, autograd, Tensor
@@ -22,19 +23,20 @@
def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable:
name: str = f"GeneratedBackwardFor_{op._namespace}_{op._opname}_{op._overloadname}"
- saved_keyset = None
- saved_keyword_only_args = None
has_kwarg_only_args = utils.has_kwarg_only_args(op._schema)
+ @dataclass
+ class Metadata:
+ keyset: _C.DispatchKeySet
+ keyword_only_args: Dict[str, Any]
+
def forward(ctx, *args):
+ metadata = args[-1]
+ args = args[:-1]
+
with _C._AutoDispatchBelowAutograd():
- nonlocal saved_keyset, saved_keyword_only_args
- keyset = saved_keyset
- assert keyset is not None, "Should have been set by autograd_impl"
- saved_keyset = None
- kwargs = saved_keyword_only_args
- assert kwargs is not None, "Should have been set by autograd_impl"
- saved_keyword_only_args = None
+ keyset = metadata.keyset
+ kwargs = metadata.keyword_only_args
result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
if info._setup_context_fn:
# The Dispatcher will remove args that are equal to their default
@@ -59,8 +61,15 @@
def backward(ctx, *grads):
if info._backward_fn:
- result = info._backward_fn(ctx, *grads)
- return result
+ try:
+ prev_needs_input_grad = ctx.needs_input_grad
+ ctx.needs_input_grad = ctx.needs_input_grad[:-1]
+ result = info._backward_fn(ctx, *grads)
+ finally:
+ ctx.needs_input_grad = prev_needs_input_grad
+ if isinstance(result, tuple):
+ return (*result, None)
+ return result, None
raise RuntimeError(
f"Trying to backward through {op} but no autograd "
f"formula was registered. "
@@ -86,15 +95,7 @@
# The dispatcher passes any keyword-only-args as kwargs and the
# rest of the args (even if specified as kwargs) as args.
def autograd_impl(keyset, *args, **keyword_only_args):
- # We set a nonlocal to ferry keyset from here to the forward.
- # This supports recursive calls (we implement the forward carefully so
- # that it'll read saved_keyset before making a recursive call to the op).
- nonlocal saved_keyset, saved_keyword_only_args
- assert saved_keyset is None
- saved_keyset = keyset
- assert saved_keyword_only_args is None
- saved_keyword_only_args = keyword_only_args
- result = Generated.apply(*args) # type: ignore[attr-defined]
+ result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined]
return result
return autograd_impl
@@ -107,41 +108,38 @@
Tensors. Applying @supports_tensorlist enables an autograd.Function to support
autograd for List[Tensor] inputs and outputs.
"""
- # NB: All calls to the autograd.Function.apply shares these variables
- # We assume that only one call to .apply happens at a time. This means that
- # you cannot call the autograd.Function recursively (e.g. from its own forward).
- input_spec: Optional[spec_t] = None
- output_spec: Optional[spec_t] = None
- result_is_tuple = None
-
orig_forward = cls.forward
orig_backward = cls.backward
orig_apply = cls.apply
+ @dataclass
+ class Metadata:
+ input_spec: spec_t
+ output_spec: Optional[spec_t] = None
+ result_is_tuple: Optional[bool] = None
+
def new_forward(ctx, *args):
- if input_spec is None:
+ metadata = args[-1]
+ args = args[:-1]
+ if not isinstance(metadata, Metadata):
raise NotImplementedError(
"NYI: calling supports_tensorlist autograd.Function.forward directly. "
"You should probably be calling .apply instead. "
"Please file an issue if not."
)
- args = unflatten(list(args), input_spec)
+ args = unflatten(list(args), metadata.input_spec)
result = orig_forward(ctx, *args)
- nonlocal output_spec
- nonlocal result_is_tuple
- result_is_tuple = isinstance(result, tuple)
- if not result_is_tuple:
+ metadata.result_is_tuple = isinstance(result, tuple)
+ if not metadata.result_is_tuple:
result = (result,)
- nonlocal output_spec
flat_result, output_spec = flatten(result, not_list_of_tensor)
+ metadata.output_spec = output_spec
- # Save the input_spec/output_spec for backward because another call to
- # .apply will override the nonlocals.
if hasattr(ctx, "_pt_metadata"):
raise RuntimeError(
"Please don't set ctx._pt_metadata; PyTorch uses it to store info"
)
- ctx._pt_metadata = (input_spec, output_spec)
+ ctx._pt_metadata = metadata
return tuple(flat_result)
@@ -153,9 +151,24 @@
"Please file an issue if you need this."
)
- input_spec, output_spec = ctx._pt_metadata
- grads = unflatten(list(grads), output_spec)
- grad_inputs = orig_backward(ctx, *grads)
+ metadata = ctx._pt_metadata
+ grads = unflatten(list(grads), metadata.output_spec)
+
+ # If the user's input is ([x, y, z], w),
+ # then needs_input_grad is (bool, bool, bool, bool, bool).
+ # We need to
+ # 1. get rid of the additional bool (which comes from the extra
+ # `metadata input`)
+ # 2. unflatten to get the right structure.
+ prev_needs_input_grad = ctx.needs_input_grad
+ try:
+ ctx.needs_input_grad = unflatten(
+ list(ctx.needs_input_grad[:-1]), metadata.input_spec
+ )
+ grad_inputs = orig_backward(ctx, *grads)
+ finally:
+ ctx.needs_input_grad = prev_needs_input_grad
+
if not isinstance(grad_inputs, tuple):
grad_inputs = (grad_inputs,)
# Assume that any Nones in the backward are Tensors.
@@ -166,29 +179,21 @@
flat_grad_inputs, grad_inputs_spec = flatten(
grad_inputs, not_list_of_optional_tensor
)
- if grad_inputs_spec != input_spec:
+ if grad_inputs_spec != metadata.input_spec:
raise RuntimeError(
f"Expected the return from backward to be of the same structure "
f"as the inputs. Got: {grad_inputs_spec} (return from backward), "
- f"{input_spec} (inputs)"
+ f"{metadata.input_spec} (inputs)"
)
- return tuple(flat_grad_inputs)
+ return tuple(flat_grad_inputs + [None])
def new_apply(*args):
- nonlocal input_spec
- if input_spec is not None:
- raise NotImplementedError(
- "NYI: Recursive call to autograd.Function decorated with "
- "`supports_tensorlist`. Please file an issue."
- )
- try:
- flat_args, input_spec = flatten(args, is_leaf=not_list_of_tensor)
- result = orig_apply(*flat_args) # type: ignore[misc]
- finally:
- input_spec = None
- assert output_spec is not None
- result = unflatten(list(result), output_spec)
- if not result_is_tuple:
+ flat_args, input_spec = flatten(args, is_leaf=not_list_of_tensor)
+ metadata = Metadata(input_spec)
+ result = orig_apply(*flat_args, metadata) # type: ignore[misc]
+ assert metadata.output_spec is not None
+ result = unflatten(list(result), metadata.output_spec)
+ if not metadata.result_is_tuple:
assert isinstance(result, tuple)
assert len(result) == 1
return result[0]
diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp
index 33300b0..0227229 100644
--- a/torch/csrc/autograd/python_function.cpp
+++ b/torch/csrc/autograd/python_function.cpp
@@ -1718,7 +1718,7 @@
nullptr},
{"needs_input_grad",
&getObject<&THPFunction::needs_input_grad>,
- nullptr,
+ &setObject<&THPFunction::needs_input_grad>,
nullptr,
nullptr},
{"requires_grad", getRequiresGrad, nullptr, nullptr, nullptr},