Allow to differentiate through NumPy code (#114608)
With this PR it is possible to differentiate through NumPy code modulo
the usual caveats that apply to differentiation:
- That there are no graphbreaks
- That the decomposition in `torch._numpy` is differentiable
@ev-br and I were somewhat careful to achieve the second point, but
it is not tested though and through, so YMMV
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114608
Approved by: https://github.com/voznesenskym
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 61928d4..945bbe4 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -1692,6 +1692,7 @@
opt_fn = torch._dynamo.optimize(cnts)(fn)
x = torch.randn(3)
res = opt_fn(x)
+ self.assertEqual(type(res), np.ndarray)
self.assertEqual(cnts.frame_count, 1)
def fn(x):
@@ -1701,6 +1702,7 @@
opt_fn = torch._dynamo.optimize(cnts)(fn)
x = torch.randn(3, requires_grad=True)
res = opt_fn(x)
+ self.assertEqual(type(res), np.ndarray)
self.assertEqual(cnts.frame_count, 1)
def test_numpy_recompilation_scalar(self):
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 5eaff6d..82b68b3 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -8166,6 +8166,32 @@
self.assertEqual(type(r), np.ndarray)
self.assertEqual(r, np.sin(x))
+ def test_numpy_autograd(self):
+ def my_torch(x):
+ y = torch.cat([torch.sin(x) ** 2, torch.max(x)[None]])
+ return y.sum()
+
+ def my_np(x):
+ y = np.concatenate([np.sin(x) ** 2, np.max(x)[None]])
+ return np.sum(y)
+
+ @torch.compile
+ def wrapper(x):
+ x = x.numpy()
+ y = my_np(x)
+ return torch.as_tensor(y)
+
+ x_np = torch.arange(8, dtype=torch.float32, requires_grad=True)
+ x = torch.arange(8, dtype=torch.float32, requires_grad=True)
+
+ out_np = wrapper(x_np)
+ out = my_torch(x)
+ self.assertEqual(out, out_np)
+
+ out_np.backward()
+ out.backward()
+ self.assertEqual(x.grad, x_np.grad)
+
# Disable constant propagation, so we isolate value range analysis
@patch.object(config, "constant_and_index_propagation", False)
@patch.object(config, "joint_graph_constant_folding", False)
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
index ca8d349..d12872a 100644
--- a/torch/_dynamo/variables/tensor.py
+++ b/torch/_dynamo/variables/tensor.py
@@ -523,19 +523,18 @@
f"can't convert {self.layout} layout tensor to numpy. Use Tensor.dense() first"
)
# We don't check that the tensor is on CPU when force is False, as this
- # allows us to execute NumPy code on CUDA.
- # We don't check that requires_grad=False as we are currently doing an
- # unconditional detach.
- # TODO: We may want to avoid detaching if `requires_grad=True`
- # and `force=False` to allow computing gradients.
+ # allows us to execute NumPy code on CUDA. Same for requires_grad=True
force = "force" in kwargs and kwargs["force"].as_python_constant()
- proxy = tx.output.create_proxy(
- "call_method", "detach", *proxy_args_kwargs([self], {})
- )
if force:
- # TODO Add resolve_conj and resolve_neg once we support complex tensors
+ # If the user set force=True we try to preserve the semantics (no gradients, move to CPU...)
+ t = self.call_method(tx, "detach", [], {})
proxy = tx.output.create_proxy(
- "call_method", "cpu", *proxy_args_kwargs([self], {})
+ "call_method", "cpu", (t.as_proxy(),), {}
+ )
+ else:
+ # Hacky way to create a view of self that will be marked as NumpyNdarrayVariable
+ proxy = tx.output.create_proxy(
+ "call_method", "view_as", *proxy_args_kwargs([self, self], {})
)
return NumpyNdarrayVariable.create(tx, proxy)
elif name == "tolist":