Run __setstate__ when cloning modules (#45858)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45858

When cloning a module that has __setstate__, __getstate__ methods.
We need to load these methods to initialize these modules.

Test Plan: Imported from OSS

Reviewed By: suo

Differential Revision: D24116524

Pulled By: bzinodev

fbshipit-source-id: a5111638e2dc903781f6468838c000850d1f9a74
diff --git a/torch/csrc/jit/api/module.cpp b/torch/csrc/jit/api/module.cpp
index f645f73..e821443 100644
--- a/torch/csrc/jit/api/module.cpp
+++ b/torch/csrc/jit/api/module.cpp
@@ -243,6 +243,14 @@
     for (auto& fn : type()->methods()) {
       r.clone_method(*this, *fn, type_remap);
     }
+
+    // Execute __setstate__(__getstate__()) to initialize custom class members.
+    if (auto setstate_method = r.find_method("__setstate__")) {
+      auto getstate_method = r.find_method("__getstate__");
+      TORCH_INTERNAL_ASSERT(getstate_method, "expect __getstate__");
+      auto state = (*getstate_method)(Stack{});
+      (*setstate_method)(Stack{state});
+    }
   }
   return r;
 }
diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp
index 50a2762..1b93d28 100644
--- a/torch/csrc/jit/passes/quantization/insert_observers.cpp
+++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp
@@ -164,6 +164,14 @@
       for (auto& fn : type->methods()) {
         clone_method(module, r, *fn, module_qconfig_map, type_remap);
       }
+      // Execute __setstate__(__getstate__()) to initialize custom class
+      // members.
+      if (auto setstate_method = r.find_method("__setstate__")) {
+        auto getstate_method = r.find_method("__getstate__");
+        TORCH_INTERNAL_ASSERT(getstate_method, "expect __getstate__");
+        auto state = (*getstate_method)(Stack{});
+        (*setstate_method)(Stack{state});
+      }
     }
     return r;
   }