3 #include "test/cpp/jit/test_base.h" 4 #include "test/cpp/jit/test_utils.h" 5 #include "torch/csrc/jit/graph_executor.h" 11 void testGraphExecutor() {
12 constexpr
int batch_size = 4;
13 constexpr
int input_size = 256;
15 int hidden_size = 2 * input_size;
17 auto v = [](
at::Tensor t) {
return autograd::make_variable(t,
false); };
19 auto input = at::randn({batch_size, input_size}, at::kCUDA);
20 auto hx = at::randn({batch_size, hidden_size}, at::kCUDA);
21 auto cx = at::randn({batch_size, hidden_size}, at::kCUDA);
22 auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA));
23 auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA));
25 auto g = build_lstm();
26 GraphExecutor executor(g);
27 auto stack = createStack({v(input), v(hx), v(cx), v(w_ih), v(w_hh)});
29 ASSERT_EQ(stack.size(), 2);
31 std::tie(r0, r1) = lstm(input, hx, cx, w_ih, w_hh);
32 ASSERT_TRUE(almostEqual(Variable(stack[0].toTensor()).data(), r0));
33 ASSERT_TRUE(almostEqual(Variable(stack[1].toTensor()).data(), r1));