Fake: copy over grad attribute (#82593)
Copy over the grad attribute... let me any suggestions of other tests/opinfos/crossrefs I should be doing to make this more comprehensive.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82593
Approved by: https://github.com/ezyang
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index 944b1b8..6ea0936 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -14,8 +14,8 @@
from torch.utils._python_dispatch import enable_torch_dispatch_mode
from torch import nn
import unittest
-import contextlib
import torch._prims as prims
+import contextlib
import copy
class FakeTensorTest(TestCase):
@@ -138,6 +138,16 @@
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
y = x[0]
+ def test_fake_grad_copy(self):
+ x = torch.rand([4, 4], requires_grad=True)
+ x.grad = torch.rand([4, 4])
+ mode = FakeTensorMode()
+ fake_x = mode.from_tensor(x)
+ prims.utils.compare_tensor_meta(fake_x, x)
+ prims.utils.compare_tensor_meta(fake_x.grad, x.grad)
+
+ self.assertTrue(isinstance(fake_x.grad, FakeTensor))
+
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_like_constructor(self):
with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index af16175..e4dbc6c 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -137,6 +137,8 @@
out = FakeTensor(fake_mode, self.meta_converter(t), existing_device)
if type(t) is torch.nn.Parameter:
out = torch.nn.Parameter(out, requires_grad=out.requires_grad) # type: ignore[assignment]
+ if t.grad is not None:
+ out.grad = self.from_real_tensor(fake_mode, t.grad)
self.set_tensor_memo(t, out)
return out