[JIT] Fix python pickle serialization for torchbind (#32878)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32878
ghstack-source-id: 97736045

Test Plan: Imported from OSS

Differential Revision: D19669879

fbshipit-source-id: 23ea91cffe7344d1eed014e2509983c281dd18d3
diff --git a/test/test_jit.py b/test/test_jit.py
index cae5d27..881239b 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -5058,6 +5058,17 @@
         traced = torch.jit.trace(TryTracing123(), ())
         self.assertEqual(torch.zeros(4, 4), traced())
 
+    @skipIfRocm
+    @unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case")
+    def test_torchbind_pickle_serialization(self):
+        nt = torch.classes._TorchScriptTesting_PickleTester([3, 4])
+        b = io.BytesIO()
+        torch.save(nt, b)
+        b.seek(0)
+        nt_loaded = torch.load(b)
+        for exp in [7, 3, 3, 1]:
+            self.assertEqual(nt_loaded.pop(), exp)
+
     def test_jitter_bug(self):
         @torch.jit.script
         def fn2(input, kernel_size):
diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp
index 1f2ba58..5337082 100644
--- a/torch/csrc/jit/script/init.cpp
+++ b/torch/csrc/jit/script/init.cpp
@@ -737,11 +737,63 @@
             return bool(self.find_method(name));
           })
       .def(
-          "_method_names", [](Object& self) {
+          "_method_names",
+          [](Object& self) {
             return fmap(self.get_methods(), [](const Method& method) {
               return method.name();
             });
-          });
+          })
+      .def(py::pickle(
+          [](const Object& self)
+              -> std::tuple<py::object, std::string> { // __getstate__
+            if (auto getstate_method = self.find_method("__getstate__")) {
+              auto object_state = toPyObject((*getstate_method)(Stack{}));
+              TORCH_INTERNAL_ASSERT(self.type()->name());
+              return std::make_tuple(
+                  object_state, self.type()->name()->qualifiedName());
+            }
+            std::stringstream err;
+            err << "Tried to serialize object ";
+            if (auto qualname = self.type()->name()) {
+              err << qualname->qualifiedName() << " ";
+            }
+            err << "which does not have a __getstate__ method defined!";
+            throw std::runtime_error(err.str());
+          },
+          [](std::tuple<py::object, std::string> state_tup) -> Object {
+            py::object state;
+            std::string qualname;
+            std::tie(state, qualname) = state_tup;
+            auto class_type = classCU()->get_class(qualname);
+            TORCH_CHECK(
+                class_type,
+                "Tried to deserialize class ",
+                qualname,
+                " which is not known to the runtime. "
+                "If this is a custom C++ class, make "
+                "sure the appropriate code is linked.");
+
+            auto self = script::Object(c10::ivalue::Object::create(
+                c10::StrongTypePtr(classCU(), class_type), 1));
+            if (auto setstate_method = self.find_method("__setstate__")) {
+              auto setstate_schema = setstate_method->function().getSchema();
+              TORCH_INTERNAL_ASSERT(
+                  setstate_schema.arguments().size() == 2,
+                  "__setstate__ method for class ",
+                  class_type->python_str(),
+                  " must have exactly 2 arguments!");
+              auto state_type = setstate_schema.arguments().at(1).type();
+              (*setstate_method)(Stack{toIValue(state, state_type)});
+              return self;
+            }
+            std::stringstream err;
+            err << "Tried to deserialize object ";
+            if (auto qualname = class_type->name()) {
+              err << qualname->qualifiedName() << " ";
+            }
+            err << "which does not have a __setstate__ method defined!";
+            throw std::runtime_error(err.str());
+          }));
 
   // torch.jit.ScriptModule is a subclass of this C++ object.
   // Methods here are prefixed with _ since they should not be