Caffe2 - C++ API
A deep learning, cross platform ML framework
test_graph_executor.h
1 #pragma once
2 
3 #include "test/cpp/jit/test_base.h"
4 #include "test/cpp/jit/test_utils.h"
5 #include "torch/csrc/jit/graph_executor.h"
6 
7 namespace torch {
8 namespace jit {
9 namespace test {
10 
11 void testGraphExecutor() {
12  constexpr int batch_size = 4;
13  constexpr int input_size = 256;
14 
15  int hidden_size = 2 * input_size;
16 
17  auto v = [](at::Tensor t) { return autograd::make_variable(t, false); };
18 
19  auto input = at::randn({batch_size, input_size}, at::kCUDA);
20  auto hx = at::randn({batch_size, hidden_size}, at::kCUDA);
21  auto cx = at::randn({batch_size, hidden_size}, at::kCUDA);
22  auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA));
23  auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA));
24 
25  auto g = build_lstm();
26  GraphExecutor executor(g);
27  auto stack = createStack({v(input), v(hx), v(cx), v(w_ih), v(w_hh)});
28  executor.run(stack);
29  ASSERT_EQ(stack.size(), 2);
30  at::Tensor r0, r1;
31  std::tie(r0, r1) = lstm(input, hx, cx, w_ih, w_hh);
32  ASSERT_TRUE(almostEqual(Variable(stack[0].toTensor()).data(), r0));
33  ASSERT_TRUE(almostEqual(Variable(stack[1].toTensor()).data(), r1));
34 }
35 
36 } // namespace test
37 } // namespace jit
38 } // namespace torch
Definition: module.cpp:17
Definition: jit_type.h:17