blob: 0f424e130ff028f1e8f5477b02614d7fb588f761 [file] [log] [blame]
#include "test/cpp/jit/test_base.h"
#include "test/cpp/jit/test_utils.h"
#include <stdexcept>
namespace torch {
namespace jit {
void testTypeCheck() {
{
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%a.1 : Tensor,
%b.1 : Tensor):
%t0 : Float(2:2, 2:1, device=cpu, requires_grad=1), %t1 : Float(3:3, 3:1), %type_matched : bool = prim::TypeCheck(%a.1, %b.1)
return (%t0, %t1, %type_matched)
)IR",
&*graph,
vmap);
Code function(graph, "");
InterpreterState interp(function);
{
// TypeCheck yields to true! Shape, grad and device matches.
auto a = at::zeros({2, 2}, at::kFloat);
auto b = at::ones({3, 3}, at::kFloat);
a.set_requires_grad(true);
a = a.to(at::kCPU);
std::vector<IValue> stack({a, b});
interp.run(stack);
ASSERT_TRUE(exactlyEqual(stack[0].toTensor(), a));
ASSERT_TRUE(exactlyEqual(stack[1].toTensor(), b));
ASSERT_TRUE(stack[2].toBool());
}
{
auto a = at::zeros({2, 2}, at::kFloat);
auto b = at::ones({2, 2}, at::kFloat); // Size mismatch
a.set_requires_grad(true);
a = a.to(at::kCPU);
std::vector<IValue> stack({a, b});
interp.run(stack);
ASSERT_FALSE(stack[2].toBool());
}
{
auto a = at::zeros({2, 2}, at::kFloat);
auto b = at::ones({3, 3}, at::kFloat);
a = a.to(at::kCPU);
a.set_requires_grad(false); // Gradient mismatch
std::vector<IValue> stack({a, b});
interp.run(stack);
ASSERT_FALSE(stack[2].toBool());
}
{
auto a = at::zeros({2, 2}, at::kFloat);
auto b = at::ones({3, 3}, at::kFloat);
a = a.to(at::kCPU);
a.set_requires_grad(true);
a = a.to(at::kInt); // Scalar type mismatch
std::vector<IValue> stack({a, b});
interp.run(stack);
ASSERT_FALSE(stack[2].toBool());
}
{
auto a = at::zeros({2, 2}, at::kFloat);
auto b = at::ones({3, 3}, at::kFloat);
a.set_requires_grad(true);
a = a.to(at::kCUDA); // Device mismatch
std::vector<IValue> stack({a, b});
interp.run(stack);
ASSERT_FALSE(stack[2].toBool());
}
}
try { // Test empty Typecheck raises an internal assertion
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%a.1 : Tensor,
%b.1 : Tensor):
%type_matched : bool = prim::TypeCheck()
return (%type_matched)
)IR",
&*graph,
vmap);
ASSERT_TRUE(false);
} catch (const std::exception& e) {
}
try { // Test for assertion if num_inputs + 1 != num_outputs
auto graph = std::make_shared<Graph>();
std::unordered_map<std::string, Value*> vmap;
parseIR(
R"IR(
graph(%a.1 : Tensor,
%b.1 : Tensor):
%type_matched : bool = prim::TypeCheck(%a.1)
return (%type_matched)
)IR",
&*graph,
vmap);
ASSERT_TRUE(false);
} catch (const std::exception& e) {
}
}
void testInterp() {
constexpr int batch_size = 4;
constexpr int input_size = 256;
constexpr int seq_len = 32;
int hidden_size = 2 * input_size;
auto input = at::randn({seq_len, batch_size, input_size}, at::kCUDA);
auto hx = at::randn({batch_size, hidden_size}, at::kCUDA);
auto cx = at::randn({batch_size, hidden_size}, at::kCUDA);
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA));
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA));
auto lstm_g = build_lstm();
Code lstm_function(lstm_g, "");
InterpreterState lstm_interp(lstm_function);
auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh});
std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
ASSERT_TRUE(exactlyEqual(outputs[0], hx));
ASSERT_TRUE(exactlyEqual(outputs[1], cx));
}
} // namespace jit
} // namespace torch