[jit] fix traced training attribute (#47211)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47211
The attribute is getting shadowed by the default one set on all modules,
and the __setattr__ on the TracedModule object prevents setting it correctly.
import torch
inp = torch.zeros(1, 3, 224, 224)
model = torch.hub.load('pytorch/vision:v0.6.0', 'mobilenet_v2', pretrained=True)
model.eval()
print(model.training)
with torch.no_grad():
traced = torch.jit.trace(model, inp)
print(traced.training)
traced.eval()
print(traced.training)
traced.training = False
print(traced.training)
torch.jit.freeze(traced)
Test Plan: Imported from OSS
Reviewed By: suo
Differential Revision: D24686690
Pulled By: zdevito
fbshipit-source-id: 9c1678dc68e9bf83176e9f5a20fa8f6bff5d69a0
diff --git a/test/test_jit.py b/test/test_jit.py
index f4c3a8d..378c88e 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -413,6 +413,15 @@
self.assertEqual(origin_result, m3(input.cpu()))
self.assertEqual(origin_result, m4(input.cuda(0)))
+ def test_trace_retains_train(self):
+ class M(torch.nn.Module):
+ def forward(self, x):
+ return x
+ m = M()
+ m.eval()
+ tm = torch.jit.trace(m, (torch.rand(3)))
+ self.assertEqual(tm.training, m.training)
+
@unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
def test_restore_shared_storage_on_cuda(self):
class Foo(torch.jit.ScriptModule):
diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py
index 74ee0ba..b9120f5 100644
--- a/torch/jit/_trace.py
+++ b/torch/jit/_trace.py
@@ -1010,7 +1010,6 @@
"TracedModules don't support parameter sharing between modules"
)
id_set.add(param)
-
tmp_module.training = orig.training
for name, param in orig._parameters.items():
@@ -1046,7 +1045,7 @@
self.__dict__["_name"] = type(orig).__name__
self.__dict__["_actual_script_module"] = script_module
- for name in ("_parameters", "_buffers", "_modules"):
+ for name in ("_parameters", "_buffers", "_modules", "training"):
delattr(self, name)
def forward(self, *args, **kwargs):