Caffe2 - C++ API
A deep learning, cross platform ML framework
saved_variable.h
1 #pragma once
2 
3 #include <torch/csrc/WindowsTorchApiMacro.h>
4 #include <torch/csrc/autograd/variable_version.h>
5 
6 #include <ATen/ATen.h>
7 
8 #include <cstdint>
9 #include <list>
10 #include <memory>
11 
12 namespace torch { namespace autograd {
13 
14 struct Variable;
15 struct Function;
16 
17 TORCH_API extern const char* ERR_BACKWARD_TWICE;
18 
21 class TORCH_API SavedVariable {
22  public:
23  SavedVariable() = default;
24  SavedVariable(const Variable& variable, bool is_output);
25  SavedVariable(SavedVariable&&) = default;
26  SavedVariable& operator=(SavedVariable&&) = default;
27 
31  Variable unpack(std::shared_ptr<Function> saved_for = nullptr) const;
32 
33  void reset_data() {
34  return data_.reset();
35  }
36 
37  void reset_grad_function() {
38  grad_fn_.reset();
39  }
40 
41  private:
42  at::Tensor data_;
43 
44  // The gradient function associated with this node. If has_grad_fn
45  // is false, then this is a leaf node. Note that the grad_fn is not saved if
46  // it would create a circular reference. In that case, the grad_fn must be
47  // passed in to the unpack function when reconstructing the Variable.
48  std::shared_ptr<Function> grad_fn_;
49  std::weak_ptr<Function> grad_accumulator_;
50  VariableVersion version_counter_;
51 
52  uint32_t saved_version_ = 0;
53  uint32_t output_nr_ = 0;
54  bool was_default_constructed_ = true;
55  bool requires_grad_ = false;
56  bool has_grad_fn_ = false;
57 };
58 }} // namespace torch::autograd
A snapshot of a variable at a certain version.
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17