blob: 17a65e2268521dd0a7be5547f0926719bd43e8ca [file] [log] [blame]
#include <test/cpp/jit/test_base.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/torch.h>
namespace torch {
namespace jit {
using namespace torch::jit::script;
void testModuleClone() {
auto cu = std::make_shared<CompilationUnit>();
auto parent = ClassType::create("parent", cu, true);
// creating child module
auto child = ClassType::create("child", cu, true);
auto attr_name = "attr";
child->addAttribute(attr_name, IntType::get());
Module c1(cu, child);
auto v1 = IValue(2);
c1.register_attribute(attr_name,
IntType::get(),
v1,
false);
Module c2(cu, child);
auto v2 = IValue(3);
c2.register_attribute(attr_name,
IntType::get(),
v2,
false);
// attach two child module instance to parent that shares
// ClassType
Module p(cu, parent);
p.register_attribute("c1", c1.type(), c1._ivalue(), false);
p.register_attribute("c2", c2.type(), c2._ivalue(), false);
// clone parent
Module p2 = p.clone();
// check the two child module has the same ClassType
ASSERT_EQ(p2.attr("c1").type(), p2.attr("c2").type());
// but different instances
ASSERT_EQ(Module(p2.attr("c1").toObject()).attr(attr_name).toInt(), 2);
ASSERT_EQ(Module(p2.attr("c2").toObject()).attr(attr_name).toInt(), 3);
}
void testModuleCloneInstance() {
auto cu = std::make_shared<CompilationUnit>();
auto cls = ClassType::create("foo.bar", cu, true);
auto attr_name = "attr";
cls->addAttribute(attr_name, IntType::get());
Module m(cu, cls);
auto v = IValue(2);
m.register_attribute(attr_name,
IntType::get(),
v,
false);
Module m2 = m.clone();
Module m3 = m.clone_instance();
// Make sure copy works
ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
ASSERT_EQ(m3.attr(attr_name).toInt(), 2);
// clone will copy both type and data, therefore we'll have a
// different type
ASSERT_NE(m.type(), m2.type());
// clone_instance only copies data, type is shared
ASSERT_EQ(m.type(), m3.type());
// change value of copied instance
m3.register_attribute(attr_name,
IntType::get(),
IValue(3),
false);
// Verify value of original instance doesn't change
ASSERT_EQ(m2.attr(attr_name).toInt(), 2);
ASSERT_EQ(m3.attr(attr_name).toInt(), 3);
}
void testModuleConstant() {
auto cu = std::make_shared<CompilationUnit>();
auto cls = ClassType::create("foo.bar", cu, true);
auto attr_name = "attr";
auto const_name = "const";
cls->addAttribute(attr_name, IntType::get());
cls->addConstant(const_name, IValue(3));
Module m(cu, cls);
auto v = IValue(2);
m.register_attribute(attr_name,
IntType::get(),
v,
false);
ASSERT_TRUE(m.hasattr(attr_name));
ASSERT_TRUE(m.hasattr(const_name));
ASSERT_EQ(m.attr(attr_name).toInt(), 2);
ASSERT_EQ(m.attr(const_name).toInt(), 3);
}
} // namespace jit
} // namespace torch