| #include <ATen/ATen.h> |
| #include "ATen/core/ivalue.h" |
| #include <test/cpp/jit/test_base.h> |
| #include <test/cpp/jit/test_utils.h> |
| #include <torch/torch.h> |
| |
| namespace torch { |
| namespace jit { |
| |
| using namespace torch::jit::script; |
| |
| void testIValue() { |
| c10::List<int64_t> foo({3, 4, 5}); |
| ASSERT_EQ(foo.use_count(), 1); |
| IValue bar{foo}; |
| ASSERT_EQ(foo.use_count(), 2); |
| auto baz = bar; |
| ASSERT_EQ(foo.use_count(), 3); |
| auto foo2 = std::move(bar); |
| ASSERT_EQ(foo.use_count(), 3); |
| ASSERT_TRUE(foo2.isIntList()); |
| ASSERT_TRUE(bar.isNone()); |
| foo2 = IValue(4.0); |
| ASSERT_TRUE(foo2.isDouble()); |
| ASSERT_EQ(foo2.toDouble(), 4.0); |
| ASSERT_EQ(foo.use_count(), 2); |
| ASSERT_TRUE(baz.toIntVector() == std::vector<int64_t>({3, 4, 5})); |
| |
| auto move_it = std::move(baz).toIntList(); |
| ASSERT_EQ(foo.use_count(), 2); |
| ASSERT_TRUE(baz.isNone()); |
| IValue i(4); |
| ASSERT_TRUE(i.isInt()); |
| ASSERT_EQ(i.toInt(), 4); |
| IValue dlist(c10::List<double>({3.5})); |
| ASSERT_TRUE(dlist.isDoubleList()); |
| ASSERT_TRUE(dlist.toDoubleVector() == std::vector<double>({3.5})); |
| std::move(dlist).toDoubleList(); |
| ASSERT_TRUE(dlist.isNone()); |
| dlist = IValue(c10::List<double>({3.4})); |
| ASSERT_TRUE(dlist.toDoubleVector() == std::vector<double>({3.4})); |
| IValue the_list( |
| at::ivalue::Tuple::create({IValue(3.4), IValue(4), IValue(foo)})); |
| ASSERT_EQ(foo.use_count(), 3); |
| ASSERT_TRUE(the_list.isTuple()); |
| auto first = the_list.toTuple()->elements()[1]; |
| ASSERT_EQ(first.toInt(), 4); |
| at::Tensor tv = at::rand({3, 4}); |
| IValue ten(tv); |
| ASSERT_EQ(tv.use_count(), 2); |
| auto ten2 = ten; |
| ASSERT_EQ(tv.use_count(), 3); |
| ASSERT_TRUE(ten2.toTensor().equal(ten.toTensor())); |
| std::move(ten2).toTensor(); |
| ASSERT_EQ(tv.use_count(), 2); |
| |
| { |
| std::tuple<int64_t, at::Tensor> t = std::make_tuple(123, at::randn({1})); |
| auto iv = IValue(t); |
| auto t_ = iv.to<std::tuple<int64_t, at::Tensor>>(); |
| ASSERT_EQ(std::get<0>(t_), 123); |
| ASSERT_EQ( |
| std::get<1>(t_).item().to<float>(), std::get<1>(t).item().to<float>()); |
| } |
| |
| // unsafeRemoveAttr in ivalue::Object |
| auto cu = std::make_shared<CompilationUnit>(); |
| auto cls = ClassType::create("foo.bar", cu); |
| cls->addAttribute("attr1", TensorType::get()); |
| cls->addAttribute("attr2", TensorType::get()); |
| auto obj = c10::ivalue::Object::create( |
| c10::StrongTypePtr(cu, cls), cls->numAttributes()); |
| obj->unsafeRemoveAttr("attr1"); |
| // attr1 is not removed in the type |
| ASSERT_TRUE(cls->hasAttribute("attr1")); |
| ASSERT_TRUE(cls->hasAttribute("attr2")); |
| ASSERT_TRUE(obj->slots().size() == 1); |
| |
| // Test tuple print |
| { |
| IValue tp = std::make_tuple(3); |
| std::stringstream ss; |
| ss << tp; |
| ASSERT_EQ(ss.str(), "(3,)"); |
| } |
| |
| { |
| IValue tp = std::make_tuple(3, 3); |
| std::stringstream ss; |
| ss << tp; |
| ASSERT_EQ(ss.str(), "(3, 3)"); |
| } |
| } |
| |
| } // namespace jit |
| } // namespace torch |