Caffe2 - C++ API
A deep learning, cross platform ML framework
test_interpreter.h
1 #pragma once
2 
3 #include "test/cpp/jit/test_base.h"
4 #include "test/cpp/jit/test_utils.h"
5 
6 namespace torch {
7 namespace jit {
8 namespace test {
9 
10 void testInterp() {
11  constexpr int batch_size = 4;
12  constexpr int input_size = 256;
13  constexpr int seq_len = 32;
14 
15  int hidden_size = 2 * input_size;
16 
17  auto input = at::randn({seq_len, batch_size, input_size}, at::kCUDA);
18  auto hx = at::randn({batch_size, hidden_size}, at::kCUDA);
19  auto cx = at::randn({batch_size, hidden_size}, at::kCUDA);
20  auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCUDA));
21  auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCUDA));
22 
23  auto lstm_g = build_lstm();
24  Code lstm_function(lstm_g);
25  InterpreterState lstm_interp(lstm_function);
26  auto outputs = run(lstm_interp, {input[0], hx, cx, w_ih, w_hh});
27  std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
28 
29  // std::cout << almostEqual(outputs[0],hx) << "\n";
30  ASSERT_TRUE(exactlyEqual(outputs[0], hx));
31  ASSERT_TRUE(exactlyEqual(outputs[1], cx));
32 }
33 } // namespace test
34 } // namespace jit
35 } // namespace torch
Definition: module.cpp:17
Definition: jit_type.h:17