Caffe2 - C++ API
A deep learning, cross platform ML framework
function.h
1 #pragma once
2 
3 #include <torch/csrc/autograd/edge.h>
4 #include <torch/csrc/autograd/grad_mode.h>
5 #include <torch/csrc/autograd/anomaly_mode.h>
6 #include <torch/csrc/autograd/profiler.h>
7 #include <torch/csrc/autograd/saved_variable.h>
8 #include <torch/csrc/autograd/input_metadata.h>
9 #include <torch/csrc/autograd/variable.h>
10 #include <torch/csrc/utils/python_stub.h>
11 #include <torch/csrc/utils/variadic.h>
12 
13 #include <ATen/ATen.h>
14 #include <c10/util/Exception.h>
15 
16 #include <algorithm>
17 #include <cstdint>
18 #include <initializer_list>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 namespace torch { namespace autograd {
25 
26 struct Edge;
27 struct FunctionPostHook;
28 struct FunctionPreHook;
29 
30 using tensor_list = std::vector<at::Tensor>;
31 using variable_list = std::vector<Variable>;
32 using edge_list = std::vector<Edge>;
33 using saved_variable_list = std::vector<SavedVariable>;
34 using IndexRange = std::pair<size_t, size_t>;
35 
36 // Custom deleter to prevent stack overflows.
37 void deleteFunction(Function* function);
38 
87 struct TORCH_API Function : std::enable_shared_from_this<Function> {
88  public:
92  explicit Function(
93  uint64_t sequence_nr,
94  edge_list&& next_edges = edge_list())
95  : sequence_nr_(sequence_nr),
96  next_edges_(std::move(next_edges)) {
97  if (AnomalyMode::is_enabled()) {
98  metadata()->store_stack();
99  }
100  }
101 
102  explicit Function(edge_list&& next_edges = edge_list())
103  : Function(get_next_sequence_nr()++, std::move(next_edges)) {}
104 
106  Function(const Function& other) = delete;
107  Function(Function&& other) = delete;
108  Function& operator=(const Function& other) = delete;
109  Function& operator=(Function&& other) = delete;
110  virtual ~Function() = default;
111 
114  variable_list operator()(variable_list&& inputs) {
115  profiler::RecordFunction rec(this);
116  return apply(std::move(inputs));
117  }
118 
119  // Graph Connectivity API
120  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
121 
122  // Inputs. NOTE: inputs of the grad_fn correspond to Tensor outputs of the
123  // forward function.
124 
125  // Marker for expected undefined input
126  struct undefined_input {};
127 
130  uint32_t add_input_metadata(
131  const at::Type& type
132  , at::IntArrayRef shape
133  , at::Device device) noexcept {
134  uint32_t input_nr = input_metadata_.size();
135  input_metadata_.emplace_back(type, shape, device);
136  return input_nr;
137  }
138 
139  uint32_t add_input_metadata(const at::Tensor& t) noexcept {
140  uint32_t input_nr = input_metadata_.size();
141  input_metadata_.emplace_back(t);
142  return input_nr;
143  }
144 
146  uint32_t add_input_metadata(undefined_input u) noexcept {
147  uint32_t input_nr = input_metadata_.size();
148  input_metadata_.emplace_back();
149  return input_nr;
150  }
151 
152  uint32_t num_inputs() const noexcept {
153  return input_metadata_.size();
154  }
155 
156  const InputMetadata& input_metadata(size_t index) const {
157  return input_metadata_[index];
158  }
159 
160  void clear_input_metadata() {
161  input_metadata_.clear();
162  }
163 
164  // Outputs ("Next Edges")
165 
166  const Edge& next_edge(size_t index) const noexcept {
167  return next_edges_[index];
168  }
169 
170  void set_next_edge(size_t index, Edge edge) {
171  next_edges_[index] = std::move(edge);
172  }
173 
174  void add_next_edge(Edge edge) {
175  next_edges_.push_back(std::move(edge));
176  }
177 
178  void set_next_edges(edge_list&& next_edges) {
179  next_edges_ = std::move(next_edges);
180  }
181 
182  const edge_list& next_edges() const noexcept {
183  return next_edges_;
184  }
185 
186  edge_list& next_edges() noexcept {
187  return next_edges_;
188  }
189 
190  uint32_t num_outputs() const noexcept {
191  return next_edges_.size();
192  }
193 
194  // Miscellaneous Methods
195  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
196 
198  uint64_t sequence_nr() const noexcept {
199  return sequence_nr_;
200  }
201 
205  virtual std::shared_ptr<Function> get_shared_ptr() {
206  return shared_from_this();
207  }
208 
210  virtual std::string name() const;
211 
214  bool should_compute_output(size_t output_edge_index) const {
215  AT_CHECK(output_edge_index < num_outputs(), "Index out of range");
216  return next_edges_[output_edge_index].is_valid();
217  }
218 
220  bool should_compute_output(std::initializer_list<IndexRange> idxs) const {
221  return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) {
222  for (auto i = range.first; i < range.second; i++) {
223  if (should_compute_output(i))
224  return true;
225  }
226  return false;
227  });
228  }
229 
232  PyObject* pyobj() const noexcept {
233  return pyobj_;
234  }
235 
237  void set_pyobj(PyObject* pyobj) noexcept {
238  pyobj_ = pyobj;
239  }
240 
243  AnomalyMetadata* metadata() noexcept;
244 
245  // Hook API
246  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
247 
248  void add_post_hook(std::unique_ptr<FunctionPostHook>&& post_hook) {
249  post_hooks_.push_back(std::move(post_hook));
250  }
251 
252  const std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() const
253  noexcept {
254  return post_hooks_;
255  }
256 
257  std::vector<std::unique_ptr<FunctionPostHook>>& post_hooks() noexcept {
258  return post_hooks_;
259  }
260 
261  void add_pre_hook(std::unique_ptr<FunctionPreHook>&& pre_hook) {
262  pre_hooks_.push_back(std::move(pre_hook));
263  }
264 
265  const std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks() const
266  noexcept {
267  return pre_hooks_;
268  }
269 
270  std::vector<std::unique_ptr<FunctionPreHook>>& pre_hooks() noexcept {
271  return pre_hooks_;
272  }
273 
274  // Customization Points for Subclasses
275  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
276 
278  virtual void release_variables() {}
279 
283  virtual void will_release_variables() {}
284 
288  virtual bool is_traceable() {
289  return false;
290  }
291 
301  virtual bool passes_state_transparently() {
302  return false;
303  }
304 
305  static uint64_t peek_at_next_sequence_nr();
306 
307  protected:
308  static uint64_t& get_next_sequence_nr();
309 
311  virtual variable_list apply(variable_list&& inputs) = 0;
312 
314  variable_list traced_apply(variable_list inputs);
315 
316  // Since `Function`s are neither copyable nor moveable, we can have const
317  // fields.
318  const uint64_t sequence_nr_;
319 
320  edge_list next_edges_;
321  PyObject* pyobj_ = nullptr; // weak reference
322  std::unique_ptr<AnomalyMetadata> anomaly_metadata_ = nullptr;
323  std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_;
324  std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
325  at::SmallVector<InputMetadata, 2> input_metadata_;
326 };
327 
329 struct TraceableFunction : public Function {
330  using Function::Function;
331  bool is_traceable() final {
332  return true;
333  }
334 };
335 
336 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
337 // Associated Free Functions
338 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
339 
340 namespace detail {
341 // Implementation of `collect_next_edges` (see below).
342 struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> {
343  edge_list next_edges;
344  using IterArgs<MakeNextFunctionList>::operator();
345  void operator()(const Variable& variable) {
346  if (variable.defined()) {
347  next_edges.push_back(variable.gradient_edge());
348  } else {
349  next_edges.emplace_back();
350  }
351  }
352 };
353 } // namespace detail
354 
366 inline void create_gradient_edge(
367  Variable& variable,
368  std::shared_ptr<Function> function) {
369  // Copy before move.
370  const auto input_nr = function->add_input_metadata(variable);
371  variable.set_gradient_edge({std::move(function), input_nr});
372 }
373 
375 inline bool any_variable_requires_grad(const variable_list& variables) {
376  return std::any_of(
377  variables.begin(), variables.end(), [](const Variable& variable) {
378  return variable.defined() && variable.requires_grad();
379  });
380 }
381 
383 template <typename... Variables>
384 edge_list collect_next_edges(Variables&&... variables) {
385  if (!GradMode::is_enabled())
386  return {};
387  detail::MakeNextFunctionList make;
388  make.apply(std::forward<Variables>(variables)...);
389  return std::move(make.next_edges);
390 }
391 }} // namespace torch::autograd
This is a &#39;vector&#39; (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:939
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
Definition: jit_type.h:17