3 #include <torch/csrc/autograd/variable.h> 4 #include <torch/csrc/jit/pybind.h> 5 #include <torch/csrc/utils/hash.h> 19 : sizes(var.sizes().vec()),
20 type(var.scalar_type()),
22 requires_grad(var.requires_grad()) {}
25 return std::tie(device, requires_grad, type, sizes) ==
26 std::tie(o.device, o.requires_grad, o.type, o.sizes);
30 return get_hash(m.sizes, m.device, m.requires_grad, m.type);
33 std::vector<int64_t> sizes;
40 return std::tie(structure, metadata, grad_enabled) ==
41 std::tie(o.structure, o.metadata, o.grad_enabled);
45 return get_hash(o.structure, o.metadata, o.grad_enabled);
48 void extend(
const autograd::variable_list& list) {
49 metadata.reserve(metadata.size() + list.size());
50 for (
auto& var : list)
51 metadata.emplace_back(var);
61 std::string structure;
62 std::vector<VariableMetadata> metadata;
63 bool grad_enabled =
false;
66 static inline std::ostream& operator<<(
70 auto& t = at::getNonVariableType(
71 meta_device.
is_cpu() ? at::Backend::CPU : at::Backend::CUDA, meta.type);
72 out << t <<
"(requires_grad=" << meta.requires_grad;
74 out <<
", device=" << meta_device.
index();
77 for (
size_t i = 0; i < meta.sizes.size(); ++i) {
86 static inline std::ostream& operator<<(
89 out << desc.structure <<
"\n";
90 out <<
" with grad_enabled=" << desc.grad_enabled <<
"\n";
91 for (
size_t i = 0; i < desc.metadata.size(); ++i) {
92 out <<
" with v" << i <<
" having type " << desc.metadata[i] <<
"\n";
99 autograd::variable_list vars;
104 void extend(
const autograd::variable_list& list) {
107 vars.reserve(vars.size() + list.size());
108 for (
auto& var : list)
109 vars.emplace_back(var);
bool is_cuda() const noexcept
Return true if the device is of CUDA type.
Represents a a compute device on which a tensor is located.
Device device() const
Returns a Tensor's device.
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
bool is_cpu() const noexcept
Return true if the device is of CPU type.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
DeviceIndex index() const noexcept
Returns the optional index.