| #include <catch.hpp> |
| |
| #include <torch/torch.h> |
| #include <torch/utils.h> |
| #include <torch/nn/modules/any.h> |
| |
| #include <algorithm> |
| #include <string> |
| |
| using namespace torch::nn; |
| using namespace torch::detail; |
| |
| using Catch::Contains; |
| using Catch::StartsWith; |
| |
| TEST_CASE("any-module") { |
| torch::manual_seed(0); |
| SECTION("int()") { |
| struct M : torch::nn::Module { |
| int forward() { |
| return 123; |
| } |
| }; |
| AnyModule any(M{}); |
| REQUIRE(any.forward().get<int>() == 123); |
| } |
| SECTION("int(int)") { |
| struct M : torch::nn::Module { |
| int forward(int x) { |
| return x; |
| } |
| }; |
| AnyModule any(M{}); |
| REQUIRE(any.forward(5).get<int>() == 5); |
| } |
| SECTION("const char*(const char*)") { |
| struct M : torch::nn::Module { |
| const char* forward(const char* x) { |
| return x; |
| } |
| }; |
| AnyModule any(M{}); |
| REQUIRE(any.forward("hello").get<const char*>() == std::string("hello")); |
| } |
| |
| SECTION("string(int, const double)") { |
| struct M : torch::nn::Module { |
| std::string forward(int x, const double f) { |
| return std::to_string(static_cast<int>(x + f)); |
| } |
| }; |
| AnyModule any(M{}); |
| int x = 4; |
| REQUIRE(any.forward(x, 3.14).get<std::string>() == std::string("7")); |
| } |
| |
| SECTION("Tensor(string, const string&, string&&)") { |
| struct M : torch::nn::Module { |
| torch::Tensor forward( |
| std::string a, |
| const std::string& b, |
| std::string&& c) { |
| const auto s = a + b + c; |
| return torch::ones({static_cast<int64_t>(s.size())}); |
| } |
| }; |
| AnyModule any(M{}); |
| REQUIRE( |
| any.forward(std::string("a"), std::string("ab"), std::string("abc")) |
| .get<torch::Tensor>() |
| .sum() |
| .toCInt() == 6); |
| } |
| SECTION("wrong argument type") { |
| struct M : torch::nn::Module { |
| int forward(float x) { |
| return x; |
| } |
| }; |
| AnyModule any(M{}); |
| REQUIRE_THROWS_WITH( |
| any.forward(5.0), |
| StartsWith("Expected argument #0 to be of type float, " |
| "but received value of type double")); |
| } |
| SECTION("wrong number of arguments") { |
| struct M : torch::nn::Module { |
| int forward(int a, int b) { |
| return a + b; |
| } |
| }; |
| AnyModule any(M{}); |
| REQUIRE_THROWS_WITH( |
| any.forward(), |
| Contains("M's forward() method expects 2 arguments, but received 0")); |
| REQUIRE_THROWS_WITH( |
| any.forward(5), |
| Contains("M's forward() method expects 2 arguments, but received 1")); |
| REQUIRE_THROWS_WITH( |
| any.forward(1, 2, 3), |
| Contains("M's forward() method expects 2 arguments, but received 3")); |
| } |
| SECTION("get()") { |
| struct M : torch::nn::Module { |
| explicit M(int value_) : torch::nn::Module("M"), value(value_) {} |
| int value; |
| int forward(float x) { |
| return x; |
| } |
| }; |
| AnyModule any(M{5}); |
| |
| SECTION("good cast") { |
| REQUIRE(any.get<M>().value == 5); |
| } |
| |
| SECTION("bad cast") { |
| struct N : torch::nn::Module {}; |
| REQUIRE_THROWS_WITH(any.get<N>(), StartsWith("Attempted to cast module")); |
| } |
| } |
| SECTION("ptr()") { |
| struct M : torch::nn::Module { |
| explicit M(int value_) : torch::nn::Module("M"), value(value_) {} |
| int value; |
| int forward(float x) { |
| return x; |
| } |
| }; |
| AnyModule any(M{5}); |
| |
| SECTION("base class cast") { |
| auto ptr = any.ptr(); |
| REQUIRE(ptr != nullptr); |
| REQUIRE(ptr->name() == "M"); |
| } |
| |
| SECTION("good downcast") { |
| auto ptr = any.ptr<M>(); |
| REQUIRE(ptr != nullptr); |
| REQUIRE(ptr->value == 5); |
| } |
| |
| SECTION("bad downcast") { |
| struct N : torch::nn::Module {}; |
| REQUIRE_THROWS_WITH(any.ptr<N>(), StartsWith("Attempted to cast module")); |
| } |
| } |
| SECTION("default state is empty") { |
| struct M : torch::nn::Module { |
| explicit M(int value_) : value(value_) {} |
| int value; |
| int forward(float x) { |
| return x; |
| } |
| }; |
| AnyModule any; |
| REQUIRE(any.is_empty()); |
| any = std::make_shared<M>(5); |
| REQUIRE(!any.is_empty()); |
| REQUIRE(any.get<M>().value == 5); |
| } |
| SECTION("all methods throw for empty AnyModule") { |
| struct M : torch::nn::Module { |
| int forward(int x) { |
| return x; |
| } |
| }; |
| AnyModule any; |
| REQUIRE(any.is_empty()); |
| REQUIRE_THROWS_WITH( |
| any.get<M>(), StartsWith("Cannot call get() on an empty AnyModule")); |
| REQUIRE_THROWS_WITH( |
| any.ptr<M>(), StartsWith("Cannot call ptr() on an empty AnyModule")); |
| REQUIRE_THROWS_WITH( |
| any.ptr(), StartsWith("Cannot call ptr() on an empty AnyModule")); |
| REQUIRE_THROWS_WITH( |
| any.type_info(), |
| StartsWith("Cannot call type_info() on an empty AnyModule")); |
| REQUIRE_THROWS_WITH( |
| any.forward<int>(5), |
| StartsWith("Cannot call forward() on an empty AnyModule")); |
| } |
| SECTION("can move assign differentm modules") { |
| struct M : torch::nn::Module { |
| std::string forward(int x) { |
| return std::to_string(x); |
| } |
| }; |
| struct N : torch::nn::Module { |
| int forward(float x) { |
| return 3 + x; |
| } |
| }; |
| AnyModule any; |
| REQUIRE(any.is_empty()); |
| any = std::make_shared<M>(); |
| REQUIRE(!any.is_empty()); |
| REQUIRE(any.forward(5).get<std::string>() == "5"); |
| any = std::make_shared<N>(); |
| REQUIRE(!any.is_empty()); |
| REQUIRE(any.forward(5.0f).get<int>() == 8); |
| } |
| SECTION("constructs from ModuleHolder") { |
| struct MImpl : torch::nn::Module { |
| explicit MImpl(int value_) : torch::nn::Module("M"), value(value_) {} |
| int value; |
| int forward(float x) { |
| return x; |
| } |
| }; |
| |
| struct M : torch::nn::ModuleHolder<MImpl> { |
| using torch::nn::ModuleHolder<MImpl>::ModuleHolder; |
| using torch::nn::ModuleHolder<MImpl>::get; |
| }; |
| |
| AnyModule any(M{5}); |
| REQUIRE(any.get<MImpl>().value == 5); |
| REQUIRE(any.get<M>()->value == 5); |
| } |
| SECTION("converts at::Tensor to torch::Tensor correctly") { |
| struct M : torch::nn::Module { |
| torch::Tensor forward(torch::Tensor input) { |
| return input; |
| } |
| }; |
| struct N : torch::nn::Module { |
| at::Tensor forward(at::Tensor input) { |
| return input; |
| } |
| }; |
| { |
| // When you get an at::Tensor by performing an operation on a |
| // torch::Tensor, the tensor should be converted back to torch::Tensor |
| // before being passed to the function (to avoid a type mismatch). |
| AnyModule any(M{}); |
| at::Tensor tensor_that_is_actually_a_variable = torch::ones(5) * 2; |
| REQUIRE( |
| any.forward(tensor_that_is_actually_a_variable) |
| .get<torch::Tensor>() |
| .sum() |
| .toCFloat() == 10); |
| // But tensors that are really tensors should just error. |
| REQUIRE_THROWS_WITH( |
| any.forward(at::ones(5)), |
| StartsWith( |
| "Expected argument #0 to be of type torch::autograd::Variable, " |
| "but received value of type at::Tensor")); |
| } |
| { |
| // If the function does really accept an `at::Tensor`, this should still |
| // work. |
| AnyModule any(N{}); |
| REQUIRE(any.forward(at::ones(5)).get<at::Tensor>().sum().toCFloat() == 5); |
| } |
| } |
| } |
| |
| namespace torch { |
| namespace nn { |
| struct TestValue { |
| template <typename T> |
| explicit TestValue(T&& value) : value_(std::forward<T>(value)) {} |
| AnyModule::Value operator()() { |
| return std::move(value_); |
| } |
| AnyModule::Value value_; |
| }; |
| template <typename T> |
| AnyModule::Value make_value(T&& value) { |
| return TestValue(std::forward<T>(value))(); |
| } |
| } // namespace nn |
| } // namespace torch |
| |
| TEST_CASE("any-value") { |
| torch::manual_seed(0); |
| SECTION("gets the correct value for the right type") { |
| SECTION("int") { |
| auto value = make_value(5); |
| // const and non-const types have the same typeid() |
| REQUIRE(value.try_get<int>() != nullptr); |
| REQUIRE(value.try_get<const int>() != nullptr); |
| REQUIRE(value.get<int>() == 5); |
| } |
| SECTION("const int") { |
| auto value = make_value(5); |
| REQUIRE(value.try_get<const int>() != nullptr); |
| REQUIRE(value.try_get<int>() != nullptr); |
| REQUIRE(value.get<const int>() == 5); |
| } |
| SECTION("const char*") { |
| auto value = make_value("hello"); |
| REQUIRE(value.try_get<const char*>() != nullptr); |
| REQUIRE(value.get<const char*>() == std::string("hello")); |
| } |
| SECTION("std::string") { |
| auto value = make_value(std::string("hello")); |
| REQUIRE(value.try_get<std::string>() != nullptr); |
| REQUIRE(value.get<std::string>() == "hello"); |
| } |
| SECTION("pointers") { |
| std::string s("hello"); |
| std::string* p = &s; |
| auto value = make_value(p); |
| REQUIRE(value.try_get<std::string*>() != nullptr); |
| REQUIRE(*value.get<std::string*>() == "hello"); |
| } |
| SECTION("references") { |
| std::string s("hello"); |
| const std::string& t = s; |
| auto value = make_value(t); |
| REQUIRE(value.try_get<std::string>() != nullptr); |
| REQUIRE(value.get<std::string>() == "hello"); |
| } |
| } |
| SECTION("try_get returns nullptr for the wrong type") { |
| auto value = make_value(5); |
| REQUIRE(value.try_get<int>() != nullptr); |
| REQUIRE(value.try_get<float>() == nullptr); |
| REQUIRE(value.try_get<long>() == nullptr); |
| REQUIRE(value.try_get<std::string>() == nullptr); |
| } |
| SECTION("get throws for the wrong type") { |
| auto value = make_value(5); |
| REQUIRE(value.try_get<int>() != nullptr); |
| REQUIRE_THROWS_WITH( |
| value.get<float>(), |
| StartsWith("Attempted to cast Value to float, " |
| "but its actual type is int")); |
| REQUIRE_THROWS_WITH( |
| value.get<long>(), |
| StartsWith("Attempted to cast Value to long, " |
| "but its actual type is int")); |
| } |
| SECTION("move is allowed") { |
| auto value = make_value(5); |
| SECTION("construction") { |
| auto copy = make_value(std::move(value)); |
| REQUIRE(copy.try_get<int>() != nullptr); |
| REQUIRE(copy.get<int>() == 5); |
| } |
| SECTION("assignment") { |
| auto copy = make_value(10); |
| copy = std::move(value); |
| REQUIRE(copy.try_get<int>() != nullptr); |
| REQUIRE(copy.get<int>() == 5); |
| } |
| } |
| SECTION("type_info is correct") { |
| SECTION("int") { |
| auto value = make_value(5); |
| REQUIRE(value.type_info().hash_code() == typeid(int).hash_code()); |
| } |
| SECTION("const char") { |
| auto value = make_value("hello"); |
| REQUIRE(value.type_info().hash_code() == typeid(const char*).hash_code()); |
| } |
| SECTION("std::string") { |
| auto value = make_value(std::string("hello")); |
| REQUIRE(value.type_info().hash_code() == typeid(std::string).hash_code()); |
| } |
| } |
| } |