| #include "test/cpp/jit/test_base.h" |
| #include "test/cpp/jit/test_utils.h" |
| |
| namespace torch { |
| namespace jit { |
| |
| void testInterp() { |
| constexpr int batch_size = 4; |
| constexpr int input_size = 256; |
| constexpr int seq_len = 32; |
| |
| int hidden_size = 2 * input_size; |
| |
| auto input = at::randn({seq_len, batch_size, input_size}, at::kCUDA); |
| auto hx = at::randn({batch_size, hidden_size}, at::kCUDA); |
| auto cx = at::randn({batch_size, hidden_size}, at::kCUDA); |
| auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA)); |
| auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA)); |
| |
| auto lstm_g = build_lstm(); |
| Code lstm_function(lstm_g); |
| InterpreterState lstm_interp(lstm_function); |
| auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh}); |
| std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh); |
| |
| // std::cout << almostEqual(outputs[0],hx) << "\n"; |
| ASSERT_TRUE(exactlyEqual(outputs[0], hx)); |
| ASSERT_TRUE(exactlyEqual(outputs[1], cx)); |
| } |
| } // namespace jit |
| } // namespace torch |