blob: 308bced43f555165f58f9accc8f7c3b688c86acc [file] [log] [blame]
#include <gtest/gtest.h>
#include <torch/csrc/jit/runtime/static/fusion.h>
#include <torch/csrc/jit/runtime/static/impl.h>
#include "deep_wide_pt.h"
#include "test_scripts.h"
using namespace caffe2;
using namespace torch;
using namespace torch::jit;
using c10::IValue;
namespace {
static at::Tensor getTensor(const at::IValue& ival) {
if (ival.isTensor()) {
return ival.toTensor();
} else if (ival.isTensorList()) {
auto tensor_vec = ival.toTensorVector();
TORCH_CHECK(tensor_vec.size() == 1);
return tensor_vec[0];
} else if (ival.isTuple()) {
auto tuple = ival.toTuple();
auto ivalue_vec = tuple->elements();
TORCH_CHECK(ivalue_vec.size() == 1);
return ivalue_vec[0].toTensor();
} else {
CAFFE_THROW("Unknown input IValue");
}
}
void compareTensorLists(
const std::vector<IValue>& l, /* expects */
const std::vector<IValue>& r /* values */) {
EXPECT_TRUE(l.size() == r.size());
for (int i = 0; i < l.size(); ++i) {
ASSERT_TRUE(l[i].isTensor());
ASSERT_TRUE(r[i].isTensor());
VLOG(2) << "expect " << i << ": \n" << l[i] << std::endl;
VLOG(2) << "output " << i << ": \n" << r[i] << std::endl;
if (! l[i].toTensor().defined()) {
EXPECT_TRUE(! r[i].toTensor().defined());
} else {
EXPECT_TRUE(l[i].toTensor().equal(r[i].toTensor()));
}
}
}
void compareTensorLists(
const std::vector<at::Tensor>& l, /* expects */
const std::vector<at::Tensor>& r /* values */) {
EXPECT_TRUE(l.size() == r.size());
for (int i = 0; i < l.size(); ++i) {
VLOG(2) << "expect " << i << ": \n" << l[i] << std::endl;
VLOG(2) << "output " << i << ": \n" << r[i] << std::endl;
if (! l[i].defined()) {
EXPECT_TRUE(! r[i].defined());
} else {
EXPECT_TRUE(l[i].equal(r[i]));
}
}
}
// Given a model/function in jit script, run the model/function
// with the jit interpreter and static runtime, and compare the results
void testStaticRuntime(
const std::string& jit_script,
const std::vector<IValue>& args) {
script::Module module("module");
module.define(jit_script);
std::vector<IValue> args_tensors, args_copy;
for (const auto& ival : args) {
if (ival.isTensor()) {
args_tensors.emplace_back(ival);
const at::Tensor& t = ival.toTensor();
args_copy.emplace_back(t.clone());
}
}
auto expect = module.forward(args);
torch::jit::StaticModule smodule(module);
auto actual = smodule(args, {});
smodule.runtime().check_for_memory_leak();
if (expect.isTuple()) {
compareTensorLists(
expect.toTuple()->elements(), actual.toTuple()->elements());
} else if (expect.isList()) {
compareTensorLists(expect.toTensorVector(), actual.toTensorVector());
} else {
EXPECT_TRUE(expect.toTensor().equal(actual.toTensor()));
}
// make sure inputs were not modified
compareTensorLists(args_tensors, args_copy);
}
} // namespace
TEST(StaticRuntime, UnaryOps) {
auto a = at::ones({2, 3});
std::vector<IValue> args{a};
testStaticRuntime(aten_sum, args);
testStaticRuntime(aten_sum_0, args);
testStaticRuntime(aten_sum_1, args);
testStaticRuntime(aten_sum_0_true, args);
testStaticRuntime(aten_sum_1_true, args);
}
TEST(StaticRuntime, EmbeddingBag) {
at::Tensor weight = torch::ones({3, 11}, at::ScalarType::Float);
at::Tensor input = torch::tensor({0, 1, 0, 2});
at::Tensor offset = torch::tensor({0, 2, 4});
std::vector<IValue> args{weight, input, offset};
testStaticRuntime(embedding_bag_default, args);
testStaticRuntime(embedding_bag_mean, args);
testStaticRuntime(embedding_bag_max, args);
testStaticRuntime(embedding_bag_sum_last_offset, args);
testStaticRuntime(embedding_bag_mean_last_offset, args);
testStaticRuntime(embedding_bag_max_last_offset, args);
}
TEST(StaticRuntime, IndividualOps_Binary) {
auto a = at::randn({2, 3});
auto b = at::ones({2, 3});
std::vector<IValue> args{a, b};
testStaticRuntime(add_script, args);
testStaticRuntime(list_construct_script, args);
testStaticRuntime(list_construct_script_2, args);
testStaticRuntime(list_construct_script_3, args);
testStaticRuntime(list_unpack_script, args);
testStaticRuntime(list_unpack_script_2, args);
testStaticRuntime(tuple_construct_script, args);
testStaticRuntime(tuple_construct_script_2, args);
}
TEST(StaticRuntime, IndividualOps_Reshape) {
auto a = at::randn({2, 3});
auto b = std::vector<int64_t>({3, 2});
std::vector<IValue> args{a, b};
testStaticRuntime(reshape_script_1, args);
testStaticRuntime(reshape_script_2, args);
testStaticRuntime(reshape_script_3, args);
testStaticRuntime(reshape_script_4, args);
testStaticRuntime(reshape_script_5, args);
testStaticRuntime(reshape_inplace_script, args);
}
TEST(StaticRuntime, IndividualOps_flatten) {
auto test_flatten =
[](std::vector<int64_t> shape, int64_t start_dim, int64_t end_dim) {
auto a = at::randn(shape);
std::vector<IValue> args{a, start_dim, end_dim};
testStaticRuntime(flatten_script_1, args);
if (shape.size() > 2) {
testStaticRuntime(flatten_script_2, args);
}
};
test_flatten({2, 3}, 0, 1);
test_flatten({2, 1, 3}, 1, 2);
test_flatten({0, 1, 3, 0}, 1, 2);
test_flatten({2, 3}, 1, 1);
test_flatten({}, 0, 0);
}
TEST(StaticRuntime, IndividualOps_pow) {
auto a = at::randn({2, 3});
auto b = at::randn({2, 3});
std::vector<IValue> args0{a, 4};
testStaticRuntime(pow_script_ten_sca, args0);
std::vector<IValue> args1{at::abs(a), b};
testStaticRuntime(pow_script_ten_ten, args1);
std::vector<IValue> args2{5, b};
testStaticRuntime(pow_script_sca_ten, args2);
}
TEST(StaticRuntime, IndividualOps_to) {
auto test_to =
[](at::ScalarType b, bool c, bool d, c10::MemoryFormat e) {
auto a = at::randn({2, 3});
std::vector<IValue> args0{a, b, c, d, e};
std::vector<IValue> args1{a, b, c, d};
testStaticRuntime(to_script_0, args0);
testStaticRuntime(to_script_1, args1);
};
test_to(at::ScalarType::Float, true, true, c10::MemoryFormat::Contiguous);
test_to(at::ScalarType::Half, true, false, c10::MemoryFormat::Preserve);
test_to(at::ScalarType::Float, false, false, c10::MemoryFormat::Contiguous);
test_to(at::ScalarType::Half, false, true, c10::MemoryFormat::Preserve);
}
TEST(StaticRuntime, LongModel) {
torch::jit::Module mod = getLongScriptModel();
auto a = torch::randn({2, 2});
auto b = torch::randn({2, 2});
auto c = torch::randn({2, 2});
// run jit graph executor
std::vector<at::IValue> input_ivalues({a, b, c});
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
// run static runtime
std::vector<at::Tensor> input_tensors({a, b, c});
torch::jit::StaticModule smod(mod);
at::Tensor output_2 = smod(input_tensors)[0];
smod.runtime().check_for_memory_leak();
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
}
TEST(StaticRuntime, TrivialModel) {
torch::jit::Module mod = getTrivialScriptModel();
auto a = torch::randn({2, 2});
auto b = torch::randn({2, 2});
auto c = torch::randn({2, 2});
// run jit graph executor
std::vector<at::IValue> input_ivalues({a, b, c});
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
// run static runtime
std::vector<at::Tensor> input_tensors({a, b, c});
torch::jit::StaticModule smod(mod);
at::Tensor output_2 = smod(input_tensors)[0];
smod.runtime().check_for_memory_leak();
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
}
TEST(StaticRuntime, LeakyReLU) {
torch::jit::Module mod = getLeakyReLUConstScriptModel();
auto inputs = torch::randn({2, 2});
// run jit graph executor
std::vector<at::IValue> input_ivalues({inputs});
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
// run static runtime
std::vector<at::Tensor> input_tensors({inputs});
torch::jit::StaticModule smod(mod);
at::Tensor output_2 = smod(input_tensors)[0];
smod.runtime().check_for_memory_leak();
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
}
TEST(StaticRuntime, DeepWide) {
const int embedding_size = 32;
const int num_features = 50;
torch::jit::Module mod = getDeepAndWideSciptModel();
torch::jit::StaticModule smod(mod);
for (int batch_size : {1, 8, 32}) {
for (int i = 0; i < 2; ++i) {
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
auto user_emb = torch::randn({batch_size, 1, embedding_size});
auto wide = torch::randn({batch_size, num_features});
// run jit graph executor
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
auto output_1 = getTensor(mod.forward(inputs));
// run static runtime
std::vector<at::Tensor> input_tensors({ad_emb_packed, user_emb, wide});
at::Tensor output_2 = smod(input_tensors)[0];
smod.runtime().check_for_memory_leak();
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
}
}
}
TEST(StaticRuntime, KWargsAPI_1) {
const int embedding_size = 32;
const int num_features = 50;
auto module = getDeepAndWideSciptModel();
torch::jit::StaticModule smod(module);
for (int batch_size : {1, 8, 32}) {
for (int i = 0; i < 2; ++i) {
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
auto user_emb = torch::randn({batch_size, 1, embedding_size});
auto wide = torch::randn({batch_size, num_features});
{
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
// run jit graph executor
at::Tensor output_1 = getTensor(module.forward(inputs));
// run static runtime
c10::IValue output_ivalue = smod(inputs, {});
smod.runtime().check_for_memory_leak();
at::Tensor output_2 = getTensor(output_ivalue);
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
// check for output aliasing
EXPECT_EQ(output_ivalue.use_count(), 1);
output_ivalue = IValue();
EXPECT_EQ(output_2.getIntrusivePtr().use_count(), 1);
}
// check for input aliasing (deep & wide does not have ops
// that create aliases of input tensors)
EXPECT_EQ(ad_emb_packed.getIntrusivePtr().use_count(), 1);
EXPECT_EQ(user_emb.getIntrusivePtr().use_count(), 1);
EXPECT_EQ(wide.getIntrusivePtr().use_count(), 1);
}
}
}
TEST(StaticRuntime, KWargsAPI_2) {
const int embedding_size = 32;
const int num_features = 50;
auto module = getDeepAndWideSciptModel();
torch::jit::StaticModule smod(module);
for (int batch_size : {1, 8, 32}) {
for (int i = 0; i < 2; ++i) {
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
auto user_emb = torch::randn({batch_size, 1, embedding_size});
auto wide = torch::randn({batch_size, num_features});
{
// run jit graph executor
std::vector<at::IValue> args({ad_emb_packed, user_emb, wide});
at::Tensor output_1 = getTensor(module.forward(args));
std::unordered_map<std::string, c10::IValue> kwargs(
{{"ad_emb_packed", ad_emb_packed},
{"user_emb", user_emb},
{"wide", wide}});
// run static runtime
c10::IValue output_ivalue = smod({}, kwargs);
smod.runtime().check_for_memory_leak();
at::Tensor output_2 = getTensor(output_ivalue);
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
// check for output aliasing
EXPECT_EQ(output_ivalue.use_count(), 1);
output_ivalue = IValue();
EXPECT_EQ(output_2.getIntrusivePtr().use_count(), 1);
}
EXPECT_EQ(ad_emb_packed.getIntrusivePtr().use_count(), 1);
EXPECT_EQ(user_emb.getIntrusivePtr().use_count(), 1);
EXPECT_EQ(wide.getIntrusivePtr().use_count(), 1);
}
}
}
TEST(StaticRuntime, CleanUpMemory) {
const int embedding_size = 32;
const int num_features = 50;
torch::jit::Module mod = getDeepAndWideSciptModel();
torch::jit::StaticModule smod(mod);
for (auto cleanup_memory : {true, false}) {
for (auto enable_out_variant : {true, false}) {
VLOG(1) << "cleanup_memory: " << cleanup_memory
<< ", enable_out_variant: " << enable_out_variant;
torch::jit::StaticModuleOptions opts{cleanup_memory, enable_out_variant};
torch::jit::StaticModule smod(mod, opts);
for (int batch_size : {1, 8, 32}) {
for (int i = 0; i < 2; ++i) {
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
auto user_emb = torch::randn({batch_size, 1, embedding_size});
auto wide = torch::randn({batch_size, num_features});
// run jit graph executor
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
auto output_1 = getTensor(mod.forward(inputs));
// run static runtime
std::vector<at::Tensor> input_tensors(
{ad_emb_packed, user_emb, wide});
at::Tensor output_2 = smod(input_tensors)[0];
smod.runtime().check_for_memory_leak();
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
}
}
}
}
}
TEST(StaticRuntime, FusionPass) {
const int embedding_size = 32;
const int num_features = 50;
for (int batch_size : {1, 8, 32}) {
for (int i = 0; i < 2; ++i) {
torch::jit::Module module = getDeepAndWideSciptModel();
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
auto user_emb = torch::randn({batch_size, 1, embedding_size});
auto wide = torch::randn({batch_size, num_features});
// run jit graph executor
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
auto output_1 = getTensor(module.forward(inputs));
Method method = module.get_method("forward");
auto graph = method.graph();
fuseStaticSubgraphs(graph, 2);
bool hit = false;
for (const auto& n : module.get_method("forward").graph()->nodes()) {
if (n->kind() == torch::jit::prim::StaticSubgraph) {
hit = true;
}
}
EXPECT_TRUE(hit);
auto output_2 = getTensor(module.forward(inputs));
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
}
}
}