blob: 5932f5676f76aef019ca8a569e1a6da50bf6056d [file] [log] [blame]
#include <gtest/gtest.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/runtime/static/impl.h>
#include <torch/csrc/jit/runtime/static/ops.h>
using namespace torch::jit;
namespace {
StaticModule makeStaticModuleFromScript(const std::string& script) {
Module m("module");
m.define(script);
return StaticModule(m);
}
} // namespace
/**
* Test that StaticModule::value_group groups values of the graph into
* 1) Inputs/Constants and their aliases 2) Outputs and their aliases.
*/
TEST(StaticModule, ValueGroup) {
const std::string src = R"IR(
graph(%input0 : Tensor, %input1 : Tensor):
# Constants.
%0 : int = prim::Constant[value=1]()
# Internal values.
%1 : Tensor = aten::add(%input0, %input1, %0)
# This includes aliases of output.
%2 : Tensor = aten::add(%input0, %1, %0)
# This includes output.
%3 : (Tensor) = prim::TupleConstruct(%2)
return (%3)
)IR";
auto input_graph = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(src, input_graph.get());
torch::jit::StaticModule sm(input_graph);
const Graph& graph = sm.graph();
std::vector<const Node*> nodes(graph.nodes().begin(), graph.nodes().end());
const auto& value_group = sm.value_group();
std::vector<const Value*> expected_input_aliases{graph.inputs()[0], graph.inputs()[1], nodes[0]->output()};
for (auto* value : expected_input_aliases) {
EXPECT_TRUE(value_group.isExternalAlias(value));
}
std::vector<const Value*> expected_output_aliases{graph.outputs()[0], nodes[2]->output()};
for (auto* value : expected_output_aliases) {
EXPECT_TRUE(value_group.isOutputAlias(value));
}
EXPECT_FALSE(value_group.isAlwaysAlive(nodes[1]->output()));
EXPECT_TRUE(value_group.isAlwaysAlive(graph.inputs()[0]));
EXPECT_TRUE(value_group.isAlwaysAlive(graph.inputs()[1]));
EXPECT_TRUE(value_group.isAlwaysAlive(graph.outputs()[0]));
}
TEST(StaticModule, IsOptimizableContainerType_NonOptimizableInputs) {
// Cannot use out variants for list/tuple construction here because
// inputs are not produced by nodes with out variants.
const std::string src = R"JIT(
def forward(self, a, b):
a_alias = a.view(a.size())
non_optimizable_list = [a_alias]
non_optimizable_tuple = (b, )
return non_optimizable_list, non_optimizable_tuple
)JIT";
auto sm = makeStaticModuleFromScript(src);
const auto& graph = sm.graph();
for (const Node* n : graph.nodes()) {
EXPECT_FALSE(sm.is_optimizable_container_type(n));
}
}
TEST(StaticModule, IsOptimizableContainerType_WrongType) {
// Cannot use out variants for list/tuple construction here because
// types are not Tensors
const std::string src = R"JIT(
def forward(self, x: int, y: int):
a = 1 + x
b = 2 + y
non_optimizable_list = [a]
non_optimizable_tuple = (b, )
return non_optimizable_list, non_optimizable_tuple
)JIT";
auto sm = makeStaticModuleFromScript(src);
const auto& graph = sm.graph();
for (const Node* n : graph.nodes()) {
EXPECT_FALSE(sm.is_optimizable_container_type(n));
}
}
TEST(StaticModule, IsOptimizableContainerType_CanUseOutVariant) {
// This container should be optimizable since aten::add has an
// out variant the container contains Tensors.
const std::string src = R"JIT(
def forward(self, x):
a = torch.relu(x)
optimizable_list = [a]
return optimizable_list
)JIT";
auto sm = makeStaticModuleFromScript(src);
const auto& graph = sm.graph();
for (const Node* n : graph.nodes()) {
if (n->kind() == prim::ListConstruct) {
EXPECT_TRUE(sm.is_optimizable_container_type(n));
} else {
EXPECT_FALSE(sm.is_optimizable_container_type(n));
}
}
}