[JIT] Constant prop getattr (#49806)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49806
Fix for https://github.com/pytorch/pytorch/issues/47089
Test Plan: Imported from OSS
Reviewed By: navahgar
Differential Revision: D25696791
Pulled By: eellison
fbshipit-source-id: 914c17b8effef7f4f341775ac2b8150ee4703efd
diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py
index 7b7490e..28fc594 100644
--- a/test/jit/test_freezing.py
+++ b/test/jit/test_freezing.py
@@ -1331,3 +1331,33 @@
m.eval()
with self.assertRaisesRegex(RuntimeError, "Freezing modules containing prim::ModuleDictIndex is not supported"):
mf = torch._C._freeze_module(m._c)
+
+
+ def test_freeze_non_module_class_getattr(self):
+ class BoxCoder(object):
+ def __init__(self, bbox_xform_clip):
+ # type: (float) -> None
+ self.bbox_xform_clip = bbox_xform_clip
+
+ def decode(self, input):
+ return input * self.bbox_xform_clip
+
+ class MyModule(torch.nn.Module):
+ __annotations__ = {
+ 'box_coder': BoxCoder,
+ }
+
+ def __init__(self):
+ super(MyModule, self).__init__()
+ self.box_coder = BoxCoder(50.)
+
+ def forward(self, input):
+ return self.box_coder.decode(input)
+
+ model = MyModule()
+ model.eval()
+ script_model = torch.jit.freeze(torch.jit.script(model))
+ inp = torch.randn([4, 4])
+ output_eager = model(inp)
+ self.assertEqual(model(inp), script_model(inp))
+ FileCheck().check_not("GetAttr").run(script_model.graph)
diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp
index b2f314c..da9d551 100644
--- a/torch/csrc/jit/passes/constant_propagation.cpp
+++ b/torch/csrc/jit/passes/constant_propagation.cpp
@@ -54,6 +54,10 @@
case prim::CreateObject: {
createObject(stack, n->output()->type()->expect<ClassType>());
} break;
+ case prim::GetAttr: {
+ auto attr = pop(stack).toObject()->getAttr(n->s(attr::name));
+ push(stack, attr);
+ } break;
case prim::isinstance: {
isinstance(stack, n->tys(attr::types));
} break;