[JIT] Fix classes as attributes in recursive scripting

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32594

Test Plan: Imported from OSS

Differential Revision: D19562951

Pulled By: jamesr66a

fbshipit-source-id: 3d5491c1c23456f107390a78be16da687de951e6
diff --git a/test/jit/test_recursive_script.py b/test/jit/test_recursive_script.py
index c87c291..fd778e9 100644
--- a/test/jit/test_recursive_script.py
+++ b/test/jit/test_recursive_script.py
@@ -465,7 +465,7 @@
 
     def test_attributes(self):
         @torch.jit.script
-        class Inner(object):
+        class Inner2(object):
             def __init__(self):
                 self.b = "a string"
 
@@ -473,16 +473,16 @@
         class Foo(object):
             def __init__(self):
                 self.a = 4
-                self.inner = Inner()
+                self.inner = Inner2()
 
         @torch.jit.script
         class SFoo(object):
             def __init__(self):
                 self.a = 4
-                self.inner = Inner()
+                self.inner = Inner2()
 
             def __setstate__(self, obj):
-                # type: (Tuple[int, Inner]) -> None
+                # type: (Tuple[int, Inner2]) -> None
                 a, inner = obj
                 self.a = a
                 self.inner = inner
diff --git a/test/test_jit.py b/test/test_jit.py
index 06395a1..59d71e2 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -4971,6 +4971,25 @@
         scripted = torch.jit.script(foo)
         self.getExportImportCopy(scripted)
 
+    def test_class_as_attribute(self):
+        @torch.jit.script
+        class Foo321(object):
+            def __init__(self):
+                self.x = 3
+
+        class FooBar1234(torch.nn.Module):
+            def __init__(self):
+                super(FooBar1234, self).__init__()
+                self.f = Foo321()
+
+            def forward(self, x):
+                return x + self.f.x
+
+        scripted = torch.jit.script(FooBar1234())
+        eic = self.getExportImportCopy(scripted)
+        x = torch.rand(3, 4)
+        self.assertEqual(scripted(x), eic(x))
+
     def test_jitter_bug(self):
         @torch.jit.script
         def fn2(input, kernel_size):
diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h
index 381ee8a..15ee12c 100644
--- a/torch/csrc/jit/pybind_utils.h
+++ b/torch/csrc/jit/pybind_utils.h
@@ -153,6 +153,23 @@
     return InferredType(IntType::get());
   }
 
+  py::bool_ isClass =
+      py::module::import("inspect").attr("isclass")(input.get_type());
+  if (py::cast<bool>(isClass)) {
+    py::str qualifiedName = py::module::import("torch.jit")
+                                .attr("_qualified_name")(input.get_type());
+    auto pyClass = py::module::import("torch.jit")
+                       .attr("_get_script_class")(qualifiedName);
+    if (!pyClass.is_none()) {
+      auto cu = get_python_cu();
+      const auto classname =
+          c10::QualifiedName(py::cast<std::string>(qualifiedName));
+      auto class_type = cu->get_class(classname);
+      TORCH_INTERNAL_ASSERT(class_type);
+      return InferredType(class_type);
+    }
+  }
+
   // Try container types
   return tryToInferContainerType(input);
 }
@@ -693,6 +710,13 @@
     AT_ASSERT(classType);
     auto pyClass =
         py::module::import("torch.jit").attr("_get_script_class")(obj->name());
+    if (pyClass.is_none()) {
+      std::stringstream err;
+      err << "Unknown reference to ScriptClass ";
+      err << obj->name();
+      err << ". Did you forget to import it?)";
+      throw std::runtime_error(err.str());
+    }
     auto pyObj = pyClass.attr("__new__")(pyClass);
 
     const auto numAttrs = classType->numAttributes();
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index 926ae27..91cdc34 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -1973,8 +1973,7 @@
 def _get_script_class(name):
     global _script_classes
     if name not in _script_classes:
-        raise RuntimeError("Unknown reference to ScriptClass '{}'. "
-                           "Did you forget to import it?".format(name))
+        return None
     return _script_classes[name]
 
 # overloads are registered in _jit_internal and compiled here so that _overload