Caffe2 - C++ API
A deep learning, cross platform ML framework
function.cpp
1 #include <torch/csrc/autograd/function.h>
2 
3 #include <torch/csrc/autograd/engine.h>
4 #include <torch/csrc/autograd/variable.h>
5 #include <torch/csrc/jit/ir.h>
6 
7 #include <ATen/ATen.h>
8 
9 #include <algorithm>
10 #include <cstdint>
11 #include <memory>
12 #include <stdexcept>
13 #include <string>
14 #include <utility>
15 #include <vector>
16 #include <deque>
17 
18 namespace torch { namespace autograd {
19 
22 thread_local uint64_t Function_next_sequence_nr_ = 0;
23 
24 uint64_t Function::peek_at_next_sequence_nr() {
25  return Function_next_sequence_nr_;
26 }
27 
28 uint64_t& Function::get_next_sequence_nr() {
29  return Function_next_sequence_nr_;
30 }
31 
32 auto Function::name() const -> std::string {
33  return c10::demangle(typeid(*this).name());
34 }
35 
36 AnomalyMetadata* Function::metadata() noexcept {
37  if (!anomaly_metadata_) {
38  anomaly_metadata_ = Engine::get_default_engine().make_anomaly_metadata();
39  }
40  return anomaly_metadata_.get();
41 }
42 
43 static void gatherFunctions(
44  Function* func,
45  std::vector<std::shared_ptr<Function>>& stack) {
46  func->release_variables();
47 
48  for (auto& edge : func->next_edges()) {
49  if (edge.function.use_count() == 1) {
50  stack.emplace_back(std::move(edge.function));
51  } else {
52  edge.function.reset();
53  }
54  }
55 }
56 
57 /*
58  * Fix for #5534: prevent stack overflow on deletion of deep computation graph
59  *
60  * Sometimes one can end up with a very big computation graph of Functions
61  * and Edges. Each std::shared_ptr<Function> contains a list of Edge, and
62  * each Edge contains a std::shared_ptr<Function>. Deleting a
63  * std::shared_ptr<Function> can trigger the recursive deletion of other
64  * std::shared_ptr<Function>'s: this can stack overflow if the graph
65  * is deep enough. Here is an example of such a graph:
66  *
67  * shared_ptr<Function> -> Edge -> shared_ptr<Function> -> Edge -> ... -> shared_ptr<Function>
68  *
69  * The solution here is to detect when we are decrementing away the last
70  * reference to a Function, and when doing so to buffer up the Function's
71  * that will be recursively decremented. We can then decrement (and free)
72  * the original Function without causing a recursive cascade, before
73  * draining the buffer applying the same behavior. This is, in effect,
74  * converting recursion to a loop, using a heap buffer in place of the
75  * recursive call stack.
76  */
77 void deleteFunction(Function* function) {
78  // To avoid stack overflow on large computational graphs,
79  // we need to track reference decrementing and freeing
80  // on the heap.
81  function->release_variables();
82  std::vector<std::shared_ptr<Function>> stack;
83  gatherFunctions(function, stack);
84  delete function;
85 
86  while (!stack.empty()) {
87  auto func = std::move(stack.back());
88  stack.pop_back();
89  gatherFunctions(func.get(), stack);
90  // Reference count is decremented on the loop backedge.
91  }
92 }
93 
94 }} // namespace torch::autograd
static Engine & get_default_engine()
Returns a reference to a static Engine instance.
Definition: engine.cpp:642
Definition: jit_type.h:17
std::string demangle(const char *name)
Utility to demangle a C++ symbol name.
Definition: Type.cpp:23