3 #include "test/cpp/jit/test_base.h" 4 #include "test/cpp/jit/test_utils.h" 11 constexpr
int batch_size = 4;
12 constexpr
int input_size = 256;
13 constexpr
int seq_len = 32;
15 int hidden_size = 2 * input_size;
17 auto input = at::randn({seq_len, batch_size, input_size}, at::kCUDA);
18 auto hx = at::randn({batch_size, hidden_size}, at::kCUDA);
19 auto cx = at::randn({batch_size, hidden_size}, at::kCUDA);
20 auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA));
21 auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA));
23 auto lstm_g = build_lstm();
24 Code lstm_function(lstm_g);
25 InterpreterState lstm_interp(lstm_function);
26 auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh});
27 std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
30 ASSERT_TRUE(exactlyEqual(outputs[0], hx));
31 ASSERT_TRUE(exactlyEqual(outputs[1], cx));