[JIT] Fix python pickle serialization for torchbind (#32878)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32878
ghstack-source-id: 97736045
Test Plan: Imported from OSS
Differential Revision: D19669879
fbshipit-source-id: 23ea91cffe7344d1eed014e2509983c281dd18d3
diff --git a/test/test_jit.py b/test/test_jit.py
index cae5d27..881239b 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -5058,6 +5058,17 @@
traced = torch.jit.trace(TryTracing123(), ())
self.assertEqual(torch.zeros(4, 4), traced())
+ @skipIfRocm
+ @unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case")
+ def test_torchbind_pickle_serialization(self):
+ nt = torch.classes._TorchScriptTesting_PickleTester([3, 4])
+ b = io.BytesIO()
+ torch.save(nt, b)
+ b.seek(0)
+ nt_loaded = torch.load(b)
+ for exp in [7, 3, 3, 1]:
+ self.assertEqual(nt_loaded.pop(), exp)
+
def test_jitter_bug(self):
@torch.jit.script
def fn2(input, kernel_size):
diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp
index 1f2ba58..5337082 100644
--- a/torch/csrc/jit/script/init.cpp
+++ b/torch/csrc/jit/script/init.cpp
@@ -737,11 +737,63 @@
return bool(self.find_method(name));
})
.def(
- "_method_names", [](Object& self) {
+ "_method_names",
+ [](Object& self) {
return fmap(self.get_methods(), [](const Method& method) {
return method.name();
});
- });
+ })
+ .def(py::pickle(
+ [](const Object& self)
+ -> std::tuple<py::object, std::string> { // __getstate__
+ if (auto getstate_method = self.find_method("__getstate__")) {
+ auto object_state = toPyObject((*getstate_method)(Stack{}));
+ TORCH_INTERNAL_ASSERT(self.type()->name());
+ return std::make_tuple(
+ object_state, self.type()->name()->qualifiedName());
+ }
+ std::stringstream err;
+ err << "Tried to serialize object ";
+ if (auto qualname = self.type()->name()) {
+ err << qualname->qualifiedName() << " ";
+ }
+ err << "which does not have a __getstate__ method defined!";
+ throw std::runtime_error(err.str());
+ },
+ [](std::tuple<py::object, std::string> state_tup) -> Object {
+ py::object state;
+ std::string qualname;
+ std::tie(state, qualname) = state_tup;
+ auto class_type = classCU()->get_class(qualname);
+ TORCH_CHECK(
+ class_type,
+ "Tried to deserialize class ",
+ qualname,
+ " which is not known to the runtime. "
+ "If this is a custom C++ class, make "
+ "sure the appropriate code is linked.");
+
+ auto self = script::Object(c10::ivalue::Object::create(
+ c10::StrongTypePtr(classCU(), class_type), 1));
+ if (auto setstate_method = self.find_method("__setstate__")) {
+ auto setstate_schema = setstate_method->function().getSchema();
+ TORCH_INTERNAL_ASSERT(
+ setstate_schema.arguments().size() == 2,
+ "__setstate__ method for class ",
+ class_type->python_str(),
+ " must have exactly 2 arguments!");
+ auto state_type = setstate_schema.arguments().at(1).type();
+ (*setstate_method)(Stack{toIValue(state, state_type)});
+ return self;
+ }
+ std::stringstream err;
+ err << "Tried to deserialize object ";
+ if (auto qualname = class_type->name()) {
+ err << qualname->qualifiedName() << " ";
+ }
+ err << "which does not have a __setstate__ method defined!";
+ throw std::runtime_error(err.str());
+ }));
// torch.jit.ScriptModule is a subclass of this C++ object.
// Methods here are prefixed with _ since they should not be