Caffe2 - C++ API
A deep learning, cross platform ML framework
autodiff.h
1 #pragma once
2 
3 #include <torch/csrc/WindowsTorchApiMacro.h>
4 #include <torch/csrc/jit/ir.h>
5 
6 #include <ATen/ATen.h>
7 
8 #include <memory>
9 #include <vector>
10 
11 namespace torch {
12 namespace jit {
13 
14 using value_list = std::vector<Value*>;
15 // clang-format off
16 // Example showcasing how Gradient is constructed:
17 //
18 // Let's assume we have a function f, `m` and `n` do not require grad
19 // (`n` can depend only on `m`):
20 // y, n = f(x, m)
21 //
22 // Now, let's assume that the reverse of f (called f') needs to use values of `x`, `t` and `y`.
23 // `t` is an intermediate value produced in the body of f, and let's assume that it requires
24 // grad too.
25 //
26 // In this case differentiate(f) will return this:
27 // y, n, t = f(x, m) // `t` is appended to the output list
28 // dx = f'(dy, dt, x, t, y) // No `dm` or `dn` because they do not require gradient
29 // // All needed values from f are prepended to the input list
30 //
31 // f_real_outputs = 2 // Only first two outputs were present in f originally
32 // df_input_vjps = {0, 2} // i.e. connect grad_fn of y and t variables produced by f,
33 // y t // with y's output_nr = 0 and t's output_nr = 1
34 // df_input_captures = {I0, O2, O0} // Order matches the prefix of inputs to df
35 // x t y
36 // df_output_vjps = {0} // i.e. connect next_edge[0] of grad_fn to x's (grad_fn, output_nr).
37 //
38 // Terminology: vjp = vector-jacobian product
39 // clang-format on
40 
41 struct Gradient {
42  explicit operator bool() const {
43  return df != nullptr;
44  }
45  std::shared_ptr<Graph> f;
46  std::shared_ptr<Graph> df;
47 
48  // Describes how to construct outputs of f from what its graph will return.
49  // This is necessary because some trailing outputs are intermediates produced
50  // only to be saved for df (and should be ignored).
51  size_t f_real_outputs = 0; // initialized for safety.
52 
53  // df inputs are split into two sections: vjps (aka grad_outputs) and
54  // captures. VJPs are "seeds" for the gradient computation given for each
55  // input capture of an Output kind. Captures are values the need to be saved
56  // when f is run. We handle inputs specially, because this allows us to avoid
57  // adding extra vjps as df inputs.
58 
59  std::vector<size_t> df_input_vjps; // Offsets into f's outputs.
60  // capture can come from inputs or outputs
61  std::vector<size_t> df_input_captured_inputs; // Offsets into f's inputs
62  std::vector<size_t> df_input_captured_outputs; // Offsets into f's outputs
63 
64  // df will produce vjps for a subset of inputs of f that required grad.
65  // df_output_vjps[idx] == inp_idx means that idx-th output of df produces a
66  // vjp for inp_idx-th input of f.
67  std::vector<size_t> df_output_vjps; // Offsets into f's inputs.
68 
69  // How to use gradient to implement a differentiable autograd function:
70  // When running f:
71  // - Unwrap input Variables
72  // - Run f's graph
73  // - Create grad_fn
74  // - Wrap outputs in Variables (assume we have a tensor_outputs array):
75  // outputs = map(Variable, tensor_output)
76  // for i, offset in enumerate(df_input_vjps):
77  // outputs[offset].set_grad_fn(grad_fn, output_nr=i)
78  // - Use df_output_vjps to connect next_edges of grad_fn:
79  // for idx in df_output_vjps:
80  // grad_fn.add_next_edge(inputs[idx].gradient_edge())
81  // - Save captures for df (care needs to be taken to use SavedVariables for
82  // inputs and outputs that we will actually return)
83  // - Return outputs[:f_real_outputs]
84  //
85  // When running df:
86  // - Concatenate received vjps and captured Variables
87  // - Interpret df
88  // - Wrap outputs of df into Variables (that don't require grad)
89 };
90 TORCH_API Gradient differentiate(std::shared_ptr<Graph>& graph);
91 
92 // can we take a derivative of this node symbolically?
93 TORCH_API bool isDifferentiable(Node* n);
94 TORCH_API bool isDifferentiable(Graph& g);
95 TORCH_API bool isZero(Value* v);
96 
97 } // namespace jit
98 } // namespace torch
Definition: jit_type.h:17