Caffe2 - C++ API
A deep learning, cross platform ML framework
tuple_parser.cpp
1 #include <torch/csrc/utils/tuple_parser.h>
2 
3 
4 #include <torch/csrc/DynamicTypes.h>
5 #include <torch/csrc/autograd/python_variable.h>
6 #include <torch/csrc/utils/python_strings.h>
7 #include <torch/csrc/utils/python_numbers.h>
8 
9 #include <string>
10 #include <stdexcept>
11 #include <vector>
12 
13 namespace torch {
14 
15 TupleParser::TupleParser(PyObject* args, int num_args) : args(args), idx(0) {
16  int size = (int) PyTuple_GET_SIZE(args);
17  if (num_args >= 0 && size != num_args) {
18  std::string msg("missing required arguments (expected ");
19  msg += std::to_string(num_args) + " got " + std::to_string(size) + ")";
20  throw std::runtime_error(msg);
21  }
22  }
23 
24 auto TupleParser::parse(bool& x, const std::string& param_name) -> void {
25  PyObject* obj = next_arg();
26  if (!PyBool_Check(obj)) {
27  throw invalid_type("bool", param_name);
28  }
29  x = (obj == Py_True);
30 }
31 
32 auto TupleParser::parse(int& x, const std::string& param_name) -> void {
33  PyObject* obj = next_arg();
34  if (!THPUtils_checkLong(obj)) {
35  throw invalid_type("int", param_name);
36  }
37  x = THPUtils_unpackLong(obj);
38 }
39 
40 auto TupleParser::parse(double& x, const std::string& param_name) -> void {
41  PyObject* obj = next_arg();
42  if (!THPUtils_checkDouble(obj)) {
43  throw invalid_type("float", param_name);
44  }
45  x = THPUtils_unpackDouble(obj);
46 }
47 
48 auto TupleParser::parse(std::vector<int>& x, const std::string& param_name) -> void {
49  PyObject* obj = next_arg();
50  if (!PyTuple_Check(obj)) {
51  throw invalid_type("tuple of int", param_name);
52  }
53  int size = PyTuple_GET_SIZE(obj);
54  x.resize(size);
55  for (int i = 0; i < size; ++i) {
56  PyObject* item = PyTuple_GET_ITEM(obj, i);
57  if (!THPUtils_checkLong(item)) {
58  throw invalid_type("tuple of int", param_name);
59  }
60  x[i] = THPUtils_unpackLong(item);
61  }
62 }
63 
64 auto TupleParser::parse(std::string& x, const std::string& param_name) -> void {
65  PyObject* obj = next_arg();
66  if (!THPUtils_checkString(obj)) {
67  throw invalid_type("bytes/str", param_name);
68  }
69  x = THPUtils_unpackString(obj);
70 }
71 
72 auto TupleParser::next_arg() -> PyObject* {
73  if (idx >= PyTuple_GET_SIZE(args)) {
74  throw std::runtime_error("out of range");
75  }
76  return PyTuple_GET_ITEM(args, idx++);
77 }
78 
79 auto TupleParser::invalid_type(const std::string& expected, const std::string& param_name) -> std::runtime_error {
80  std::string msg("argument ");
81  msg += std::to_string(idx - 1);
82  msg += " (";
83  msg += param_name;
84  msg += ") ";
85  msg += "must be ";
86  msg += expected;
87 
88  PyObject* obj = PyTuple_GET_ITEM(args, idx -1);
89  if (PyTuple_Check(obj)){
90  msg += " but got tuple of (";
91  int size = PyTuple_GET_SIZE(obj);
92  for (int i = 0; i < size; ++i) {
93  msg += Py_TYPE(PyTuple_GET_ITEM(obj, i))->tp_name;
94  if (i != size - 1){
95  msg += ", ";
96  }
97  }
98  msg += ")";
99  }
100  else{
101  msg += ", not ";
102  msg += Py_TYPE(PyTuple_GET_ITEM(args, idx - 1))->tp_name;
103  }
104  return std::runtime_error(msg);
105 }
106 
107 } // namespace torch
Definition: jit_type.h:17