3 #include <torch/csrc/utils/python_stub.h> 5 #include <torch/csrc/WindowsTorchApiMacro.h> 6 #include <torch/csrc/autograd/edge.h> 7 #include <torch/csrc/autograd/function_hook.h> 8 #include <torch/csrc/autograd/variable_version.h> 10 #include <ATen/ATen.h> 11 #include <c10/util/Exception.h> 21 namespace torch {
namespace autograd {
103 bool is_differentiable,
104 bool allow_tensor_metadata_change,
115 bool allow_tensor_metadata_change);
122 friend Variable make_variable_consuming(
125 bool allow_tensor_metadata_change);
134 bool allow_tensor_metadata_change);
143 is_variable() || !defined(),
144 "Tensor that was converted to Variable was not actually a Variable");
150 is_variable() || !defined(),
151 "Tensor that was converted to Variable was not actually a Variable");
169 const std::shared_ptr<Function>& grad_fn()
const;
172 Function* grad_fn_unsafe()
const;
176 void set_grad_accumulator(std::weak_ptr<Function> grad_accumulator);
181 std::shared_ptr<Function> try_get_grad_accumulator()
const;
185 std::shared_ptr<Function> grad_accumulator()
const;
202 if (
const auto& gradient = grad_fn()) {
203 return Edge(gradient, output_nr());
205 return Edge(grad_accumulator(), 0);
231 bool create_graph)
const;
245 void set_gradient_edge(
Edge edge) noexcept;
250 uint32_t output_nr()
const noexcept;
253 bool is_leaf()
const noexcept;
259 void bump_version() noexcept;
260 void set_version_counter(
const VariableVersion& version_counter) noexcept;
267 uint32_t current_version()
const noexcept;
278 void rebase_history(
Edge gradient_edge);
283 void add_hook(std::shared_ptr<FunctionPreHook> hook);
284 const std::vector<std::shared_ptr<FunctionPreHook>>& hooks()
const noexcept;
291 bool is_view()
const noexcept;
300 void set_name(
const std::string& name);
301 const std::string& name()
const noexcept;
303 PyObject* pyobj()
const noexcept;
304 void set_pyobj(PyObject* pyobj) noexcept;
307 Variable::AutogradMeta* get_autograd_meta()
const noexcept;
314 struct DifferentiableViewImpl;
315 struct DifferentiableViewMeta;
335 std::shared_ptr<Function> grad_fn_;
336 std::weak_ptr<Function> grad_accumulator_;
339 std::vector<std::shared_ptr<FunctionPreHook>> hooks_;
351 PyObject* pyobj_ =
nullptr;
363 !requires_grad || at::isFloatingType(at::typeMetaToScalarType(self_impl->
dtype())),
364 "Only Tensors of floating point dtype can require gradients");
365 requires_grad_ = requires_grad;
369 return requires_grad_ || grad_fn_;
377 const Variable& grad()
const override {
386 struct TORCH_API Variable::DifferentiableViewMeta :
public Variable::AutogradMeta {
396 return requires_grad_ || grad_fn_ || (is_view_ && base_.requires_grad());
407 std::unique_ptr<Variable::AutogradMeta> autograd_meta,
413 int64_t numel()
const override;
416 bool is_contiguous()
const override;
417 int64_t size(int64_t d)
const override;
418 int64_t stride(int64_t d)
const override;
419 void resize_dim(int64_t ndim)
override;
420 void set_size(int64_t dim, int64_t new_size)
override;
421 void set_stride(int64_t dim, int64_t new_stride)
override;
422 void set_storage_offset(int64_t storage_offset)
override;
424 int64_t dim()
const override;
425 bool has_storage()
const override;
427 void* slow_data()
const override;
432 void release_resources()
override;
434 Variable::AutogradMeta* get_autograd_meta()
const {
435 return static_cast<Variable::AutogradMeta*
>(autograd_meta());
438 int64_t storage_offset()
const override;
445 int64_t get_device_slow()
const override;
521 struct TORCH_API Variable::DifferentiableViewImpl :
public Variable::Impl {
522 DifferentiableViewImpl(
526 std::unique_ptr<Variable::DifferentiableViewMeta> autograd_meta);
529 void release_resources()
override;
547 inline Variable make_variable_view(
550 bool is_differentiable =
true,
551 bool allow_tensor_metadata_change =
true,
553 if (data.defined()) {
554 if (is_differentiable) {
556 auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach();
557 data_impl_copy->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
559 auto diff_view_meta = c10::guts::make_unique<Variable::DifferentiableViewMeta>();
560 return Variable(c10::make_intrusive<Variable::DifferentiableViewImpl>(
561 std::move(base), std::move(data_copy), std::move(gradient_edge), std::move(diff_view_meta)));
564 auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach();
565 data_impl_copy->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
567 auto autograd_meta = c10::guts::make_unique<Variable::AutogradMeta>();
568 auto var = Variable(c10::make_intrusive<Variable::Impl>(
569 std::move(data_copy), std::move(autograd_meta),
false, std::move(gradient_edge)));
577 inline Variable make_variable(
579 bool requires_grad =
false,
580 bool allow_tensor_metadata_change =
true) {
583 "Must not create a new variable from a variable, use its .data()");
584 if (data.defined()) {
585 auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach();
586 data_impl_copy->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
588 auto autograd_meta = c10::guts::make_unique<Variable::AutogradMeta>();
589 return Variable(c10::make_intrusive<Variable::Impl>(data_copy, std::move(autograd_meta),
requires_grad));
594 inline Variable make_variable_consuming(
596 bool requires_grad =
false,
597 bool allow_tensor_metadata_change =
true) {
600 "Must not create a new variable from a variable, use its .data()");
601 if (data.defined()) {
602 AT_ASSERT(data.getIntrusivePtr().use_count() == 1);
604 auto autograd_meta = c10::guts::make_unique<Variable::AutogradMeta>();
605 return Variable(c10::make_intrusive<Variable::Impl>(std::move(data), std::move(autograd_meta),
requires_grad));
610 inline Variable make_variable(
613 bool allow_tensor_metadata_change =
true) {
616 "Must not create a new variable from a variable, use its .data()");
617 if (data.defined()) {
618 auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach();
619 data_impl_copy->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
621 auto autograd_meta = c10::guts::make_unique<Variable::AutogradMeta>();
622 return Variable(c10::make_intrusive<Variable::Impl>(data_copy, std::move(autograd_meta),
false, std::move(gradient_edge)));
633 inline Variable& as_variable_ref(
at::Tensor& tensor) {
636 "Attempted to cast a Tensor to a Variable, but " 637 "the dynamic type of the value is not Variable.");
638 return static_cast<Variable&
>(tensor);
641 inline const Variable& as_variable_ref(
const at::Tensor& tensor) {
644 "Attempted to cast a Tensor to a Variable, but " 645 "the dynamic type of the value is not Variable.");
646 return static_cast<const Variable&
>(tensor);
649 inline const at::Tensor& Variable::data()
const noexcept {
653 inline at::Tensor& Variable::data() noexcept {
661 return get_autograd_meta()->grad_fn_.get();
665 std::weak_ptr<Function> grad_accumulator) {
666 get_autograd_meta()->grad_accumulator_ = std::move(grad_accumulator);
670 return get_autograd_meta()->grad_accumulator_.lock();
674 auto var = make_variable_view(*
this,
get()->data_,
false,
false,
Edge());
679 get()->set_data(new_data);
683 get_autograd_meta()->grad_fn_ = std::move(edge.function);
684 get_autograd_meta()->output_nr_ = edge.input_nr;
688 return get_autograd_meta()->output_nr_;
692 return get_autograd_meta()->grad_fn_ ==
nullptr;
698 inline void Variable::set_version_counter(
700 get_autograd_meta()->version_counter_ = version_counter;
704 get_autograd_meta()->version_counter_.bump();
708 return get_autograd_meta()->version_counter_.current_version();
712 return get_autograd_meta()->version_counter_;
718 inline void Variable::add_hook(std::shared_ptr<FunctionPreHook> hook) {
719 get_autograd_meta()->hooks_.push_back(std::move(hook));
722 inline const std::vector<std::shared_ptr<FunctionPreHook>>& Variable::hooks()
724 return get_autograd_meta()->hooks_;
727 inline void Variable::clear_hooks() {
728 get_autograd_meta()->hooks_.clear();
735 return get_autograd_meta()->is_view_;
740 auto diff_view_meta =
static_cast<Variable::DifferentiableViewMeta*
>(get_autograd_meta());
741 return diff_view_meta->
base_;
743 throw std::runtime_error(
"Can't get base of non-view Variable");
750 inline void Variable::set_name(
const std::string& name) {
751 get_autograd_meta()->name = name;
754 inline const std::string& Variable::name()
const noexcept {
755 return get_autograd_meta()->name;
758 inline void Variable::set_pyobj(PyObject* pyobj) noexcept {
759 get_autograd_meta()->pyobj_ = pyobj;
762 inline PyObject* Variable::pyobj()
const noexcept {
763 return get_autograd_meta()->pyobj_;
766 inline Variable::AutogradMeta* Variable::get_autograd_meta()
const noexcept {
767 return get()->get_autograd_meta();
776 inline Variable::Impl* Variable::get()
const {
777 AT_CHECK(defined(),
"Called Variable::get() on an undefined Variable");
778 return static_cast<Variable::Impl*
>(impl_.get());
Variable()=default
Default constructor.
void bump_version() noexcept
Increments the version count of this Variable.
Variable & grad() override
Accesses the gradient Variable of this Variable.
const caffe2::TypeMeta & dtype() const
Returns the TypeMeta of a tensor, which describes what data type it is (e.g., int, float, ...)
at::Tensor data_
The underlying data tensor for this Variable.
void set_gradient_edge(Edge edge) noexcept
Set the gradient edge – i.e.
void set_grad_accumulator(std::weak_ptr< Function > grad_accumulator)
Set the gradient accumulator of the Variable.
virtual void set_allow_tensor_metadata_change(bool value)
Set whether a tensor allows changes to its metadata (e.g.
Function * grad_fn_unsafe() const
Gets the raw gradient function pointer, whatever it currently is.
Represents a particular input of a function.
The low-level representation of a tensor, which contains a pointer to a storage (which contains the a...
bool is_variable() const noexcept
Returns true if the Tensor is actually a torch::autograd::Variable.
bool is_view() const noexcept
Returns true if this Variable is a view of another Variable.
uint32_t attr_version
The value of the version_counter at the time grad_fn was created.
Variable detach() const
Returns a copy of this Variable that is detached from its autograd graph and has a blank version...
const Variable & base() const
Returns the Variable that this Variable is a view of.
const VariableVersion & version_counter() const noexcept
Retrieves this Variables version counter.
bool is_leaf() const noexcept
True if this Variable is a leaf and thus does not have a grad_fn.
std::shared_ptr< Function > try_get_grad_accumulator() const
Attempts to get a pointer to the gradient accumulator of the Variable, if it still exists...
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
void set_data(const at::Tensor &new_data)
Sets the Tensor held by this Variable to the one supplied.
uint32_t current_version() const noexcept
Retrieves the current value of the Variable's version counter.
TensorOptions requires_grad(bool requires_grad=true)
Convenience function that returns a TensorOptions object with the requires_grad set to the given one...
Variable base_
The base Variable (never a view).
void set_requires_grad(bool requires_grad, at::TensorImpl *self_impl) override
Sets the requires_grad property of Variable.
uint32_t output_nr() const noexcept
Returns the input index of the gradient Function to which this Variable is connected.
Edge gradient_edge() const
Returns the "canonical" gradient edge of this Variable, i.e.