blob: e5d50c1fc60d770b3b3395e1644b0e0aa5a835a4 [file] [log] [blame]
#include <gtest/gtest.h>
#include <torch/nn/cursor.h>
#include <torch/nn/module.h>
#include <torch/tensor.h>
#include <torch/utils.h>
#include <test/cpp/api/support.h>
#include <iostream>
#include <iterator>
#include <map>
#include <string>
#include <utility>
#include <vector>
using namespace torch::nn;
using namespace torch::detail;
struct TestModule : public torch::nn::Module {
TestModule(int64_t size) {
tensor1 = register_parameter("tensor1", torch::randn({size}));
tensor2 = register_parameter("tensor2", torch::randn({size}));
}
torch::Tensor tensor1;
torch::Tensor tensor2;
};
struct Container : public torch::nn::Module {
template <typename... Ms>
explicit Container(Ms&&... ms) {
add(0, ms...);
}
void add(size_t) {}
template <typename Head, typename... Tail>
void add(size_t index, Head head, Tail... tail) {
add(std::to_string(index), std::move(head));
add(index + 1, tail...);
}
template <typename M>
void add(std::string name, M&& module) {
m.push_back(register_module(name, std::make_shared<M>(std::move(module))));
}
template <typename M>
void add(std::string name, std::shared_ptr<M>&& module) {
m.push_back(register_module(name, std::move(module)));
}
Module& operator[](size_t index) {
return *m.at(index);
}
std::vector<std::shared_ptr<Module>> m;
};
struct ModuleCursorFlatTest : torch::test::SeedingFixture {
ModuleCursorFlatTest()
: model(TestModule(1), TestModule(2), TestModule(3)),
cursor(model.modules()) {}
Container model;
ModuleCursor cursor;
};
TEST_F(ModuleCursorFlatTest, IteratesInTheCorrectOrder) {
auto iterator = cursor.begin();
ASSERT_EQ(&iterator->value, &model[0]);
ASSERT_EQ(&(++iterator)->value, &model[1]);
ASSERT_EQ(&(++iterator)->value, &model[2]);
ASSERT_EQ(++iterator, cursor.end());
}
TEST_F(ModuleCursorFlatTest, NamesAreFlat) {
auto iterator = cursor.begin();
ASSERT_EQ(iterator->key, "0");
ASSERT_EQ((++iterator)->key, "1");
ASSERT_EQ((++iterator)->key, "2");
}
TEST_F(ModuleCursorFlatTest, Apply) {
size_t count = 0;
cursor.apply([this, &count](Module& module) {
ASSERT_EQ(&module, &model[count]);
count += 1;
});
ASSERT_EQ(count, 3);
}
TEST_F(ModuleCursorFlatTest, ApplyItems) {
size_t count = 0;
cursor.apply_items([this, &count](const std::string& key, Module& module) {
ASSERT_EQ(&module, &model[count]);
count += 1;
});
ASSERT_EQ(count, 3);
}
TEST_F(ModuleCursorFlatTest, Map) {
std::vector<Module*> vector(3);
cursor.map(vector.begin(), [](Module& module) { return &module; });
ASSERT_EQ(vector[0], &model[0]);
ASSERT_EQ(vector[1], &model[1]);
ASSERT_EQ(vector[2], &model[2]);
std::list<Module*> list;
cursor.map(
std::inserter(list, list.end()), [](Module& module) { return &module; });
ASSERT_EQ(list.size(), 3);
auto iterator = list.begin();
ASSERT_EQ(*iterator++, &model[0]);
ASSERT_EQ(*iterator++, &model[1]);
ASSERT_EQ(*iterator++, &model[2]);
ASSERT_EQ(iterator, list.end());
}
TEST_F(ModuleCursorFlatTest, MapItems) {
std::map<std::string, Module*> output;
cursor.map_items(
std::inserter(output, output.end()),
[](const std::string& key, Module& module) {
return std::make_pair(key, &module);
});
ASSERT_EQ(output.size(), 3);
ASSERT_TRUE(output.count("0"));
ASSERT_TRUE(output.count("1"));
ASSERT_TRUE(output.count("2"));
ASSERT_EQ(output["0"], &model[0]);
ASSERT_EQ(output["1"], &model[1]);
ASSERT_EQ(output["2"], &model[2]);
}
TEST_F(ModuleCursorFlatTest, Count) {
ASSERT_EQ(cursor.size(), model.m.size());
}
TEST_F(ModuleCursorFlatTest, FindReturnsTheCorrectModulesWhenGivenAValidKey) {
ASSERT_EQ(cursor.find("0"), &model[0]);
ASSERT_EQ(cursor.find("1"), &model[1]);
ASSERT_EQ(cursor.find("2"), &model[2]);
}
TEST_F(ModuleCursorFlatTest, FindReturnsNullptrWhenGivenAnInvalidKey) {
ASSERT_EQ(cursor.find("foo"), nullptr);
ASSERT_EQ(cursor.find("bar"), nullptr);
}
TEST_F(
ModuleCursorFlatTest,
AtWithKeyReturnsTheCorrectModulesWhenGivenAValidKey) {
ASSERT_EQ(&cursor.at("0"), &model[0]);
ASSERT_EQ(&cursor.at("1"), &model[1]);
ASSERT_EQ(&cursor.at("2"), &model[2]);
}
TEST_F(ModuleCursorFlatTest, AtWithKeyThrowsWhenGivenAnInvalidKey) {
ASSERT_THROWS_WITH(cursor.at("foo"), "No such key: 'foo'");
ASSERT_THROWS_WITH(cursor.at("bar"), "No such key: 'bar'");
}
TEST_F(
ModuleCursorFlatTest,
SubscriptOperatorWithKeyReturnsCorrectModulesWhenGivenAValidKey) {
ASSERT_EQ(&cursor["0"], &model[0]);
ASSERT_EQ(&cursor["1"], &model[1]);
ASSERT_EQ(&cursor["2"], &model[2]);
}
TEST_F(ModuleCursorFlatTest, SubscriptOperatorWithKeyWhenGivenAnInvalidKey) {
ASSERT_THROWS_WITH(cursor["foo"], "No such key: 'foo'");
ASSERT_THROWS_WITH(cursor["bar"], "No such key: 'bar'");
}
TEST_F(
ModuleCursorFlatTest,
AtWithIndexReturnsTheCorrectModulesWhenGivenAValidKey) {
ASSERT_EQ(&cursor.at(0).value, &model[0]);
ASSERT_EQ(&cursor.at(1).value, &model[1]);
ASSERT_EQ(&cursor.at(2).value, &model[2]);
}
TEST_F(ModuleCursorFlatTest, AtWithIndexThrowsWhenGivenAnInvalidKey) {
ASSERT_THROWS_WITH(
cursor.at(5), "Index 5 is out of range for cursor of size 3");
ASSERT_THROWS_WITH(
cursor.at(123), "Index 123 is out of range for cursor of size 3");
}
TEST_F(
ModuleCursorFlatTest,
SubscriptOperatorWithIndexReturnsCorrectModulesWhenGivenAValidKey) {
ASSERT_EQ(&cursor[0].value, &model[0]);
ASSERT_EQ(&cursor[1].value, &model[1]);
ASSERT_EQ(&cursor[2].value, &model[2]);
}
TEST_F(ModuleCursorFlatTest, SubscriptOperatorWithIndexWhenGivenAnInvalidKey) {
ASSERT_THROWS_WITH(cursor[5], "Index 5 is out of range for cursor of size 3");
ASSERT_THROWS_WITH(
cursor[123], "Index 123 is out of range for cursor of size 3");
}
TEST_F(ModuleCursorFlatTest, ContainReturnsTrueWhenKeyIsPresent) {
ASSERT_TRUE(cursor.contains("0"));
ASSERT_TRUE(cursor.contains("1"));
ASSERT_TRUE(cursor.contains("2"));
}
struct ModuleCursorDeepTest : torch::test::SeedingFixture {
ModuleCursorDeepTest()
: model(
Container(TestModule(1), TestModule(2)),
TestModule(3),
Container(TestModule(4), Container(TestModule(5), TestModule(6)))) {
}
Container model;
};
TEST_F(ModuleCursorDeepTest, IteratesInTheCorrectOrder) {
auto cursor = model.modules();
auto iterator = cursor.begin();
ASSERT_EQ(&iterator->value, &model[0]);
auto* seq = dynamic_cast<Container*>(&model[0]);
ASSERT_NE(seq, nullptr);
ASSERT_EQ(&(++iterator)->value, &(*seq)[0]);
ASSERT_EQ(&(++iterator)->value, &(*seq)[1]);
ASSERT_EQ(&(++iterator)->value, &model[1]);
ASSERT_EQ(&(++iterator)->value, &model[2]);
seq = dynamic_cast<Container*>(&model[2]);
ASSERT_NE(seq, nullptr);
ASSERT_EQ(&(++iterator)->value, &(*seq)[0]);
ASSERT_EQ(&(++iterator)->value, &(*seq)[1]);
seq = dynamic_cast<Container*>(&(*seq)[1]);
ASSERT_NE(seq, nullptr);
ASSERT_EQ(&(++iterator)->value, &(*seq)[0]);
ASSERT_EQ(&(++iterator)->value, &(*seq)[1]);
}
TEST_F(ModuleCursorDeepTest, ChildrenReturnsOnlyTheFirstLevelOfSubmodules) {
auto children = model.children();
ASSERT_EQ(children.size(), 3);
ASSERT_EQ(&children.at("0"), &model[0]);
ASSERT_EQ(&children.at("1"), &model[1]);
ASSERT_EQ(&children.at("2"), &model[2]);
ASSERT_FALSE(children.contains("0.0"));
size_t count = 0;
for (auto& child : children) {
ASSERT_EQ(child.key, std::to_string(count));
ASSERT_EQ(&child.value, &model[count]);
count += 1;
}
}
struct ParameterCursorFlatTest : torch::test::SeedingFixture {
ParameterCursorFlatTest()
: first(std::make_shared<TestModule>(1)),
second(std::make_shared<TestModule>(2)),
model(first, second),
cursor(model.parameters()) {}
std::shared_ptr<TestModule> first, second;
Container model;
ParameterCursor cursor;
};
TEST(ParameterCursorTest, IteratesInTheCorrectOrderOverSimpleModels) {
torch::manual_seed(0);
TestModule model(1);
auto cursor = model.parameters();
auto iterator = cursor.begin();
ASSERT_TRUE(iterator->value.equal(model.tensor1));
ASSERT_TRUE((++iterator)->value.equal(model.tensor2));
}
TEST_F(ParameterCursorFlatTest, IteratesInTheCorrectOrder) {
auto iterator = cursor.begin();
ASSERT_TRUE(iterator->value.equal(first->tensor1));
ASSERT_TRUE((++iterator)->value.equal(first->tensor2));
ASSERT_TRUE((++iterator)->value.equal(second->tensor1));
ASSERT_TRUE((++iterator)->value.equal(second->tensor2));
}
TEST_F(ParameterCursorFlatTest, ApplyItemsWorks) {
size_t count = 0;
cursor.apply_items(
[this, &count](const std::string& key, torch::Tensor& tensor) {
switch (count) {
case 0: {
ASSERT_TRUE(tensor.equal(first->tensor1));
break;
}
case 1: {
ASSERT_TRUE(tensor.equal(first->tensor2));
break;
}
case 2: {
ASSERT_TRUE(tensor.equal(second->tensor1));
break;
}
case 3: {
ASSERT_TRUE(tensor.equal(second->tensor2));
break;
}
}
count += 1;
});
ASSERT_EQ(count, 4);
}
struct ParameterCursorDeepTest : torch::test::SeedingFixture {
std::vector<std::shared_ptr<TestModule>> make_modules() {
std::vector<std::shared_ptr<TestModule>> modules;
for (size_t i = 1; i <= 6; ++i) {
modules.push_back(std::make_shared<TestModule>(i));
}
return modules;
}
ParameterCursorDeepTest()
: modules(make_modules()),
model(
Container(modules[0], modules[1]),
modules[2],
Container(modules[3], Container(modules[4], modules[5]))) {}
std::vector<std::shared_ptr<TestModule>> modules;
Container model;
};
TEST_F(ParameterCursorDeepTest, IteratesInTheCorrectOrderOverDeepModels) {
auto cursor = model.parameters();
auto iterator = cursor.begin();
ASSERT_TRUE(iterator->value.equal(modules[0]->tensor1));
ASSERT_TRUE((++iterator)->value.equal(modules[0]->tensor2));
for (size_t index = 1; index < 6; ++index) {
ASSERT_TRUE((++iterator)->value.equal(modules[index]->tensor1));
ASSERT_TRUE((++iterator)->value.equal(modules[index]->tensor2));
}
}
TEST_F(ParameterCursorDeepTest, NamesAreHierarchical) {
auto cursor = model.parameters();
auto iterator = cursor.begin();
ASSERT_EQ(iterator->key, "0.0.tensor1");
ASSERT_EQ((++iterator)->key, "0.0.tensor2");
ASSERT_EQ((++iterator)->key, "0.1.tensor1");
ASSERT_EQ((++iterator)->key, "0.1.tensor2");
ASSERT_EQ((++iterator)->key, "1.tensor1");
ASSERT_EQ((++iterator)->key, "1.tensor2");
ASSERT_EQ((++iterator)->key, "2.0.tensor1");
ASSERT_EQ((++iterator)->key, "2.0.tensor2");
ASSERT_EQ((++iterator)->key, "2.1.0.tensor1");
ASSERT_EQ((++iterator)->key, "2.1.0.tensor2");
ASSERT_EQ((++iterator)->key, "2.1.1.tensor1");
ASSERT_EQ((++iterator)->key, "2.1.1.tensor2");
ASSERT_EQ(++iterator, cursor.end());
}
struct CursorTest : torch::test::SeedingFixture {};
TEST_F(CursorTest, NonConstToConstConversion) {
auto first = std::make_shared<TestModule>(1);
auto second = std::make_shared<TestModule>(2);
Container model(first, second);
{
ConstModuleCursor const_cursor(model.modules());
{
ModuleCursor cursor = model.modules();
ConstModuleCursor const_cursor = cursor;
}
}
{
ConstParameterCursor const_cursor(model.parameters());
{
ParameterCursor cursor = model.parameters();
ConstParameterCursor const_cursor = cursor;
}
}
{
ConstBufferCursor const_cursor(model.buffers());
{
BufferCursor cursor = model.buffers();
ConstBufferCursor const_cursor = cursor;
}
}
}
TEST_F(CursorTest, CanInvokeConstMethodOnConstCursor) {
TestModule model(1);
/// This will only compile if `Cursor` has the appropriate const methods.
const auto cursor = model.parameters();
ASSERT_TRUE(cursor.contains("tensor1"));
}