[JIT] Constant prop getattr (#49806)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49806

Fix for https://github.com/pytorch/pytorch/issues/47089

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D25696791

Pulled By: eellison

fbshipit-source-id: 914c17b8effef7f4f341775ac2b8150ee4703efd
diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py
index 7b7490e..28fc594 100644
--- a/test/jit/test_freezing.py
+++ b/test/jit/test_freezing.py
@@ -1331,3 +1331,33 @@
         m.eval()
         with self.assertRaisesRegex(RuntimeError, "Freezing modules containing prim::ModuleDictIndex is not supported"):
             mf = torch._C._freeze_module(m._c)
+
+
+    def test_freeze_non_module_class_getattr(self):
+        class BoxCoder(object):
+            def __init__(self, bbox_xform_clip):
+                # type: (float) -> None
+                self.bbox_xform_clip = bbox_xform_clip
+
+            def decode(self, input):
+                return input * self.bbox_xform_clip
+
+        class MyModule(torch.nn.Module):
+            __annotations__ = {
+                'box_coder': BoxCoder,
+            }
+
+            def __init__(self):
+                super(MyModule, self).__init__()
+                self.box_coder = BoxCoder(50.)
+
+            def forward(self, input):
+                return self.box_coder.decode(input)
+
+        model = MyModule()
+        model.eval()
+        script_model = torch.jit.freeze(torch.jit.script(model))
+        inp = torch.randn([4, 4])
+        output_eager = model(inp)
+        self.assertEqual(model(inp), script_model(inp))
+        FileCheck().check_not("GetAttr").run(script_model.graph)
diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp
index b2f314c..da9d551 100644
--- a/torch/csrc/jit/passes/constant_propagation.cpp
+++ b/torch/csrc/jit/passes/constant_propagation.cpp
@@ -54,6 +54,10 @@
     case prim::CreateObject: {
       createObject(stack, n->output()->type()->expect<ClassType>());
     } break;
+    case prim::GetAttr: {
+      auto attr = pop(stack).toObject()->getAttr(n->s(attr::name));
+      push(stack, attr);
+    } break;
     case prim::isinstance: {
       isinstance(stack, n->tys(attr::types));
     } break;