[jit] `__copy__` for `RecursiveScriptModule` (#36830)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/36830
Test Plan:
build/bin/test_jit
Imported from OSS
Differential Revision: D21431012
fbshipit-source-id: 13a1bf9744ec95ea59622226c8d8a8d55ec3f0b0
diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp
index 6a57973..cf75f46 100644
--- a/aten/src/ATen/core/ivalue.cpp
+++ b/aten/src/ATen/core/ivalue.cpp
@@ -564,6 +564,14 @@
slots_.resize(type()->numAttributes());
}
+c10::intrusive_ptr<ivalue::Object> ivalue::Object::copy() const {
+ auto object = ivalue::Object::create(c10::StrongTypePtr(type_.cu_, type()), type()->numAttributes());
+ for (auto i = 0; i < slots_.size(); ++i) {
+ object->setSlot(i, slots_[i]);
+ }
+ return object;
+}
+
c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy() const {
IValue::HashAliasedIValueMap memo;
return deepcopy(memo);
diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h
index b4cae48..6a2888d 100644
--- a/aten/src/ATen/core/ivalue_inl.h
+++ b/aten/src/ATen/core/ivalue_inl.h
@@ -433,7 +433,10 @@
return type_.cu_;
}
+ c10::intrusive_ptr<Object> copy() const;
+
c10::intrusive_ptr<Object> deepcopy() const;
+
c10::intrusive_ptr<Object> deepcopy(IValue::HashAliasedIValueMap& memo) const;
private:
diff --git a/test/cpp/jit/test_module_api.cpp b/test/cpp/jit/test_module_api.cpp
index 20cb85e..b51c944 100644
--- a/test/cpp/jit/test_module_api.cpp
+++ b/test/cpp/jit/test_module_api.cpp
@@ -34,7 +34,7 @@
ASSERT_EQ(Module(p2.attr("c2").toObject()).attr(attr_name).toInt(), 3);
}
-void testModuleCloneInstance() {
+void testModuleCopy() {
auto cu = std::make_shared<CompilationUnit>();
auto cls = ClassType::create("foo.bar", cu, true);
auto attr_name = "attr";
@@ -44,7 +44,8 @@
m.register_attribute(attr_name, IntType::get(), v, false);
Module m2 = m.clone();
- Module m3 = m.clone_instance();
+ Module m3 = m.copy();
+
// Make sure copy works
ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
ASSERT_EQ(m3.attr(attr_name).toInt(), 2);
@@ -52,7 +53,7 @@
// clone will copy both type and data, therefore we'll have a
// different type
ASSERT_NE(m.type(), m2.type());
- // clone_instance only copies data, type is shared
+ // copy only copies data, type is shared
ASSERT_EQ(m.type(), m3.type());
// change value of copied instance
@@ -81,7 +82,7 @@
m.setattr(tensor_list_attr, list);
Module m2 = m.deepcopy();
- Module m3 = m.clone_instance();
+ Module m3 = m.copy();
// Make sure copy works
ASSERT_EQ(m2.attr(int_attr).toInt(), 2);
ASSERT_EQ(m3.attr(int_attr).toInt(), 2);
@@ -90,13 +91,14 @@
ASSERT_TRUE(!IValue(m2._ivalue()).overlaps(IValue(m._ivalue())));
ASSERT_TRUE(IValue(m3._ivalue()).overlaps(IValue(m._ivalue())));
- // Both deepcopy and clone_instance will preserve the type
+ // Both deepcopy and copy will preserve the type
ASSERT_EQ(m.type(), m2.type());
ASSERT_EQ(m.type(), m3.type());
// change int value of copied instances
m2.setattr(int_attr, IValue(3));
m3.setattr(int_attr, IValue(4));
+
// Verify value of original instance doesn't change
ASSERT_EQ(m.attr(int_attr).toInt(), 2);
ASSERT_EQ(m2.attr(int_attr).toInt(), 3);
@@ -106,8 +108,8 @@
at::Tensor t1 = m.attr(tensor_attr).toTensor();
at::Tensor t2 =
m2.attr(tensor_attr).toTensor(); // deepcopy will copy the Tensor
- at::Tensor t3 = m3.attr(tensor_attr)
- .toTensor(); // clone_instance will not copy the Tensor
+ at::Tensor t3 =
+ m3.attr(tensor_attr).toTensor(); // copy will not copy the Tensor
// check copy works
ASSERT_TRUE(t1.equal(t2));
ASSERT_TRUE(t1.equal(t3));
diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h
index 04db682..5062ccf 100644
--- a/test/cpp/jit/tests.h
+++ b/test/cpp/jit/tests.h
@@ -55,9 +55,9 @@
_(SubgraphMatching) \
_(SubgraphRewriter) \
_(ModuleClone) \
- _(ModuleCloneInstance) \
_(ModuleConstant) \
_(ModuleParameter) \
+ _(ModuleCopy) \
_(ModuleDeepcopy) \
_(ModuleDeepcopyString) \
_(ModuleDeepcopyAliasing) \
diff --git a/torch/csrc/jit/api/module.cpp b/torch/csrc/jit/api/module.cpp
index a511e33..7cb5cb2 100644
--- a/torch/csrc/jit/api/module.cpp
+++ b/torch/csrc/jit/api/module.cpp
@@ -163,6 +163,10 @@
return clone_method(orig, orig.get_method(name).function(), type_remap);
}
+Module Module::copy() const {
+ return Module(_ivalue()->copy());
+}
+
Module Module::deepcopy() const {
return Module(_ivalue()->deepcopy());
}
@@ -228,22 +232,7 @@
}
Module Module::clone_instance() const {
- Module r(_ivalue()->compilation_unit(), type());
-
- // Copy slots. If a slot is a module - recursively clone it.
- size_t N = type()->numAttributes();
- for (size_t i = 0; i < N; ++i) {
- IValue s = _ivalue()->getSlot(i);
- if (type()->getAttribute(i)->is_module()) {
- const Module& orig = Module(s.toObject());
- Module cloned = orig.clone_instance();
- r._ivalue()->setAttr(type()->getAttributeName(i), cloned._ivalue());
- } else {
- r._ivalue()->setAttr(type()->getAttributeName(i), s);
- }
- }
-
- return r;
+ return Module(_ivalue()->copy());
}
void Module::train(bool on) {
diff --git a/torch/csrc/jit/api/module.h b/torch/csrc/jit/api/module.h
index 9c60fa1..e754ff4 100644
--- a/torch/csrc/jit/api/module.h
+++ b/torch/csrc/jit/api/module.h
@@ -221,6 +221,8 @@
const std::string& filename,
const ExtraFilesMap& extra_files = ExtraFilesMap()) const;
+ Module copy() const;
+
Module deepcopy() const;
// Clones both the underlying `ClassType` and the module instance(data), this
diff --git a/torch/csrc/jit/api/object.cpp b/torch/csrc/jit/api/object.cpp
index 994f66b..003cbc3 100644
--- a/torch/csrc/jit/api/object.cpp
+++ b/torch/csrc/jit/api/object.cpp
@@ -35,6 +35,10 @@
*type()->name(), src, resolver ? resolver : nativeResolver(), &self);
}
+Object Object::copy() const {
+ return Object(_ivalue()->copy());
+}
+
Object Object::deepcopy() const {
c10::IValue::HashAliasedIValueMap memo;
return deepcopy(memo);
diff --git a/torch/csrc/jit/api/object.h b/torch/csrc/jit/api/object.h
index b46bc7a..12e9cf6 100644
--- a/torch/csrc/jit/api/object.h
+++ b/torch/csrc/jit/api/object.h
@@ -124,6 +124,9 @@
return _ivalue()->slots().size();
}
+ // shallow copy the object
+ Object copy() const;
+
// Copies all the attributes of the object recursively without creating new
// `ClassType`, including deepcopy of Tensors
Object deepcopy() const;
diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp
index f16ee36..462c720 100644
--- a/torch/csrc/jit/python/script_init.cpp
+++ b/torch/csrc/jit/python/script_init.cpp
@@ -934,6 +934,7 @@
.def("apply", &Module::apply)
.def("_clone", &Module::clone)
.def("_clone_instance", &Module::clone_instance)
+ .def("copy", &Module::copy)
.def("deepcopy", &Module::deepcopy)
.def_property_readonly("qualified_name", [](const Module& self) {
return self.type()->name()->qualifiedName();