Always get attr static out (#95771)
Discussion here https://github.com/pytorch/pytorch/issues/95630#issuecomment-1449596766
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95771
Approved by: https://github.com/jansel
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 12d5bdc..3c988bb 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -4675,6 +4675,22 @@
torch._dynamo.optimize(counter)(my_dyn_fn)(x)
self.assertEqual(counter.frame_count, 3)
+ def test_torch_compile_ctx_on_forward_and_training_step(self):
+ class MyModel(torch.nn.Module):
+ def forward(self):
+ ...
+
+ def training_step(self):
+ self()
+
+ model = MyModel()
+ compiled_model = torch.compile(model)
+
+ model.forward = compiled_model.dynamo_ctx(model.forward)
+ model.training_step = compiled_model.dynamo_ctx(model.training_step)
+
+ model.training_step()
+
class CustomFunc1(torch.autograd.Function):
@staticmethod
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index 173d5d4..4a8a2e2 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -320,6 +320,10 @@
if id_dispatch is not None:
return id_dispatch(self, value)
+ # Note - There are some nested values where types mismatch!
+ # We want to get those out and wrap those.
+ value = inspect.getattr_static(value, "_torchdynamo_inline", value)
+
# Everything else (NB: order matters!)
if istype(value, config.traceable_tensor_subclasses):
return self.wrap_tensor(value)