Caffe2 - C++ API
A deep learning, cross platform ML framework
python_arg_flatten.cpp
1 #include <torch/csrc/jit/python_arg_flatten.h>
2 #include <torch/csrc/utils/six.h>
3 
4 #include <torch/csrc/autograd/grad_mode.h>
5 
6 namespace torch {
7 namespace jit {
8 namespace python {
9 
10 using namespace torch::autograd;
11 using namespace at;
12 
13 // Alphabet used to describe structure of inputs/outputs (D for desc)
14 namespace D {
15 static constexpr char ListOpen = '[';
16 static constexpr char ListClose = ']';
17 static constexpr char TupleOpen = '(';
18 static constexpr char TupleClose = ')';
19 static constexpr char Variable = 'v';
20 } // namespace D
21 
22 namespace {
23 
24 template <typename T>
25 py::object cast_handle_sequence(std::vector<py::handle> objs) {
26  auto num_objs = objs.size();
27  T sequence{num_objs};
28  for (size_t i = 0; i < num_objs; ++i)
29  sequence[i] = py::reinterpret_borrow<py::object>(objs[i]);
30  return sequence;
31 }
32 
33 void flatten_rec(PyObject* obj, ParsedArgs& args) {
34  auto& structure = args.desc.structure;
35  if (six::isTuple(obj)) {
36  structure.push_back(D::TupleOpen);
37  for (auto item : py::reinterpret_borrow<py::tuple>(obj))
38  flatten_rec(item.ptr(), args);
39  structure.push_back(D::TupleClose);
40  } else if (PyList_Check(obj)) {
41  structure.push_back(D::ListOpen);
42  for (auto item : py::reinterpret_borrow<py::list>(obj))
43  flatten_rec(item.ptr(), args);
44  structure.push_back(D::ListClose);
45  } else if (THPVariable_Check(obj)) {
46  auto& var = reinterpret_cast<THPVariable*>(obj)->cdata;
47  args.vars.push_back(var);
48  args.desc.metadata.emplace_back(var);
49  args.desc.structure.push_back(D::Variable);
50  } else {
51  std::string msg =
52  "Only tuples, lists and Variables supported as JIT inputs, but got ";
53  msg += THPUtils_typename(obj);
54  throw std::runtime_error(msg);
55  }
56 }
57 
58 } // anonymous namespace
59 
60 ParsedArgs flatten(py::handle obj) {
61  ParsedArgs args;
62  args.desc.grad_enabled = autograd::GradMode::is_enabled();
63  flatten_rec(obj.ptr(), args);
64  return args;
65 }
66 
67 namespace {
68 
69 template <typename T>
70 py::object cast_sequence(std::vector<py::object> objs) {
71  auto num_objs = objs.size();
72  T sequence{num_objs};
73  for (size_t i = 0; i < num_objs; ++i)
74  sequence[i] = std::move(objs[i]);
75  return std::move(sequence);
76 }
77 
78 py::object unflatten_rec(
79  ArrayRef<Variable>::iterator& var_it,
80  ArrayRef<Variable>::iterator& var_it_end,
81  std::string::const_iterator& desc_it) {
82  char type = *desc_it++;
83  if (type == D::TupleOpen) {
84  std::vector<py::object> objs;
85  while (*desc_it != D::TupleClose)
86  objs.push_back(unflatten_rec(var_it, var_it_end, desc_it));
87  ++desc_it;
88  return cast_sequence<py::tuple>(objs);
89  } else if (type == D::ListOpen) {
90  std::vector<py::object> objs;
91  while (*desc_it != D::ListClose)
92  objs.push_back(unflatten_rec(var_it, var_it_end, desc_it));
93  ++desc_it;
94  return cast_sequence<py::list>(objs);
95  } else {
96  if (var_it == var_it_end)
97  throw std::runtime_error("Not enough Variables given to unflatten");
98  auto var = *var_it++;
99  return py::reinterpret_steal<py::object>(THPVariable_Wrap(var));
100  }
101 }
102 
103 } // anonymous namespace
104 
105 PyObject* unflatten(ArrayRef<Variable> vars, const IODescriptor& desc) {
106  // NB: We don't do correctness checking on descriptor.
107  // It has to be a correct bytes object produced by unflatten.
108  auto vars_it = vars.begin();
109  auto vars_it_end = vars.end();
110  auto desc_it = desc.structure.begin();
111  auto output = unflatten_rec(vars_it, vars_it_end, desc_it);
112  if (vars_it != vars_it_end)
113  throw std::runtime_error("Too many Variables given to unflatten");
114  return output.release().ptr();
115 }
116 
117 } // namespace python
118 } // namespace jit
119 } // namespace torch
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17
Flush-To-Zero and Denormals-Are-Zero mode.
Definition: static.cpp:70