[dynamo] Support if cond on NNModuleVariable (#89095)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89095
Approved by: https://github.com/yanboliang, https://github.com/mlazos
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index e27f7bc..8f79f24 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -2885,6 +2885,34 @@
self.assertTrue(same(ref, res))
self.assertTrue(same(x, x1))
+ def test_if_cond_nn_mod(self):
+ class MockModule(torch.nn.Module):
+ def __init__(self, output_relu=True):
+ super(MockModule, self).__init__()
+ self.relu = torch.nn.ReLU() if output_relu else None
+
+ def forward(self, x):
+ x = torch.sin(x)
+ if self.relu:
+ x = self.relu(x)
+ return x
+
+ model = MockModule()
+ opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
+
+ x = torch.rand(4)
+ ref = model(x)
+ res = opt_model(x)
+ self.assertTrue(same(ref, res))
+
+ model = MockModule(output_relu=False)
+ opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
+
+ x = torch.rand(4)
+ ref = model(x)
+ res = opt_model(x)
+ self.assertTrue(same(ref, res))
+
class CustomFunc(torch.autograd.Function):
@staticmethod
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index d5c05f7..d2bc533 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -252,6 +252,11 @@
+ if_next
+ if_jump
)
+ elif isinstance(value, NNModuleVariable):
+ # Equivant of "self.nn_module is not None"
+ if truth_fn(value):
+ push and self.push(value)
+ self.jump(inst)
elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence(
self
):