Run __setstate__ when cloning modules (#45858)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45858
When cloning a module that has __setstate__, __getstate__ methods.
We need to load these methods to initialize these modules.
Test Plan: Imported from OSS
Reviewed By: suo
Differential Revision: D24116524
Pulled By: bzinodev
fbshipit-source-id: a5111638e2dc903781f6468838c000850d1f9a74
diff --git a/torch/csrc/jit/api/module.cpp b/torch/csrc/jit/api/module.cpp
index f645f73..e821443 100644
--- a/torch/csrc/jit/api/module.cpp
+++ b/torch/csrc/jit/api/module.cpp
@@ -243,6 +243,14 @@
for (auto& fn : type()->methods()) {
r.clone_method(*this, *fn, type_remap);
}
+
+ // Execute __setstate__(__getstate__()) to initialize custom class members.
+ if (auto setstate_method = r.find_method("__setstate__")) {
+ auto getstate_method = r.find_method("__getstate__");
+ TORCH_INTERNAL_ASSERT(getstate_method, "expect __getstate__");
+ auto state = (*getstate_method)(Stack{});
+ (*setstate_method)(Stack{state});
+ }
}
return r;
}
diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp
index 50a2762..1b93d28 100644
--- a/torch/csrc/jit/passes/quantization/insert_observers.cpp
+++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp
@@ -164,6 +164,14 @@
for (auto& fn : type->methods()) {
clone_method(module, r, *fn, module_qconfig_map, type_remap);
}
+ // Execute __setstate__(__getstate__()) to initialize custom class
+ // members.
+ if (auto setstate_method = r.find_method("__setstate__")) {
+ auto getstate_method = r.find_method("__getstate__");
+ TORCH_INTERNAL_ASSERT(getstate_method, "expect __getstate__");
+ auto state = (*getstate_method)(Stack{});
+ (*setstate_method)(Stack{state});
+ }
}
return r;
}