blob: 5c998661be236852455ac2b2bd0fe9e259151f08 [file] [log] [blame]
#include <catch.hpp>
#include <torch/nn/cursor.h>
#include <torch/nn/module.h>
#include <torch/tensor.h>
#include <torch/utils.h>
#include <iostream>
#include <iterator>
#include <map>
#include <string>
#include <utility>
#include <vector>
using namespace torch::nn;
using namespace torch::detail;
using Catch::StartsWith;
struct TestModule : public torch::nn::Module {
TestModule(int64_t size) {
tensor1 = register_parameter("tensor1", torch::randn({size}));
tensor2 = register_parameter("tensor2", torch::randn({size}));
}
torch::Tensor tensor1;
torch::Tensor tensor2;
};
struct Container : public torch::nn::Module {
template <typename... Ms>
explicit Container(Ms&&... ms) {
add(0, ms...);
}
void add(size_t) {}
template <typename Head, typename... Tail>
void add(size_t index, Head head, Tail... tail) {
add(std::to_string(index), std::move(head));
add(index + 1, tail...);
}
template <typename M>
void add(std::string name, M&& module) {
m.push_back(register_module(name, std::make_shared<M>(std::move(module))));
}
template <typename M>
void add(std::string name, std::shared_ptr<M>&& module) {
m.push_back(register_module(name, std::move(module)));
}
Module& operator[](size_t index) {
return *m.at(index);
}
std::vector<std::shared_ptr<Module>> m;
};
TEST_CASE("cursor/module") {
torch::manual_seed(0);
SECTION("Works for flat models (depth = 1)") {
Container model(TestModule(1), TestModule(2), TestModule(3));
auto cursor = model.modules();
SECTION("Iterates in the correct order") {
auto iterator = cursor.begin();
REQUIRE(&iterator->value == &model[0]);
REQUIRE(&(++iterator)->value == &model[1]);
REQUIRE(&(++iterator)->value == &model[2]);
REQUIRE(++iterator == cursor.end());
}
SECTION("names are flat") {
auto iterator = cursor.begin();
REQUIRE(iterator->key == "0");
REQUIRE((++iterator)->key == "1");
REQUIRE((++iterator)->key == "2");
}
SECTION("Apply works") {
size_t count = 0;
cursor.apply([&count, &model](Module& module) {
REQUIRE(&module == &model[count]);
count += 1;
});
REQUIRE(count == 3);
}
SECTION("Apply_items works") {
size_t count = 0;
cursor.apply_items(
[&count, &model](const std::string& key, Module& module) {
REQUIRE(&module == &model[count]);
count += 1;
});
REQUIRE(count == 3);
}
SECTION("Map works") {
std::vector<Module*> vector(3);
cursor.map(vector.begin(), [](Module& module) { return &module; });
REQUIRE(vector[0] == &model[0]);
REQUIRE(vector[1] == &model[1]);
REQUIRE(vector[2] == &model[2]);
std::list<Module*> list;
cursor.map(std::inserter(list, list.end()), [](Module& module) {
return &module;
});
REQUIRE(list.size() == 3);
auto iterator = list.begin();
REQUIRE(*iterator++ == &model[0]);
REQUIRE(*iterator++ == &model[1]);
REQUIRE(*iterator++ == &model[2]);
REQUIRE(iterator == list.end());
}
SECTION("Map_items works") {
std::map<std::string, Module*> output;
cursor.map_items(
std::inserter(output, output.end()),
[](const std::string& key, Module& module) {
return std::make_pair(key, &module);
});
REQUIRE(output.size() == 3);
REQUIRE(output.count("0"));
REQUIRE(output.count("1"));
REQUIRE(output.count("2"));
REQUIRE(output["0"] == &model[0]);
REQUIRE(output["1"] == &model[1]);
REQUIRE(output["2"] == &model[2]);
}
SECTION("Count works for flat models") {
REQUIRE(cursor.size() == model.m.size());
}
SECTION("find() finds the correct modules when given a valid key") {
REQUIRE(cursor.find("0") == &model[0]);
REQUIRE(cursor.find("1") == &model[1]);
REQUIRE(cursor.find("2") == &model[2]);
}
SECTION("find() returns nullptr when given an invalid key") {
REQUIRE(cursor.find("foo") == nullptr);
REQUIRE(cursor.find("bar") == nullptr);
}
SECTION("at(key) returns the correct modules when given a valid key") {
REQUIRE(&cursor.at("0") == &model[0]);
REQUIRE(&cursor.at("1") == &model[1]);
REQUIRE(&cursor.at("2") == &model[2]);
}
SECTION("at(key) throws when given an invalid key") {
REQUIRE_THROWS_WITH(cursor.at("foo"), StartsWith("No such key: 'foo'"));
REQUIRE_THROWS_WITH(cursor.at("bar"), StartsWith("No such key: 'bar'"));
}
SECTION(
"operator[key] returns the correct modules when given a valid key") {
REQUIRE(&cursor["0"] == &model[0]);
REQUIRE(&cursor["1"] == &model[1]);
REQUIRE(&cursor["2"] == &model[2]);
}
SECTION("operator[key] throws when given an invalid key") {
REQUIRE_THROWS_WITH(cursor["foo"], StartsWith("No such key: 'foo'"));
REQUIRE_THROWS_WITH(cursor["bar"], StartsWith("No such key: 'bar'"));
}
SECTION("at(index) returns the correct modules when given a valid index") {
REQUIRE(&cursor.at(0).value == &model[0]);
REQUIRE(&cursor.at(1).value == &model[1]);
REQUIRE(&cursor.at(2).value == &model[2]);
}
SECTION("at(index) throws when given an invalid index") {
REQUIRE_THROWS_WITH(
cursor.at(5),
StartsWith("Index 5 is out of range for cursor of size 3"));
REQUIRE_THROWS_WITH(
cursor.at(123),
StartsWith("Index 123 is out of range for cursor of size 3"));
}
SECTION(
"operator[index] returns the correct modules when given a valid index") {
REQUIRE(&cursor[0].value == &model[0]);
REQUIRE(&cursor[1].value == &model[1]);
REQUIRE(&cursor[2].value == &model[2]);
}
SECTION("operator[index] throws when given an invalid key") {
REQUIRE_THROWS_WITH(
cursor[5],
StartsWith("Index 5 is out of range for cursor of size 3"));
REQUIRE_THROWS_WITH(
cursor[123],
StartsWith("Index 123 is out of range for cursor of size 3"));
}
SECTION("contains() is correct") {
REQUIRE(cursor.contains("0"));
REQUIRE(cursor.contains("1"));
REQUIRE(cursor.contains("2"));
}
}
SECTION("Works for deeper hierarchies (depth > 1)") {
// clang-format off
Container model(
Container(
TestModule(1),
TestModule(2)),
TestModule(3),
Container(
TestModule(4),
Container(
TestModule(5),
TestModule(6))
));
// clang-format on
auto cursor = model.modules();
// This is sufficient for the hierarchical case
// (other tests build on top)
SECTION("Iterates in the correct order") {
auto iterator = cursor.begin();
REQUIRE(&iterator->value == &model[0]);
auto* seq = dynamic_cast<Container*>(&model[0]);
REQUIRE(seq != nullptr);
REQUIRE(&(++iterator)->value == &(*seq)[0]);
REQUIRE(&(++iterator)->value == &(*seq)[1]);
REQUIRE(&(++iterator)->value == &model[1]);
REQUIRE(&(++iterator)->value == &model[2]);
seq = dynamic_cast<Container*>(&model[2]);
REQUIRE(seq != nullptr);
REQUIRE(&(++iterator)->value == &(*seq)[0]);
REQUIRE(&(++iterator)->value == &(*seq)[1]);
seq = dynamic_cast<Container*>(&(*seq)[1]);
REQUIRE(seq != nullptr);
REQUIRE(&(++iterator)->value == &(*seq)[0]);
REQUIRE(&(++iterator)->value == &(*seq)[1]);
}
SECTION("children() returns only the first level of submodules") {
auto children = model.children();
REQUIRE(children.size() == 3);
REQUIRE(&children.at("0") == &model[0]);
REQUIRE(&children.at("1") == &model[1]);
REQUIRE(&children.at("2") == &model[2]);
REQUIRE(!children.contains("0.0"));
size_t count = 0;
for (auto& child : children) {
REQUIRE(child.key == std::to_string(count));
REQUIRE(&child.value == &model[count]);
count += 1;
}
}
}
}
TEST_CASE("cursor/parameter") {
torch::manual_seed(0);
SECTION("Works for single models") {
TestModule model(1);
auto cursor = model.parameters();
SECTION("Iterates in the correct order") {
auto iterator = cursor.begin();
REQUIRE(iterator->value.equal(model.tensor1));
REQUIRE((++iterator)->value.equal(model.tensor2));
}
}
SECTION("Works for flat models (depth = 1)") {
auto first = std::make_shared<TestModule>(1);
auto second = std::make_shared<TestModule>(2);
Container model(first, second);
auto cursor = model.parameters();
SECTION("Iterates in the correct order") {
auto iterator = cursor.begin();
REQUIRE(iterator->value.equal(first->tensor1));
REQUIRE((++iterator)->value.equal(first->tensor2));
REQUIRE((++iterator)->value.equal(second->tensor1));
REQUIRE((++iterator)->value.equal(second->tensor2));
}
SECTION("Apply_items works") {
size_t count = 0;
cursor.apply_items([&count, &model, &first, &second](
const std::string& key, torch::Tensor& tensor) {
switch (count) {
case 0: {
REQUIRE(tensor.equal(first->tensor1));
break;
}
case 1: {
REQUIRE(tensor.equal(first->tensor2));
break;
}
case 2: {
REQUIRE(tensor.equal(second->tensor1));
break;
}
case 3: {
REQUIRE(tensor.equal(second->tensor2));
break;
}
}
count += 1;
});
REQUIRE(count == 4);
}
// Other tests are correct based on correct iteration behavior and apply
// working.
}
SECTION("Works for deeper hierarchies (depth > 1)") {
std::vector<std::shared_ptr<TestModule>> modules;
for (size_t i = 1; i <= 6; ++i) {
modules.push_back(std::make_shared<TestModule>(i));
}
// clang-format off
Container model(
Container(
modules[0],
modules[1]),
modules[2],
Container(
modules[3],
Container(
modules[4],
modules[5])
));
// clang-format on
auto cursor = model.parameters();
SECTION("Iterates in the correct order") {
auto iterator = cursor.begin();
REQUIRE(iterator->value.equal(modules[0]->tensor1));
REQUIRE((++iterator)->value.equal(modules[0]->tensor2));
for (size_t index = 1; index < 6; ++index) {
REQUIRE((++iterator)->value.equal(modules[index]->tensor1));
REQUIRE((++iterator)->value.equal(modules[index]->tensor2));
}
}
SECTION("names are hierarchical") {
auto iterator = cursor.begin();
REQUIRE(iterator->key == "0.0.tensor1");
REQUIRE((++iterator)->key == "0.0.tensor2");
REQUIRE((++iterator)->key == "0.1.tensor1");
REQUIRE((++iterator)->key == "0.1.tensor2");
REQUIRE((++iterator)->key == "1.tensor1");
REQUIRE((++iterator)->key == "1.tensor2");
REQUIRE((++iterator)->key == "2.0.tensor1");
REQUIRE((++iterator)->key == "2.0.tensor2");
REQUIRE((++iterator)->key == "2.1.0.tensor1");
REQUIRE((++iterator)->key == "2.1.0.tensor2");
REQUIRE((++iterator)->key == "2.1.1.tensor1");
REQUIRE((++iterator)->key == "2.1.1.tensor2");
REQUIRE(++iterator == cursor.end());
}
}
}
TEST_CASE("cursor/non-const-to-const-conversion") {
torch::manual_seed(0);
auto first = std::make_shared<TestModule>(1);
auto second = std::make_shared<TestModule>(2);
Container model(first, second);
{
ConstModuleCursor const_cursor(model.modules());
{
ModuleCursor cursor = model.modules();
ConstModuleCursor const_cursor = cursor;
}
}
{
ConstParameterCursor const_cursor(model.parameters());
{
ParameterCursor cursor = model.parameters();
ConstParameterCursor const_cursor = cursor;
}
}
{
ConstBufferCursor const_cursor(model.buffers());
{
BufferCursor cursor = model.buffers();
ConstBufferCursor const_cursor = cursor;
}
}
}
TEST_CASE("cursor/can-invoke-const-method-on-const-cursor") {
torch::manual_seed(0);
TestModule model(1);
/// This will only compile if `Cursor` has the appropriate const methods.
const auto cursor = model.parameters();
REQUIRE(cursor.contains("tensor1"));
}