1 #include <torch/csrc/autograd/variable.h> 3 #include <torch/csrc/autograd/edge.h> 4 #include <torch/csrc/autograd/engine.h> 5 #include <torch/csrc/autograd/function.h> 6 #include <torch/csrc/autograd/functions/accumulate_grad.h> 7 #include <torch/csrc/autograd/functions/tensor.h> 8 #include <torch/csrc/autograd/generated/Functions.h> 9 #include <torch/csrc/autograd/generated/VariableType.h> 10 #include <torch/csrc/autograd/variable_version.h> 12 #include <ATen/ATen.h> 13 #include <c10/util/Exception.h> 24 Variable::Impl::Impl(
at::Tensor data, std::unique_ptr<Variable::AutogradMeta> autograd_meta,
bool requires_grad, Edge gradient_edge)
25 : TensorImpl(data.type_id(), data.dtype(), nullptr, true),
26 data_(
std::move(data)) {
27 autograd_meta->grad_fn_ = std::move(gradient_edge.function);
28 autograd_meta->requires_grad_ =
false;
29 autograd_meta->is_view_ =
false;
30 autograd_meta->output_nr_ = gradient_edge.input_nr;
31 autograd_meta->pyobj_ =
nullptr;
34 autograd_meta->set_requires_grad(requires_grad,
this);
36 !autograd_meta->grad_fn_ || !autograd_meta->requires_grad_,
37 "requires_grad should be false if grad_fn is set");
38 if (!data_.defined()) {
39 throw std::runtime_error(
"data is undefined");
42 set_autograd_meta(std::move(autograd_meta));
45 Variable::Impl::~Impl() =
default;
47 int64_t Variable::Impl::numel()
const {
51 IntArrayRef Variable::Impl::sizes()
const {
55 IntArrayRef Variable::Impl::strides()
const {
56 return data_.strides();
59 bool Variable::Impl::is_contiguous()
const {
60 return data_.is_contiguous();
63 int64_t Variable::Impl::dim()
const {
67 int64_t Variable::Impl::size(int64_t d)
const {
71 int64_t Variable::Impl::stride(int64_t d)
const {
72 return data_.stride(d);
75 void Variable::Impl::resize_dim(int64_t ndim) {
76 AT_ERROR(
"variable impl does not have resize_dim");
79 void Variable::Impl::set_size(int64_t dim, int64_t new_size) {
80 AT_ERROR(
"variable impl does not have set_size");
83 void Variable::Impl::set_stride(int64_t dim, int64_t new_stride) {
84 AT_ERROR(
"variable impl does not have set_stride");
87 void Variable::Impl::set_storage_offset(int64_t storage_offset) {
88 AT_ERROR(
"variable impl does not have set_storage_offset");
91 void* Variable::Impl::slow_data()
const {
92 return data_.unsafeGetTensorImpl()->slow_data();
95 bool Variable::Impl::has_storage()
const {
96 return data_.has_storage();
99 const at::Storage& Variable::Impl::storage()
const {
100 return data_.storage();
103 int64_t Variable::Impl::storage_offset()
const {
104 return data_.storage_offset();
107 int64_t Variable::Impl::get_device_slow()
const {
108 return data_.get_device();
112 auto autograd_meta = get_autograd_meta();
113 if (autograd_meta->grad_fn_) {
114 throw std::logic_error(
115 "grad_accumulator() should be only called on leaf Variables");
117 if (!autograd_meta->requires_grad_) {
121 std::lock_guard<std::mutex> lock(autograd_meta->mutex_);
123 auto result = autograd_meta->grad_accumulator_.lock();
127 c10::raw::intrusive_ptr::incref(unsafeGetTensorImpl());
129 result = std::make_shared<AccumulateGrad>(
Variable(std::move(intrusive_from_this)));
130 autograd_meta->grad_accumulator_ = result;
136 AT_ERROR(
"Can't detach views in-place. Use detach() instead");
138 auto autograd_meta = get_autograd_meta();
139 autograd_meta->set_requires_grad(
false, unsafeGetTensorImpl());
140 autograd_meta->grad_fn_.reset();
141 autograd_meta->output_nr_ = 0;
147 bool create_graph)
const {
148 auto autograd_meta = get_autograd_meta();
149 std::vector<Edge> edges;
150 edges.emplace_back(autograd_meta->grad_fn_, autograd_meta->output_nr_);
152 std::vector<Variable> inputs;
153 if (!gradient.has_value()) {
154 gradient = make_variable(at::ones_like(data()),
false);
156 inputs.push_back(std::move(as_variable_ref(*gradient)));
160 void Variable::Impl::set_data(
const at::Tensor &new_data) {
162 auto autograd_meta = get_autograd_meta();
163 std::lock_guard<std::mutex> lock(autograd_meta->mutex_);
164 auto prior_accumulator = autograd_meta->grad_accumulator_.lock();
165 if (prior_accumulator) {
166 const auto prior_device = prior_accumulator->input_metadata(0).device();
167 const auto new_device = new_data.
device();
169 if (new_data.type() != data_.type() || prior_device != new_device) {
170 autograd_meta->grad_accumulator_.reset();
175 data_type_ = new_data.type().typeMeta();
176 type_id_ = new_data.type().type_id();
179 auto new_data_copy =
at::Tensor(new_data.getIntrusivePtr()->shallow_copy_and_detach());
180 data_ = std::move(new_data_copy);
183 void Variable::Impl::release_resources() {
184 autograd_meta_.reset();
188 Variable::DifferentiableViewImpl::DifferentiableViewImpl(
Variable base,
at::Tensor data,
Edge gradient_edge, std::unique_ptr<Variable::DifferentiableViewMeta> autograd_meta)
189 : Variable::Impl(std::move(data), std::move(autograd_meta),
false, std::move(gradient_edge)) {
190 auto diff_view_meta =
static_cast<Variable::DifferentiableViewMeta*
>(get_autograd_meta());
191 diff_view_meta->base_ = std::move(base);
192 AT_CHECK(diff_view_meta->base_.defined(),
"base is undefined");
193 if (diff_view_meta->base_.is_view()) {
194 diff_view_meta->base_ = diff_view_meta->base_.base();
196 diff_view_meta->is_view_ =
true;
197 diff_view_meta->version_counter_ = diff_view_meta->base_.version_counter();
198 diff_view_meta->attr_version = diff_view_meta->version_counter_.current_version();
203 auto diff_view_meta =
static_cast<Variable::DifferentiableViewMeta*
>(get_autograd_meta());
204 std::lock_guard<std::mutex> lock(diff_view_meta->mutex_);
205 if (!diff_view_meta->grad_fn_ && !diff_view_meta->base_.requires_grad()) {
206 return diff_view_meta->grad_fn_;
208 auto current_version = diff_view_meta->version_counter_.current_version();
209 if (diff_view_meta->attr_version != current_version) {
210 AT_ASSERT(diff_view_meta->output_nr_ == 0);
211 auto fn = std::make_shared<generated::AsStridedBackward>();
213 fn->size = sizes().vec();
214 fn->stride = strides().vec();
215 fn->storage_offset = data().storage_offset();
216 fn->set_next_edges(collect_next_edges(diff_view_meta->base_));
217 fn->add_input_metadata(
218 diff_view_meta->base_.type()
220 , diff_view_meta->base_.device());
221 diff_view_meta->grad_fn_ = std::move(fn);
222 diff_view_meta->attr_version = current_version;
224 return diff_view_meta->grad_fn_;
226 return get_autograd_meta()->grad_fn_;
230 void Variable::DifferentiableViewImpl::release_resources() {
231 auto diff_view_meta =
static_cast<Variable::DifferentiableViewMeta*
>(get_autograd_meta());
232 diff_view_meta->base_.reset();
233 Variable::Impl::release_resources();
237 AT_ASSERT(gradient_edge.
function !=
nullptr);
239 auto diff_view_meta =
static_cast<Variable::DifferentiableViewMeta*
>(get_autograd_meta());
240 AT_ASSERT(gradient_edge.
input_nr == 0);
243 gradient_edge.
function->num_inputs() == 1,
244 "Functions which modify views in-place must return a single Variable");
245 diff_view_meta->output_nr_ = gradient_edge.
input_nr;
246 auto copy_slices = std::make_shared<CopySlices>(
248 diff_view_meta->base_.set_gradient_edge({std::move(copy_slices), 0});
251 set_gradient_edge(std::move(gradient_edge));
Represents a particular input of a function.
static Engine & get_default_engine()
Returns a reference to a static Engine instance.
uint32_t input_nr
The identifier of a particular input to the function.
constexpr size_t size() const
size - Get the array size.
Device device() const
Returns a Tensor's device.
std::shared_ptr< Function > function
The function this Edge points to.
std::shared_ptr< Function > grad_accumulator() const
Gets the gradient accumulator of the Variable if it has one, or else create one on the fly and return...
void rebase_history(Edge gradient_edge)
Update the grad_fn of an existing Variable.
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
void detach_()
Like detach(), but removes this Variable in-place.
void backward(c10::optional< Tensor > gradient, bool keep_graph, bool create_graph) const
Computes the gradient of current tensor w.r.t. graph leaves.
const std::shared_ptr< Function > & grad_fn() const
Gets the gradient function of the Variable.
static intrusive_ptr reclaim(TTarget *owning_ptr)
Takes an owning pointer to TTarget* and creates an intrusive_ptr that takes over ownership.