3 #include <torch/csrc/jit/testing/file_check.h> 4 #include "test/cpp/jit/test_base.h" 5 #include "torch/csrc/jit/autodiff.h" 6 #include "torch/csrc/jit/interpreter.h" 7 #include "torch/csrc/jit/symbolic_variable.h" 13 using Var = SymbolicVariable;
14 using tensor_list = std::vector<at::Tensor>;
20 Stack createStack(std::vector<at::Tensor>&& list) {
22 std::make_move_iterator(list.begin()),
23 std::make_move_iterator(list.end()));
26 void assertAllClose(
const tensor_list& a,
const tensor_list& b) {
27 ASSERT_EQ(a.size(), b.size());
28 for (
size_t i = 0; i < a.size(); ++i) {
29 ASSERT_TRUE(a[i].is_same_size(b[i]));
30 ASSERT_TRUE(a[i].allclose(b[i]));
34 std::vector<at::Tensor> run(
35 InterpreterState& interp,
36 const std::vector<at::Tensor>& inputs) {
37 std::vector<IValue> stack(inputs.begin(), inputs.end());
39 return fmap(stack, [](
const IValue& i) {
return i.toTensor(); });
42 std::pair<tensor_list, tensor_list> runGradient(
44 tensor_list& tensors_in,
45 tensor_list& tensor_grads_in) {
46 static const auto as_tensorlist = [](
const Stack& stack) {
47 return fmap(stack, [](
const IValue& i) {
return i.toTensor(); });
49 Code f_code{grad_spec.f}, df_code{grad_spec.df};
50 InterpreterState f_interpreter{f_code}, df_interpreter{df_code};
52 auto f_stack = fmap<IValue>(tensors_in);
53 f_interpreter.run(f_stack);
57 df_stack.end(), tensor_grads_in.begin(), tensor_grads_in.end());
58 for (
auto offset : grad_spec.df_input_captured_inputs)
59 df_stack.push_back(tensors_in[offset]);
60 for (
auto offset : grad_spec.df_input_captured_outputs)
61 df_stack.push_back(f_stack[offset]);
62 df_interpreter.run(df_stack);
65 f_stack.erase(f_stack.begin() + grad_spec.f_real_outputs, f_stack.end());
66 return std::make_pair(as_tensorlist(f_stack), as_tensorlist(df_stack));
69 std::tuple<Var, Var> build_lstm_body(
76 auto gates = input.mm(w_ih);
77 gates = gates + hx.mm(w_hh);
78 auto outputs = gates.chunk(4, 1);
79 auto ingate = outputs[0];
80 auto forgetgate = outputs[1];
81 auto cellgate = outputs[2];
82 auto outgate = outputs[3];
83 ingate = ingate.sigmoid();
84 outgate = outgate.sigmoid();
85 cellgate = cellgate.tanh();
86 forgetgate = forgetgate.sigmoid();
88 auto cy = forgetgate * cx;
89 cy = cy + ingate * cellgate;
90 auto hy = outgate * cy.tanh();
92 return std::make_tuple(hy, cy);
95 std::shared_ptr<Graph> build_lstm() {
96 auto r = std::make_shared<Graph>();
98 Value* input = g.addInput();
99 Value* hx = g.addInput();
100 Value* cx = g.addInput();
101 Value* w_ih = g.addInput();
102 Value* w_hh = g.addInput();
106 std::tie(hy, cy) = build_lstm_body(g, input, hx, cx, w_ih, w_hh);
125 bool checkRtol(
const at::Tensor& diff,
const std::vector<at::Tensor> inputs) {
126 double maxValue = 0.0;
127 for (
auto& tensor : inputs) {
128 maxValue = fmax(tensor.abs().max().item<
float>(), maxValue);
130 return diff.abs().max().item<
float>() < 2e-6 * maxValue;
133 return checkRtol(a - b, {a, b});
137 return (a - b).abs().max().item<
float>() == 0.f;
140 std::pair<at::Tensor, at::Tensor> lstm(
146 auto gates = input.mm(t_use(w_ih)) + hx.mm(t_use(w_hh));
148 auto chunked_gates = gates.chunk(4, 1);
149 auto ingate = chunked_gates[0];
150 auto forgetgate = chunked_gates[1];
151 auto cellgate = chunked_gates[2];
152 auto outgate = chunked_gates[3];
154 ingate = ingate.sigmoid();
155 outgate = outgate.sigmoid();
156 cellgate = cellgate.tanh();
157 forgetgate = forgetgate.sigmoid();
159 auto cy = (forgetgate * cx) + (ingate * cellgate);
160 auto hy = outgate * cy.tanh();