Caffe2 - C++ API
A deep learning, cross platform ML framework
variable.h
1 #pragma once
2 
3 #include <torch/csrc/utils/python_stub.h>
4 
5 #include <torch/csrc/WindowsTorchApiMacro.h>
6 #include <torch/csrc/autograd/edge.h>
7 #include <torch/csrc/autograd/function_hook.h>
8 #include <torch/csrc/autograd/variable_version.h>
9 
10 #include <ATen/ATen.h>
11 #include <c10/util/Exception.h>
12 
13 #include <list>
14 #include <memory>
15 #include <mutex>
16 #include <stdexcept>
17 #include <string>
18 #include <utility>
19 #include <vector>
20 
21 namespace torch { namespace autograd {
22 
23 struct Function;
24 
84 
85 struct TORCH_API Variable : public at::Tensor {
87  Variable() = default;
88 
89  // Factory Functions
90  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
91 
92  // NOTE: These factory functions have to be friends to access the
93  // `Variable::Impl`. As a side effect, it allows us to keep them in the class.
94 
100  friend Variable make_variable_view(
101  Variable base,
102  at::Tensor data,
103  bool is_differentiable,
104  bool allow_tensor_metadata_change,
105  Edge gradient_edge);
106 
112  friend Variable make_variable(
113  at::Tensor data,
114  bool requires_grad,
115  bool allow_tensor_metadata_change);
116 
122  friend Variable make_variable_consuming(
123  at::Tensor data,
124  bool requires_grad,
125  bool allow_tensor_metadata_change);
126 
131  friend Variable make_variable(
132  at::Tensor data,
133  Edge gradient_edge,
134  bool allow_tensor_metadata_change);
135 
136  // Tensor Conversions
137  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
138 
139  // "Downcasts" a `Tensor` into a `Variable`. Only call this on tensors you
140  // know are Variables.
141  /*implicit*/ Variable(at::Tensor const& rhs) : at::Tensor(rhs) {
142  AT_CHECK(
143  is_variable() || !defined(),
144  "Tensor that was converted to Variable was not actually a Variable");
145  }
146 
147  /*implicit*/ Variable(at::Tensor&& rhs)
148  : at::Tensor(std::move(rhs)) {
149  AT_CHECK(
150  is_variable() || !defined(),
151  "Tensor that was converted to Variable was not actually a Variable");
152  }
153 
154  // NOTE: Assignment operators to Tensor come for free from the constructors.
155 
156  const at::Tensor& data() const noexcept;
157  at::Tensor& data() noexcept;
158 
159  // Gradient Function and Edges
160  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
161 
169  const std::shared_ptr<Function>& grad_fn() const;
170 
172  Function* grad_fn_unsafe() const;
173 
176  void set_grad_accumulator(std::weak_ptr<Function> grad_accumulator);
177 
181  std::shared_ptr<Function> try_get_grad_accumulator() const;
182 
185  std::shared_ptr<Function> grad_accumulator() const;
186 
195  Edge gradient_edge() const {
196  // If grad_fn is null (as is the case for a leaf node), we instead
197  // interpret the gradient function to be a gradient accumulator, which will
198  // accumulate its inputs into the grad property of the variable. These
199  // nodes get suppressed in some situations, see "suppress gradient
200  // accumulation" below. Note that only variables which have `requires_grad =
201  // True` can have gradient accumulators.
202  if (const auto& gradient = grad_fn()) {
203  return Edge(gradient, output_nr());
204  } else {
205  return Edge(grad_accumulator(), 0);
206  }
207  }
208 
220  Variable detach() const;
221 
225  void detach_();
226 
228  void backward(
229  c10::optional<Tensor> gradient,
230  bool keep_graph,
231  bool create_graph) const;
232 
237  void set_data(const at::Tensor &new_data);
238 
245  void set_gradient_edge(Edge edge) noexcept;
246 
250  uint32_t output_nr() const noexcept;
251 
253  bool is_leaf() const noexcept;
254 
255  // Versions
256  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
257 
259  void bump_version() noexcept;
260  void set_version_counter(const VariableVersion& version_counter) noexcept;
261 
263  const VariableVersion& version_counter() const noexcept;
264 
267  uint32_t current_version() const noexcept;
268 
269  // Autograd Graph Interaction
270  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
271 
278  void rebase_history(Edge gradient_edge);
279 
280  // Hooks
281  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
282 
283  void add_hook(std::shared_ptr<FunctionPreHook> hook);
284  const std::vector<std::shared_ptr<FunctionPreHook>>& hooks() const noexcept;
285  void clear_hooks();
286 
287  // View Variables
288  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
289 
291  bool is_view() const noexcept;
292 
295  const Variable& base() const;
296 
297  // Miscellaneous
298  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
299 
300  void set_name(const std::string& name);
301  const std::string& name() const noexcept;
302 
303  PyObject* pyobj() const noexcept;
304  void set_pyobj(PyObject* pyobj) noexcept;
305 
306  struct AutogradMeta;
307  Variable::AutogradMeta* get_autograd_meta() const noexcept;
308 
309  private:
313  struct Impl;
314  struct DifferentiableViewImpl;
315  struct DifferentiableViewMeta;
316 
317  // Private Methods
318  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
319 
321  Impl* get() const;
322 };
323 
324 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
325 // Variable::AutogradMeta
326 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
327 
330 
331 struct TORCH_API Variable::AutogradMeta : public c10::AutogradMetaInterface {
332  std::string name;
333 
334  Variable grad_;
335  std::shared_ptr<Function> grad_fn_;
336  std::weak_ptr<Function> grad_accumulator_;
337 
338  VariableVersion version_counter_;
339  std::vector<std::shared_ptr<FunctionPreHook>> hooks_;
340 
341  // Only meaningful on leaf variables (must be false otherwise)
342  bool requires_grad_;
343 
344  bool is_view_;
345 
346  // The "output number" of this variable; e.g., if this variable
347  // was the second output of a function, then output_nr == 1.
348  // We use this to make sure we can setup the backwards trace
349  // correctly when this variable is passed to another function.
350  uint32_t output_nr_;
351  PyObject* pyobj_ = nullptr; // weak reference
352 
353  // Mutex to ensure that concurrent read operations that modify internal
354  // state are still thread-safe. Used by grad_fn() and
355  // grad_accumulator().
356  std::mutex mutex_;
357 
361  void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl) override {
362  AT_CHECK(
363  !requires_grad || at::isFloatingType(at::typeMetaToScalarType(self_impl->dtype())),
364  "Only Tensors of floating point dtype can require gradients");
365  requires_grad_ = requires_grad;
366  }
367 
368  bool requires_grad() const override {
369  return requires_grad_ || grad_fn_;
370  }
371 
373  Variable& grad() override {
374  return grad_;
375  }
376 
377  const Variable& grad() const override {
378  return grad_;
379  }
380 };
381 
382 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
383 // Variable::DifferentiableViewMeta
384 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
385 
386 struct TORCH_API Variable::DifferentiableViewMeta : public Variable::AutogradMeta {
388  Variable base_;
389 
393  uint32_t attr_version;
394 
395  bool requires_grad() const override {
396  return requires_grad_ || grad_fn_ || (is_view_ && base_.requires_grad());
397  }
398 };
399 
400 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
401 // Variable::Impl
402 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
403 
404 struct TORCH_API Variable::Impl : public at::TensorImpl {
405  explicit Impl(
406  at::Tensor data,
407  std::unique_ptr<Variable::AutogradMeta> autograd_meta,
408  bool requires_grad = false,
409  Edge gradient_edge = Edge());
410 
411  ~Impl() override;
412 
413  int64_t numel() const override;
414  at::IntArrayRef sizes() const override;
415  at::IntArrayRef strides() const override;
416  bool is_contiguous() const override;
417  int64_t size(int64_t d) const override;
418  int64_t stride(int64_t d) const override;
419  void resize_dim(int64_t ndim) override;
420  void set_size(int64_t dim, int64_t new_size) override;
421  void set_stride(int64_t dim, int64_t new_stride) override;
422  void set_storage_offset(int64_t storage_offset) override;
423 
424  int64_t dim() const override;
425  bool has_storage() const override;
426  const at::Storage& storage() const override;
427  void* slow_data() const override;
428 
429  void set_data(const at::Tensor &new_data);
430 
432  void release_resources() override;
433 
434  Variable::AutogradMeta* get_autograd_meta() const {
435  return static_cast<Variable::AutogradMeta*>(autograd_meta());
436  }
437 
438  int64_t storage_offset() const override;
439 
443 
444  private:
445  int64_t get_device_slow() const override;
446 };
447 
448 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
449 // Variable::DifferentiableViewImpl
450 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
451 
521 struct TORCH_API Variable::DifferentiableViewImpl : public Variable::Impl {
522  DifferentiableViewImpl(
523  Variable base,
524  at::Tensor data,
525  Edge gradient_edge,
526  std::unique_ptr<Variable::DifferentiableViewMeta> autograd_meta);
527 
529  void release_resources() override;
530 };
531 
532 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
533 // Variable Implementation
534 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
535 
536 // Factory Functions
537 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
538 
545 
546 // See NOTE [ Autograd View Variables ] for details.
547 inline Variable make_variable_view(
548  Variable base,
549  at::Tensor data,
550  bool is_differentiable = true,
551  bool allow_tensor_metadata_change = true,
552  Edge gradient_edge = Edge()) {
553  if (data.defined()) {
554  if (is_differentiable) {
556  auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach();
557  data_impl_copy->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
558  auto data_copy = at::Tensor(data_impl_copy);
559  auto diff_view_meta = c10::guts::make_unique<Variable::DifferentiableViewMeta>();
560  return Variable(c10::make_intrusive<Variable::DifferentiableViewImpl>(
561  std::move(base), std::move(data_copy), std::move(gradient_edge), std::move(diff_view_meta)));
562  } else {
564  auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach();
565  data_impl_copy->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
566  auto data_copy = at::Tensor(data_impl_copy);
567  auto autograd_meta = c10::guts::make_unique<Variable::AutogradMeta>();
568  auto var = Variable(c10::make_intrusive<Variable::Impl>(
569  std::move(data_copy), std::move(autograd_meta), false, std::move(gradient_edge)));
570  var.set_version_counter(base.version_counter());
571  return var;
572  }
573  }
574  return Variable();
575 }
576 
577 inline Variable make_variable(
578  at::Tensor data,
579  bool requires_grad = false,
580  bool allow_tensor_metadata_change = true) {
581  AT_CHECK(
582  !data.is_variable(),
583  "Must not create a new variable from a variable, use its .data()");
584  if (data.defined()) {
585  auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach();
586  data_impl_copy->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
587  auto data_copy = at::Tensor(data_impl_copy);
588  auto autograd_meta = c10::guts::make_unique<Variable::AutogradMeta>();
589  return Variable(c10::make_intrusive<Variable::Impl>(data_copy, std::move(autograd_meta), requires_grad));
590  }
591  return Variable();
592 }
593 
594 inline Variable make_variable_consuming(
595  at::Tensor data,
596  bool requires_grad = false,
597  bool allow_tensor_metadata_change = true) {
598  AT_CHECK(
599  !data.is_variable(),
600  "Must not create a new variable from a variable, use its .data()");
601  if (data.defined()) {
602  AT_ASSERT(data.getIntrusivePtr().use_count() == 1);
603  data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
604  auto autograd_meta = c10::guts::make_unique<Variable::AutogradMeta>();
605  return Variable(c10::make_intrusive<Variable::Impl>(std::move(data), std::move(autograd_meta), requires_grad));
606  }
607  return Variable();
608 }
609 
610 inline Variable make_variable(
611  at::Tensor data,
612  Edge gradient_edge,
613  bool allow_tensor_metadata_change = true) {
614  AT_CHECK(
615  !data.is_variable(),
616  "Must not create a new variable from a variable, use its .data()");
617  if (data.defined()) {
618  auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach();
619  data_impl_copy->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
620  auto data_copy = at::Tensor(data_impl_copy);
621  auto autograd_meta = c10::guts::make_unique<Variable::AutogradMeta>();
622  return Variable(c10::make_intrusive<Variable::Impl>(data_copy, std::move(autograd_meta), false, std::move(gradient_edge)));
623  }
624  return Variable();
625 }
626 
627 // Tensor Conversion
628 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
629 
633 inline Variable& as_variable_ref(at::Tensor& tensor) {
634  AT_CHECK(
635  tensor.is_variable(),
636  "Attempted to cast a Tensor to a Variable, but "
637  "the dynamic type of the value is not Variable.");
638  return static_cast<Variable&>(tensor);
639 }
640 
641 inline const Variable& as_variable_ref(const at::Tensor& tensor) {
642  AT_CHECK(
643  tensor.is_variable(),
644  "Attempted to cast a Tensor to a Variable, but "
645  "the dynamic type of the value is not Variable.");
646  return static_cast<const Variable&>(tensor);
647 }
648 
649 inline const at::Tensor& Variable::data() const noexcept {
650  return get()->data_;
651 }
652 
653 inline at::Tensor& Variable::data() noexcept {
654  return get()->data_;
655 }
656 
657 // Gradient Function and Edges
658 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
659 
660 inline Function* Variable::grad_fn_unsafe() const {
661  return get_autograd_meta()->grad_fn_.get();
662 }
663 
665  std::weak_ptr<Function> grad_accumulator) {
666  get_autograd_meta()->grad_accumulator_ = std::move(grad_accumulator);
667 }
668 
669 inline std::shared_ptr<Function> Variable::try_get_grad_accumulator() const {
670  return get_autograd_meta()->grad_accumulator_.lock();
671 }
672 
673 inline Variable Variable::detach() const {
674  auto var = make_variable_view(*this, get()->data_, /*is_differentiable=*/false, /*allow_tensor_metadata_change=*/false, Edge());
675  return var;
676 }
677 
678 inline void Variable::set_data(const at::Tensor &new_data) {
679  get()->set_data(new_data);
680 }
681 
682 inline void Variable::set_gradient_edge(Edge edge) noexcept {
683  get_autograd_meta()->grad_fn_ = std::move(edge.function);
684  get_autograd_meta()->output_nr_ = edge.input_nr;
685 }
686 
687 inline uint32_t Variable::output_nr() const noexcept {
688  return get_autograd_meta()->output_nr_;
689 }
690 
691 inline bool Variable::is_leaf() const noexcept {
692  return get_autograd_meta()->grad_fn_ == nullptr;
693 }
694 
695 // Versions
696 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
697 
698 inline void Variable::set_version_counter(
699  const VariableVersion& version_counter) noexcept {
700  get_autograd_meta()->version_counter_ = version_counter;
701 }
702 
703 inline void Variable::bump_version() noexcept {
704  get_autograd_meta()->version_counter_.bump();
705 }
706 
707 inline uint32_t Variable::current_version() const noexcept {
708  return get_autograd_meta()->version_counter_.current_version();
709 }
710 
711 inline const VariableVersion& Variable::version_counter() const noexcept {
712  return get_autograd_meta()->version_counter_;
713 }
714 
715 // Hooks
716 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
717 
718 inline void Variable::add_hook(std::shared_ptr<FunctionPreHook> hook) {
719  get_autograd_meta()->hooks_.push_back(std::move(hook));
720 }
721 
722 inline const std::vector<std::shared_ptr<FunctionPreHook>>& Variable::hooks()
723  const noexcept {
724  return get_autograd_meta()->hooks_;
725 }
726 
727 inline void Variable::clear_hooks() {
728  get_autograd_meta()->hooks_.clear();
729 }
730 
731 // View Variables
732 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
733 
734 inline bool Variable::is_view() const noexcept {
735  return get_autograd_meta()->is_view_;
736 }
737 
738 inline const Variable& Variable::base() const {
739  if (is_view()) {
740  auto diff_view_meta = static_cast<Variable::DifferentiableViewMeta*>(get_autograd_meta());
741  return diff_view_meta->base_;
742  } else {
743  throw std::runtime_error("Can't get base of non-view Variable");
744  }
745 }
746 
747 // Miscellaneous
748 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
749 
750 inline void Variable::set_name(const std::string& name) {
751  get_autograd_meta()->name = name;
752 }
753 
754 inline const std::string& Variable::name() const noexcept {
755  return get_autograd_meta()->name;
756 }
757 
758 inline void Variable::set_pyobj(PyObject* pyobj) noexcept {
759  get_autograd_meta()->pyobj_ = pyobj;
760 }
761 
762 inline PyObject* Variable::pyobj() const noexcept {
763  return get_autograd_meta()->pyobj_;
764 }
765 
766 inline Variable::AutogradMeta* Variable::get_autograd_meta() const noexcept {
767  return get()->get_autograd_meta();
768 }
769 
770 // Private Methods
771 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
772 
774  : at::Tensor(std::move(self)) {}
775 
776 inline Variable::Impl* Variable::get() const {
777  AT_CHECK(defined(), "Called Variable::get() on an undefined Variable");
778  return static_cast<Variable::Impl*>(impl_.get());
779 }
780 }} // namespace torch::autograd
Variable()=default
Default constructor.
void bump_version() noexcept
Increments the version count of this Variable.
Definition: variable.h:703
Variable & grad() override
Accesses the gradient Variable of this Variable.
Definition: variable.h:373
const caffe2::TypeMeta & dtype() const
Returns the TypeMeta of a tensor, which describes what data type it is (e.g., int, float, ...)
Definition: TensorImpl.h:629
at::Tensor data_
The underlying data tensor for this Variable.
Definition: variable.h:442
void set_gradient_edge(Edge edge) noexcept
Set the gradient edge – i.e.
Definition: variable.h:682
void set_grad_accumulator(std::weak_ptr< Function > grad_accumulator)
Set the gradient accumulator of the Variable.
Definition: variable.h:664
virtual void set_allow_tensor_metadata_change(bool value)
Set whether a tensor allows changes to its metadata (e.g.
Definition: TensorImpl.h:814
Function * grad_fn_unsafe() const
Gets the raw gradient function pointer, whatever it currently is.
Definition: variable.h:660
Represents a particular input of a function.
Definition: edge.h:14
The low-level representation of a tensor, which contains a pointer to a storage (which contains the a...
Definition: TensorImpl.h:211
bool is_variable() const noexcept
Returns true if the Tensor is actually a torch::autograd::Variable.
bool is_view() const noexcept
Returns true if this Variable is a view of another Variable.
Definition: variable.h:734
uint32_t attr_version
The value of the version_counter at the time grad_fn was created.
Definition: variable.h:393
Variable detach() const
Returns a copy of this Variable that is detached from its autograd graph and has a blank version...
Definition: variable.h:673
const Variable & base() const
Returns the Variable that this Variable is a view of.
Definition: variable.h:738
const VariableVersion & version_counter() const noexcept
Retrieves this Variables version counter.
Definition: variable.h:711
bool is_leaf() const noexcept
True if this Variable is a leaf and thus does not have a grad_fn.
Definition: variable.h:691
std::shared_ptr< Function > try_get_grad_accumulator() const
Attempts to get a pointer to the gradient accumulator of the Variable, if it still exists...
Definition: variable.h:669
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17
void set_data(const at::Tensor &new_data)
Sets the Tensor held by this Variable to the one supplied.
Definition: variable.h:678
uint32_t current_version() const noexcept
Retrieves the current value of the Variable&#39;s version counter.
Definition: variable.h:707
TensorOptions requires_grad(bool requires_grad=true)
Convenience function that returns a TensorOptions object with the requires_grad set to the given one...
Variable base_
The base Variable (never a view).
Definition: variable.h:388
void set_requires_grad(bool requires_grad, at::TensorImpl *self_impl) override
Sets the requires_grad property of Variable.
Definition: variable.h:361
uint32_t output_nr() const noexcept
Returns the input index of the gradient Function to which this Variable is connected.
Definition: variable.h:687
Edge gradient_edge() const
Returns the "canonical" gradient edge of this Variable, i.e.
Definition: variable.h:195