Caffe2 - C++ API
A deep learning, cross platform ML framework
test_utils.h
1 #pragma once
2 
3 #include <torch/csrc/jit/testing/file_check.h>
4 #include "test/cpp/jit/test_base.h"
5 #include "torch/csrc/jit/autodiff.h"
6 #include "torch/csrc/jit/interpreter.h"
7 #include "torch/csrc/jit/symbolic_variable.h"
8 
9 namespace torch {
10 namespace jit {
11 namespace test {
12 
13 using Var = SymbolicVariable;
14 using tensor_list = std::vector<at::Tensor>;
15 using namespace torch::autograd;
16 
17 // work around the fact that variable_tensor_list doesn't duplicate all
18 // of std::vector's constructors.
19 // most constructors are never used in the implementation, just in our tests.
20 Stack createStack(std::vector<at::Tensor>&& list) {
21  return Stack(
22  std::make_move_iterator(list.begin()),
23  std::make_move_iterator(list.end()));
24 }
25 
26 void assertAllClose(const tensor_list& a, const tensor_list& b) {
27  ASSERT_EQ(a.size(), b.size());
28  for (size_t i = 0; i < a.size(); ++i) {
29  ASSERT_TRUE(a[i].is_same_size(b[i]));
30  ASSERT_TRUE(a[i].allclose(b[i]));
31  }
32 }
33 
34 std::vector<at::Tensor> run(
35  InterpreterState& interp,
36  const std::vector<at::Tensor>& inputs) {
37  std::vector<IValue> stack(inputs.begin(), inputs.end());
38  interp.run(stack);
39  return fmap(stack, [](const IValue& i) { return i.toTensor(); });
40 }
41 
42 std::pair<tensor_list, tensor_list> runGradient(
43  Gradient& grad_spec,
44  tensor_list& tensors_in,
45  tensor_list& tensor_grads_in) {
46  static const auto as_tensorlist = [](const Stack& stack) {
47  return fmap(stack, [](const IValue& i) { return i.toTensor(); });
48  };
49  Code f_code{grad_spec.f}, df_code{grad_spec.df};
50  InterpreterState f_interpreter{f_code}, df_interpreter{df_code};
51 
52  auto f_stack = fmap<IValue>(tensors_in);
53  f_interpreter.run(f_stack);
54 
55  Stack df_stack;
56  df_stack.insert(
57  df_stack.end(), tensor_grads_in.begin(), tensor_grads_in.end());
58  for (auto offset : grad_spec.df_input_captured_inputs)
59  df_stack.push_back(tensors_in[offset]);
60  for (auto offset : grad_spec.df_input_captured_outputs)
61  df_stack.push_back(f_stack[offset]);
62  df_interpreter.run(df_stack);
63 
64  // Outputs of f needs to be sliced
65  f_stack.erase(f_stack.begin() + grad_spec.f_real_outputs, f_stack.end());
66  return std::make_pair(as_tensorlist(f_stack), as_tensorlist(df_stack));
67 }
68 
69 std::tuple<Var, Var> build_lstm_body(
70  Graph& g,
71  Var input,
72  Var hx,
73  Var cx,
74  Var w_ih,
75  Var w_hh) {
76  auto gates = input.mm(w_ih);
77  gates = gates + hx.mm(w_hh);
78  auto outputs = gates.chunk(4, 1);
79  auto ingate = outputs[0];
80  auto forgetgate = outputs[1];
81  auto cellgate = outputs[2];
82  auto outgate = outputs[3];
83  ingate = ingate.sigmoid();
84  outgate = outgate.sigmoid();
85  cellgate = cellgate.tanh();
86  forgetgate = forgetgate.sigmoid();
87 
88  auto cy = forgetgate * cx;
89  cy = cy + ingate * cellgate;
90  auto hy = outgate * cy.tanh();
91 
92  return std::make_tuple(hy, cy);
93 }
94 
95 std::shared_ptr<Graph> build_lstm() {
96  auto r = std::make_shared<Graph>();
97  auto& g = *r;
98  Value* input = g.addInput();
99  Value* hx = g.addInput();
100  Value* cx = g.addInput();
101  Value* w_ih = g.addInput();
102  Value* w_hh = g.addInput();
103 
104  Var hy;
105  Var cy;
106  std::tie(hy, cy) = build_lstm_body(g, input, hx, cx, w_ih, w_hh);
107 
108  hy.addAsOutput();
109  cy.addAsOutput();
110  g.lint();
111 
112  return r;
113 }
114 
115 at::Tensor t_use(at::Tensor x) {
116  return x;
117 }
118 at::Tensor t_def(at::Tensor x) {
119  return x.t();
120 }
121 
122 // given the difference of output vs expected tensor, check whether the
123 // difference is within a relative tolerance range. This is a standard way of
124 // matching tensor values upto certain precision
125 bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs) {
126  double maxValue = 0.0;
127  for (auto& tensor : inputs) {
128  maxValue = fmax(tensor.abs().max().item<float>(), maxValue);
129  }
130  return diff.abs().max().item<float>() < 2e-6 * maxValue;
131 }
132 bool almostEqual(const at::Tensor& a, const at::Tensor& b) {
133  return checkRtol(a - b, {a, b});
134 }
135 
136 bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) {
137  return (a - b).abs().max().item<float>() == 0.f;
138 }
139 
140 std::pair<at::Tensor, at::Tensor> lstm(
141  at::Tensor input,
142  at::Tensor hx,
143  at::Tensor cx,
144  at::Tensor w_ih,
145  at::Tensor w_hh) {
146  auto gates = input.mm(t_use(w_ih)) + hx.mm(t_use(w_hh));
147 
148  auto chunked_gates = gates.chunk(4, 1);
149  auto ingate = chunked_gates[0];
150  auto forgetgate = chunked_gates[1];
151  auto cellgate = chunked_gates[2];
152  auto outgate = chunked_gates[3];
153 
154  ingate = ingate.sigmoid();
155  outgate = outgate.sigmoid();
156  cellgate = cellgate.tanh();
157  forgetgate = forgetgate.sigmoid();
158 
159  auto cy = (forgetgate * cx) + (ingate * cellgate);
160  auto hy = outgate * cy.tanh();
161 
162  return {hy, cy};
163 }
164 
165 } // namespace test
166 } // namespace jit
167 } // namespace torch
Definition: module.cpp:17
Definition: jit_type.h:17