1 #include <torch/csrc/jit/python_arg_flatten.h> 2 #include <torch/csrc/utils/six.h> 4 #include <torch/csrc/autograd/grad_mode.h> 15 static constexpr
char ListOpen =
'[';
16 static constexpr
char ListClose =
']';
17 static constexpr
char TupleOpen =
'(';
18 static constexpr
char TupleClose =
')';
19 static constexpr
char Variable =
'v';
25 py::object cast_handle_sequence(std::vector<py::handle> objs) {
26 auto num_objs = objs.size();
28 for (
size_t i = 0; i < num_objs; ++i)
29 sequence[i] = py::reinterpret_borrow<py::object>(objs[i]);
33 void flatten_rec(PyObject* obj, ParsedArgs& args) {
34 auto& structure = args.desc.structure;
35 if (six::isTuple(obj)) {
36 structure.push_back(D::TupleOpen);
37 for (
auto item : py::reinterpret_borrow<py::tuple>(obj))
38 flatten_rec(item.ptr(), args);
39 structure.push_back(D::TupleClose);
40 }
else if (PyList_Check(obj)) {
41 structure.push_back(D::ListOpen);
42 for (
auto item : py::reinterpret_borrow<py::list>(obj))
43 flatten_rec(item.ptr(), args);
44 structure.push_back(D::ListClose);
45 }
else if (THPVariable_Check(obj)) {
46 auto& var =
reinterpret_cast<THPVariable*
>(obj)->cdata;
47 args.vars.push_back(var);
48 args.desc.metadata.emplace_back(var);
49 args.desc.structure.push_back(D::Variable);
52 "Only tuples, lists and Variables supported as JIT inputs, but got ";
53 msg += THPUtils_typename(obj);
54 throw std::runtime_error(msg);
60 ParsedArgs flatten(py::handle obj) {
62 args.desc.grad_enabled = autograd::GradMode::is_enabled();
63 flatten_rec(obj.ptr(), args);
70 py::object cast_sequence(std::vector<py::object> objs) {
71 auto num_objs = objs.size();
73 for (
size_t i = 0; i < num_objs; ++i)
74 sequence[i] = std::move(objs[i]);
75 return std::move(sequence);
78 py::object unflatten_rec(
79 ArrayRef<Variable>::iterator& var_it,
80 ArrayRef<Variable>::iterator& var_it_end,
81 std::string::const_iterator& desc_it) {
82 char type = *desc_it++;
83 if (type == D::TupleOpen) {
84 std::vector<py::object> objs;
85 while (*desc_it != D::TupleClose)
86 objs.push_back(unflatten_rec(var_it, var_it_end, desc_it));
88 return cast_sequence<py::tuple>(objs);
89 }
else if (type == D::ListOpen) {
90 std::vector<py::object> objs;
91 while (*desc_it != D::ListClose)
92 objs.push_back(unflatten_rec(var_it, var_it_end, desc_it));
94 return cast_sequence<py::list>(objs);
96 if (var_it == var_it_end)
97 throw std::runtime_error(
"Not enough Variables given to unflatten");
99 return py::reinterpret_steal<py::object>(THPVariable_Wrap(var));
105 PyObject* unflatten(ArrayRef<Variable> vars,
const IODescriptor& desc) {
108 auto vars_it = vars.begin();
109 auto vars_it_end = vars.end();
110 auto desc_it = desc.structure.begin();
111 auto output = unflatten_rec(vars_it, vars_it_end, desc_it);
112 if (vars_it != vars_it_end)
113 throw std::runtime_error(
"Too many Variables given to unflatten");
114 return output.release().ptr();
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Flush-To-Zero and Denormals-Are-Zero mode.