Caffe2 - C++ API
A deep learning, cross platform ML framework
python_arg_flatten.h
1 #pragma once
2 
3 #include <torch/csrc/autograd/variable.h>
4 #include <torch/csrc/jit/pybind.h>
5 #include <torch/csrc/utils/hash.h>
6 
7 #include <ATen/ATen.h>
8 #include <functional>
9 #include <tuple>
10 #include <vector>
11 
12 namespace torch {
13 namespace jit {
14 namespace python {
15 
16 struct IODescriptor {
19  : sizes(var.sizes().vec()),
20  type(var.scalar_type()),
21  device(var.device()),
22  requires_grad(var.requires_grad()) {}
23 
24  bool operator==(const VariableMetadata& o) const {
25  return std::tie(device, requires_grad, type, sizes) ==
26  std::tie(o.device, o.requires_grad, o.type, o.sizes);
27  }
28 
29  static size_t hash(const VariableMetadata& m) {
30  return get_hash(m.sizes, m.device, m.requires_grad, m.type);
31  }
32 
33  std::vector<int64_t> sizes;
34  at::ScalarType type;
35  at::Device device;
36  bool requires_grad;
37  };
38 
39  bool operator==(const IODescriptor& o) const {
40  return std::tie(structure, metadata, grad_enabled) ==
41  std::tie(o.structure, o.metadata, o.grad_enabled);
42  }
43 
44  static size_t hash(const IODescriptor& o) {
45  return get_hash(o.structure, o.metadata, o.grad_enabled);
46  }
47 
48  void extend(const autograd::variable_list& list) {
49  metadata.reserve(metadata.size() + list.size());
50  for (auto& var : list)
51  metadata.emplace_back(var);
52  }
53 
54  // Description of argument structure. Variables are replaced with
55  // different characters, depending on their flags, beginnings and
56  // ends of tuples and lists are denoted by a pair of parenthesis
57  // of their corresponding kind. They should always be paired.
58  // Example desc: (vv[v(v)v])
59  // NOTE: if extend() was ever called then metadata.size() can be
60  // different than the number of 'v's in structure.
61  std::string structure;
62  std::vector<VariableMetadata> metadata;
63  bool grad_enabled = false;
64 };
65 
66 static inline std::ostream& operator<<(
67  std::ostream& out,
68  const IODescriptor::VariableMetadata& meta) {
69  at::Device meta_device = meta.device;
70  auto& t = at::getNonVariableType(
71  meta_device.is_cpu() ? at::Backend::CPU : at::Backend::CUDA, meta.type);
72  out << t << "(requires_grad=" << meta.requires_grad;
73  if (meta_device.is_cuda()) {
74  out << ", device=" << meta_device.index();
75  }
76  out << ") {";
77  for (size_t i = 0; i < meta.sizes.size(); ++i) {
78  if (i > 0)
79  out << ", ";
80  out << meta.sizes[i];
81  }
82  out << "}";
83  return out;
84 }
85 
86 static inline std::ostream& operator<<(
87  std::ostream& out,
88  const IODescriptor& desc) {
89  out << desc.structure << "\n";
90  out << " with grad_enabled=" << desc.grad_enabled << "\n";
91  for (size_t i = 0; i < desc.metadata.size(); ++i) {
92  out << " with v" << i << " having type " << desc.metadata[i] << "\n";
93  }
94  return out;
95 }
96 
97 struct ParsedArgs {
98  // Flat vector of Variables found in arguments
99  autograd::variable_list vars;
100  // Metadata describing nesting of objects received from Python and
101  // metadata of vars and whether grad is enabled.
102  IODescriptor desc;
103 
104  void extend(const autograd::variable_list& list) {
105  if (list.empty())
106  return;
107  vars.reserve(vars.size() + list.size());
108  for (auto& var : list)
109  vars.emplace_back(var);
110  desc.extend(list);
111  }
112 };
113 
114 ParsedArgs flatten(py::handle obj);
115 PyObject* unflatten(
117  const IODescriptor& structure);
118 
119 } // namespace python
120 } // namespace jit
121 } // namespace torch
bool is_cuda() const noexcept
Return true if the device is of CUDA type.
Definition: Device.h:80
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
Device device() const
Returns a Tensor&#39;s device.
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17
bool is_cpu() const noexcept
Return true if the device is of CPU type.
Definition: Device.h:85
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41
DeviceIndex index() const noexcept
Returns the optional index.
Definition: Device.h:70