[JIT] Fix classes as attributes in recursive scripting
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32594
Test Plan: Imported from OSS
Differential Revision: D19562951
Pulled By: jamesr66a
fbshipit-source-id: 3d5491c1c23456f107390a78be16da687de951e6
diff --git a/test/jit/test_recursive_script.py b/test/jit/test_recursive_script.py
index c87c291..fd778e9 100644
--- a/test/jit/test_recursive_script.py
+++ b/test/jit/test_recursive_script.py
@@ -465,7 +465,7 @@
def test_attributes(self):
@torch.jit.script
- class Inner(object):
+ class Inner2(object):
def __init__(self):
self.b = "a string"
@@ -473,16 +473,16 @@
class Foo(object):
def __init__(self):
self.a = 4
- self.inner = Inner()
+ self.inner = Inner2()
@torch.jit.script
class SFoo(object):
def __init__(self):
self.a = 4
- self.inner = Inner()
+ self.inner = Inner2()
def __setstate__(self, obj):
- # type: (Tuple[int, Inner]) -> None
+ # type: (Tuple[int, Inner2]) -> None
a, inner = obj
self.a = a
self.inner = inner
diff --git a/test/test_jit.py b/test/test_jit.py
index 06395a1..59d71e2 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -4971,6 +4971,25 @@
scripted = torch.jit.script(foo)
self.getExportImportCopy(scripted)
+ def test_class_as_attribute(self):
+ @torch.jit.script
+ class Foo321(object):
+ def __init__(self):
+ self.x = 3
+
+ class FooBar1234(torch.nn.Module):
+ def __init__(self):
+ super(FooBar1234, self).__init__()
+ self.f = Foo321()
+
+ def forward(self, x):
+ return x + self.f.x
+
+ scripted = torch.jit.script(FooBar1234())
+ eic = self.getExportImportCopy(scripted)
+ x = torch.rand(3, 4)
+ self.assertEqual(scripted(x), eic(x))
+
def test_jitter_bug(self):
@torch.jit.script
def fn2(input, kernel_size):
diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h
index 381ee8a..15ee12c 100644
--- a/torch/csrc/jit/pybind_utils.h
+++ b/torch/csrc/jit/pybind_utils.h
@@ -153,6 +153,23 @@
return InferredType(IntType::get());
}
+ py::bool_ isClass =
+ py::module::import("inspect").attr("isclass")(input.get_type());
+ if (py::cast<bool>(isClass)) {
+ py::str qualifiedName = py::module::import("torch.jit")
+ .attr("_qualified_name")(input.get_type());
+ auto pyClass = py::module::import("torch.jit")
+ .attr("_get_script_class")(qualifiedName);
+ if (!pyClass.is_none()) {
+ auto cu = get_python_cu();
+ const auto classname =
+ c10::QualifiedName(py::cast<std::string>(qualifiedName));
+ auto class_type = cu->get_class(classname);
+ TORCH_INTERNAL_ASSERT(class_type);
+ return InferredType(class_type);
+ }
+ }
+
// Try container types
return tryToInferContainerType(input);
}
@@ -693,6 +710,13 @@
AT_ASSERT(classType);
auto pyClass =
py::module::import("torch.jit").attr("_get_script_class")(obj->name());
+ if (pyClass.is_none()) {
+ std::stringstream err;
+ err << "Unknown reference to ScriptClass ";
+ err << obj->name();
+ err << ". Did you forget to import it?)";
+ throw std::runtime_error(err.str());
+ }
auto pyObj = pyClass.attr("__new__")(pyClass);
const auto numAttrs = classType->numAttributes();
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index 926ae27..91cdc34 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -1973,8 +1973,7 @@
def _get_script_class(name):
global _script_classes
if name not in _script_classes:
- raise RuntimeError("Unknown reference to ScriptClass '{}'. "
- "Did you forget to import it?".format(name))
+ return None
return _script_classes[name]
# overloads are registered in _jit_internal and compiled here so that _overload