blob: 251e2654b0135710a7e2954230715045f049f003 [file] [log] [blame]
#include <gtest/gtest.h>
#include <torch/csrc/jit/runtime/static/fusion.h>
#include <torch/csrc/jit/runtime/static/impl.h>
#include "deep_wide_pt.h"
#include "test_scripts.h"
using namespace caffe2;
using namespace torch;
using namespace torch::jit;
using c10::IValue;
namespace {
static at::Tensor getTensor(const at::IValue& ival) {
if (ival.isTensor()) {
return ival.toTensor();
} else if (ival.isTensorList()) {
auto tensor_vec = ival.toTensorVector();
TORCH_CHECK(tensor_vec.size() == 1);
return tensor_vec[0];
} else if (ival.isTuple()) {
auto tuple = ival.toTuple();
auto ivalue_vec = tuple->elements();
TORCH_CHECK(ivalue_vec.size() == 1);
return ivalue_vec[0].toTensor();
} else {
CAFFE_THROW("Unknown input IValue");
}
}
void compareTensorLists(
const std::vector<IValue>& l, /* values */
const std::vector<IValue>& r /* expects */) {
EXPECT_TRUE(l.size() == r.size());
for (int i = 0; i < l.size(); ++i) {
ASSERT_TRUE(l[i].isTensor());
ASSERT_TRUE(r[i].isTensor());
LOG(INFO) << "output " << i << ": \n" << l[i] << std::endl;
LOG(INFO) << "expect " << i << ": \n" << r[i] << std::endl;
EXPECT_TRUE(l[i].toTensor().equal(r[i].toTensor()));
}
}
void compareTensorLists(
const std::vector<at::Tensor>& l, /* values */
const std::vector<at::Tensor>& r /* expects */) {
EXPECT_TRUE(l.size() == r.size());
for (int i = 0; i < l.size(); ++i) {
LOG(INFO) << "output " << i << ": \n" << l[i] << std::endl;
LOG(INFO) << "expect " << i << ": \n" << r[i] << std::endl;
EXPECT_TRUE(l[i].equal(r[i]));
}
}
// Given a model/function in jit script, run the model/function
// with the jit interpreter and static runtime, and compare the results
void testStaticRuntime(
const std::string& jit_script,
const std::vector<IValue>& args) {
script::Module module("module");
module.define(jit_script);
auto expect = module.forward(args);
StaticRuntime runtime(module);
auto actual = runtime.run(args, {});
if (expect.isTuple()) {
compareTensorLists(
expect.toTuple()->elements(), actual.toTuple()->elements());
} else if (expect.isList()) {
compareTensorLists(
expect.toTensorVector(), actual.toTensorVector());
} else {
EXPECT_TRUE(expect.toTensor().equal(actual.toTensor()));
}
}
} // namespace
TEST(StaticRuntime, IndividualOps_Binary) {
auto a = at::randn({2, 3});
auto b = at::ones({2, 3});
std::vector<IValue> args{a, b};
testStaticRuntime(add_script, args);
testStaticRuntime(list_construct_script, args);
testStaticRuntime(list_unpack_script, args);
testStaticRuntime(tuple_construct_script, args);
}
TEST(StaticRuntime, LongModel) {
torch::jit::Module mod = getLongScriptModel();
auto a = torch::randn({2, 2});
auto b = torch::randn({2, 2});
auto c = torch::randn({2, 2});
// run jit graph executor
std::vector<at::IValue> input_ivalues({a, b, c});
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
// run static runtime
std::vector<at::Tensor> input_tensors({a, b, c});
auto g = torch::jit::PrepareForStaticRuntime(mod);
torch::jit::StaticRuntime runtime(g);
at::Tensor output_2 = runtime.run(input_tensors)[0];
EXPECT_TRUE(output_1.equal(output_2));
}
TEST(StaticRuntime, TrivialModel) {
torch::jit::Module mod = getTrivialScriptModel();
auto a = torch::randn({2, 2});
auto b = torch::randn({2, 2});
auto c = torch::randn({2, 2});
// run jit graph executor
std::vector<at::IValue> input_ivalues({a, b, c});
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
// run static runtime
std::vector<at::Tensor> input_tensors({a, b, c});
auto g = torch::jit::PrepareForStaticRuntime(mod);
torch::jit::StaticRuntime runtime(g);
at::Tensor output_2 = runtime.run(input_tensors)[0];
EXPECT_TRUE(output_1.equal(output_2));
}
TEST(StaticRuntime, LeakyReLU) {
torch::jit::Module mod = getLeakyReLUConstScriptModel();
auto inputs = torch::randn({2, 2});
// run jit graph executor
std::vector<at::IValue> input_ivalues({inputs});
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
// run static runtime
std::vector<at::Tensor> input_tensors({inputs});
auto g = torch::jit::PrepareForStaticRuntime(mod);
torch::jit::StaticRuntime runtime(g);
at::Tensor output_2 = runtime.run(input_tensors)[0];
EXPECT_TRUE(output_1.equal(output_2));
}
TEST(StaticRuntime, DeepWide) {
const int embedding_size = 32;
const int num_features = 50;
torch::jit::Module mod = getDeepAndWideSciptModel();
auto g = torch::jit::PrepareForStaticRuntime(mod);
torch::jit::StaticRuntime runtime(g);
for (int batch_size : {1, 8, 32}) {
for (int i = 0; i < 2; ++i) {
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
auto user_emb = torch::randn({batch_size, 1, embedding_size});
auto wide = torch::randn({batch_size, num_features});
// run jit graph executor
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
auto output_1 = getTensor(mod.forward(inputs));
// run static runtime
std::vector<at::Tensor> input_tensors({ad_emb_packed, user_emb, wide});
at::Tensor output_2 = runtime.run(input_tensors)[0];
EXPECT_TRUE(output_1.equal(output_2));
}
}
}
TEST(StaticRuntime, KWargsAPI_1) {
const int embedding_size = 32;
const int num_features = 50;
auto module = getDeepAndWideSciptModel();
torch::jit::StaticRuntime runtime(module);
for (int batch_size : {1, 8, 32}) {
for (int i = 0; i < 2; ++i) {
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
auto user_emb = torch::randn({batch_size, 1, embedding_size});
auto wide = torch::randn({batch_size, num_features});
// run jit graph executor
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
at::Tensor output_1 = getTensor(module.forward(inputs));
// run static runtime
at::Tensor output_2 = getTensor(runtime.run(inputs, {}));
EXPECT_TRUE(output_1.equal(output_2));
}
}
}
TEST(StaticRuntime, KWargsAPI_2) {
const int embedding_size = 32;
const int num_features = 50;
auto module = getDeepAndWideSciptModel();
auto g = torch::jit::PrepareForStaticRuntime(module);
torch::jit::StaticRuntime runtime(module);
for (int batch_size : {1, 8, 32}) {
for (int i = 0; i < 2; ++i) {
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
auto user_emb = torch::randn({batch_size, 1, embedding_size});
auto wide = torch::randn({batch_size, num_features});
// run jit graph executor
std::vector<at::IValue> args({ad_emb_packed, user_emb, wide});
at::Tensor output_1 = getTensor(module.forward(args));
std::unordered_map<std::string, c10::IValue> kwargs(
{{"ad_emb_packed", ad_emb_packed},
{"user_emb", user_emb},
{"wide", wide}});
// run static runtime
at::Tensor output_2 = getTensor(runtime.run({}, kwargs));
EXPECT_TRUE(output_1.equal(output_2));
}
}
}
TEST(StaticRuntime, CleanUpMemory) {
const int embedding_size = 32;
const int num_features = 50;
torch::jit::Module mod = getDeepAndWideSciptModel();
auto g = torch::jit::PrepareForStaticRuntime(mod);
for (auto cleanup_memory : {true, false}) {
for (auto enable_out_variant : {true, false}) {
VLOG(1) << "cleanup_memory: " << cleanup_memory
<< ", enable_out_variant: " << enable_out_variant;
torch::jit::StaticRuntimeOptions opts{cleanup_memory, enable_out_variant};
torch::jit::StaticRuntime runtime(g, opts);
for (int batch_size : {1, 8, 32}) {
for (int i = 0; i < 2; ++i) {
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
auto user_emb = torch::randn({batch_size, 1, embedding_size});
auto wide = torch::randn({batch_size, num_features});
// run jit graph executor
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
auto output_1 = getTensor(mod.forward(inputs));
// run static runtime
std::vector<at::Tensor> input_tensors(
{ad_emb_packed, user_emb, wide});
at::Tensor output_2 = runtime.run(input_tensors)[0];
EXPECT_TRUE(output_1.equal(output_2));
}
}
}
}
}
TEST(StaticRuntime, FusionPass) {
const int embedding_size = 32;
const int num_features = 50;
for (int batch_size : {1, 8, 32}) {
for (int i = 0; i < 2; ++i) {
torch::jit::Module module = getDeepAndWideSciptModel();
auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
auto user_emb = torch::randn({batch_size, 1, embedding_size});
auto wide = torch::randn({batch_size, num_features});
// run jit graph executor
std::vector<at::IValue> inputs({ad_emb_packed, user_emb, wide});
auto output_1 = getTensor(module.forward(inputs));
Method method = module.get_method("forward");
auto graph = method.graph();
fuseStaticSubgraphs(graph);
bool hit = false;
for (const auto& n : module.get_method("forward").graph()->nodes()) {
if (n->kind() == torch::jit::prim::StaticSubgraph) {
hit = true;
}
}
EXPECT_TRUE(hit);
auto output_2 = getTensor(module.forward(inputs));
EXPECT_TRUE(output_1.equal(output_2));
}
}
}