Caffe2 - C++ API
A deep learning, cross platform ML framework
saved_variable.cpp
1 #include <torch/csrc/autograd/saved_variable.h>
2 
3 #include <torch/csrc/autograd/edge.h>
4 #include <torch/csrc/autograd/function.h>
5 #include <torch/csrc/autograd/variable.h>
6 
7 #include <ATen/Tensor.h>
8 
9 #include <cstdint>
10 #include <list>
11 #include <memory>
12 
13 namespace torch { namespace autograd {
14 
15 SavedVariable::SavedVariable(const Variable& variable, bool is_output) {
16  if (variable.defined()) {
17  was_default_constructed_ = false;
18  output_nr_ = variable.output_nr();
19  requires_grad_ = variable.requires_grad();
20  has_grad_fn_ = !variable.is_leaf();
21  // These copies are all shared_ptr copies, so slightly more expensive.
22  // Do them here instead of in the init list in case data is undefined.
23  data_ = variable.data();
24  if (variable.is_leaf()) {
25  grad_accumulator_ = variable.grad_accumulator();
26  } else if (!is_output) {
27  grad_fn_ = variable.grad_fn();
28  }
29  version_counter_ = variable.version_counter();
30  saved_version_ = version_counter_.current_version();
31  }
32 }
33 
34 Variable SavedVariable::unpack(std::shared_ptr<Function> saved_for) const {
35  if (!data_.defined()) {
36  if (!was_default_constructed_) {
37  throw std::runtime_error(ERR_BACKWARD_TWICE);
38  }
39  return Variable();
40  }
41 
42  if (saved_version_ != version_counter_.current_version()) {
43  throw std::runtime_error(
44  "one of the variables needed for gradient computation has been "
45  "modified by an inplace operation");
46  }
47 
48  auto grad_fn = grad_fn_;
49  if (has_grad_fn_ && !grad_fn) {
50  if (!saved_for) {
51  // If saving the grad_fn would create a circular reference, then it must
52  // be passed in to the unpack function.
53  throw std::runtime_error("No grad_fn for non-leaf saved variable");
54  }
55  grad_fn = std::move(saved_for);
56  }
57 
58  // NB: saved views are unpacked as normal Variables (not views) even though
59  // they still share the same storage. This works only because we never call
60  // in-place functions on unpacked variables.
61  Variable var;
62  if (grad_fn) {
63  var = make_variable(data_, Edge(std::move(grad_fn), output_nr_));
64  } else {
65  var = make_variable(data_, requires_grad_);
66  }
67  var.set_version_counter(saved_version_);
68 
69  // If a Variable is a leaf (no grad_fn saved), and it requires_grad, then we
70  // should have saved the grad accumulator. Even if the Variable no longer
71  // alive, the accumulator should be kept alive by the references in the
72  // graph).
73  if (requires_grad_ && !var.grad_fn() && grad_accumulator_.expired())
74  throw std::logic_error("No grad accumulator for a saved leaf!");
75  var.set_grad_accumulator(grad_accumulator_);
76 
77  return var;
78 }
79 
80 const char* ERR_BACKWARD_TWICE =
81  "Trying to backward through the graph a second time, but the buffers have "
82  "already been freed. Specify retain_graph=True when calling backward "
83  "the first time.";
84 
85 }} // namespace torch::autograd
void set_grad_accumulator(std::weak_ptr< Function > grad_accumulator)
Set the gradient accumulator of the Variable.
Definition: variable.h:664
Represents a particular input of a function.
Definition: edge.h:14
Variable unpack(std::shared_ptr< Function > saved_for=nullptr) const
Reconstructs the saved variable.
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17
const std::shared_ptr< Function > & grad_fn() const
Gets the gradient function of the Variable.
Definition: variable.cpp:201