Caffe2 - C++ API
A deep learning, cross platform ML framework
python_arg_parser.cpp
1 #include <torch/csrc/utils/python_arg_parser.h>
2 
3 #include <torch/csrc/Exceptions.h>
4 #include <torch/csrc/Layout.h>
5 #include <torch/csrc/utils/invalid_arguments.h>
6 #include <torch/csrc/utils/python_strings.h>
7 
8 #include <ATen/ATen.h>
9 
10 #include <sstream>
11 #include <stdexcept>
12 #include <string>
13 #include <unordered_map>
14 #include <vector>
15 
16 namespace torch {
17 
18 static std::unordered_map<std::string, ParameterType> type_map = {
19  {"Tensor", ParameterType::TENSOR},
20  {"Scalar", ParameterType::SCALAR},
21  {"int64_t", ParameterType::INT64},
22  {"double", ParameterType::DOUBLE},
23  {"TensorList", ParameterType::TENSOR_LIST},
24  {"IntArrayRef", ParameterType::INT_LIST},
25  {"Generator", ParameterType::GENERATOR},
26  {"bool", ParameterType::BOOL},
27  {"Storage", ParameterType::STORAGE},
28  {"PyObject*", ParameterType::PYOBJECT},
29  {"ScalarType", ParameterType::SCALARTYPE},
30  {"Layout", ParameterType::LAYOUT},
31  {"Device", ParameterType::DEVICE},
32  {"std::string", ParameterType::STRING},
33 };
34 
35 // TODO: remove this. This is a temporary list of functions that allow Python
36 // numbers to bind to Tensors. Some binary ops have separate Tensor and Scalar
37 // overloads and binding to the Tensor overload with a number of a different
38 // type will trigger a type error.
39 //
40 // If you modify this, you will need to adjust the blacklist in
41 // tools/pyi/gen_pyi.py (and add hardcoded signatures for these
42 // functions.)
43 static bool should_allow_numbers_as_tensors(const std::string& name) {
44  static std::unordered_set<std::string> allowed = {
45  "add", "add_", "add_out",
46  "div", "div_", "div_out",
47  "mul", "mul_", "mul_out",
48  "sub", "sub_", "sub_out",
49  };
50  return allowed.find(name) != allowed.end();
51 }
52 
53 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
54 FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
55  : optional(false)
56  , allow_none(false)
57  , keyword_only(keyword_only)
58  , size(0)
59  , default_scalar(0)
60 {
61  auto space = fmt.find(' ');
62  if (space == std::string::npos) {
63  throw std::runtime_error("FunctionParameter(): missing type: " + fmt);
64  }
65 
66  auto type_str = fmt.substr(0, space);
67 
68  auto question = type_str.find('?');
69  if (question != std::string::npos) {
70  allow_none = true;
71  type_str = type_str.substr(0, question);
72  }
73 
74  // Parse and remove brackets from type_str
75  auto bracket = type_str.find('[');
76  if (bracket != std::string::npos) {
77  auto size_str = type_str.substr(bracket + 1, type_str.length() - bracket - 2);
78  size = atoi(size_str.c_str());
79  type_str = type_str.substr(0, bracket);
80  }
81 
82  auto name_str = fmt.substr(space + 1);
83  auto it = type_map.find(type_str);
84  if (it == type_map.end()) {
85  throw std::runtime_error("FunctionParameter(): invalid type string: " + type_str);
86  }
87  type_ = it->second;
88 
89  auto eq = name_str.find('=');
90  if (eq != std::string::npos) {
91  name = name_str.substr(0, eq);
92  optional = true;
93  set_default_str(name_str.substr(eq + 1));
94  } else {
95  name = name_str;
96  }
97 #if PY_MAJOR_VERSION == 2
98  python_name = PyString_InternFromString(name.c_str());
99 #else
100  python_name = PyUnicode_InternFromString(name.c_str());
101 #endif
102 }
103 
104 bool FunctionParameter::check(PyObject* obj) {
105  switch (type_) {
106  case ParameterType::TENSOR: {
107  return THPVariable_Check(obj) || (allow_numbers_as_tensors && THPUtils_checkDouble(obj));
108  }
109  case ParameterType::SCALAR:
110  if (PyComplex_Check(obj)) {
111  return true;
112  }
113  // fallthrough
114  case ParameterType::DOUBLE: {
115  if (THPUtils_checkDouble(obj)) {
116  return true;
117  }
118  if (THPVariable_Check(obj)) {
119  auto& var = ((THPVariable*)obj)->cdata;
120  return !var.requires_grad() && var.dim() == 0;
121  }
122  return false;
123  }
124  case ParameterType::INT64: {
125  if (THPUtils_checkLong(obj)) {
126  return true;
127  }
128  if (THPVariable_Check(obj)) {
129  auto& var = ((THPVariable*)obj)->cdata;
130  return at::isIntegralType(var.scalar_type()) && !var.requires_grad() && var.dim() == 0;
131  }
132  return false;
133  }
134  case ParameterType::TENSOR_LIST: return six::isTuple(obj) || PyList_Check(obj);
135  case ParameterType::INT_LIST: {
136  if (PyTuple_Check(obj) || PyList_Check(obj)) {
137  return true;
138  }
139  // if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single int
140  return size > 0 && THPUtils_checkLong(obj);
141  }
142  case ParameterType::GENERATOR: return THPGenerator_Check(obj);
143  case ParameterType::BOOL: return PyBool_Check(obj);
144  case ParameterType::STORAGE: return isStorage(obj);
145  case ParameterType::PYOBJECT: return true;
146  case ParameterType::SCALARTYPE: return THPDtype_Check(obj);
147  case ParameterType::LAYOUT: return THPLayout_Check(obj);
148  case ParameterType::DEVICE:
149  return THPUtils_checkLong(obj) || THPUtils_checkString(obj) || THPDevice_Check(obj);
150  case ParameterType::STRING: return THPUtils_checkString(obj);
151  default: throw std::runtime_error("unknown parameter type");
152  }
153 }
154 
155 std::string FunctionParameter::type_name() const {
156  switch (type_) {
157  case ParameterType::TENSOR: return "Tensor";
158  case ParameterType::SCALAR: return "Number";
159  case ParameterType::INT64: return "int";
160  case ParameterType::DOUBLE: return "float";
161  case ParameterType::TENSOR_LIST: return "tuple of Tensors";
162  case ParameterType::INT_LIST: return "tuple of ints";
163  case ParameterType::GENERATOR: return "torch.Generator";
164  case ParameterType::BOOL: return "bool";
165  case ParameterType::STORAGE: return "torch.Storage";
166  case ParameterType::PYOBJECT: return "object";
167  case ParameterType::SCALARTYPE: return "torch.dtype";
168  case ParameterType::LAYOUT: return "torch.layout";
169  case ParameterType::DEVICE: return "torch.device";
170  case ParameterType::STRING: return "str";
171  default: throw std::runtime_error("unknown parameter type");
172  }
173 }
174 
175 static inline c10::optional<int64_t> parse_as_integer(const std::string& s) {
176  if (s.empty())
177  return c10::nullopt;
178  char *str_end;
179  long ans = strtol(s.c_str(), &str_end, 0);
180  // *str_end == 0 if the entire string was parsed as an integer.
181  return (*str_end == 0) ? c10::optional<int64_t>(ans) : c10::nullopt;
182 }
183 
184 /*
185 Parse default value of IntArrayRef declared at native_functions.yaml
186 
187 There are two kinds of default values:
188 1. IntArrayRef[2] x=1 (where size=2, value={1,1}
189 2. IntArrayRef x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be space after comma since native_parse.py uses ', ' to split args)
190 */
191 static inline std::vector<int64_t> parse_intlist_args(const std::string& s, int64_t size) {
192  size_t n = s.size();
193 
194  if (s.empty()) return std::vector<int64_t>();
195 
196  // case 1. s is an int (e.g., s=2)
197  if (s[0] != '{') {
198  return std::vector<int64_t>(size, std::stol(s));
199  }
200 
201  // case 2. s is a list of dims (e.g., s={1,2})
202 
203  // since already checked left brace '{' above, here only checks right brace '}'
204  AT_CHECK(s[n - 1] == '}', "Default value of IntArrayRef is missing right brace '}', found ", s[n - 1]);
205 
206  auto args = std::vector<int64_t>();
207  std::istringstream ss(s.substr(1, s.length() - 2)); // exclude '{' and '}'
208  std::string tok;
209 
210  while(std::getline(ss, tok, ',')) {
211  args.emplace_back(std::stol(tok));
212  }
213  return args;
214 }
215 
216 void FunctionParameter::set_default_str(const std::string& str) {
217  if (str == "None") {
218  allow_none = true;
219  }
220  if (type_ == ParameterType::TENSOR) {
221  if (str != "None") {
222  throw std::runtime_error("default value for Tensor must be none, got: " + str);
223  }
224  } else if (type_ == ParameterType::INT64) {
225  default_int = atol(str.c_str());
226  } else if (type_ == ParameterType::BOOL) {
227  default_bool = (str == "True" || str == "true");
228  } else if (type_ == ParameterType::DOUBLE) {
229  default_double = atof(str.c_str());
230  } else if (type_ == ParameterType::SCALAR) {
231  if (str != "None") {
232  // we sometimes rely on integer-vs-float values, e.g. with arange.
233  const auto as_integer = parse_as_integer(str);
234  default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value()) :
235  at::Scalar(atof(str.c_str()));
236  }
237  } else if (type_ == ParameterType::INT_LIST) {
238  if (str != "None") {
239  default_intlist = parse_intlist_args(str, size);
240  }
241  } else if (type_ == ParameterType::SCALARTYPE) {
242  if (str == "None") {
243  default_scalartype = at::ScalarType::Undefined;
244  } else if (str == "torch.int64") {
245  default_scalartype = at::ScalarType::Long;
246  } else {
247  throw std::runtime_error("invalid default value for ScalarType: " + str);
248  }
249  } else if (type_ == ParameterType::LAYOUT) {
250  if (str == "None") {
251  default_layout = nullptr;
252  } else if (str == "torch.strided") {
253  default_layout = torch::getLayout(at::Backend::CPU);
254  } else if (str == "torch.sparse_coo") {
255  default_layout = torch::getLayout(at::Backend::SparseCPU);
256  } else {
257  throw std::runtime_error("invalid default value for layout: " + str);
258  }
259  } else if (type_ == ParameterType::DEVICE) {
260  if (str != "None") {
261  throw std::runtime_error("invalid device: " + str);
262  }
263  } else if (type_ == ParameterType::STRING) {
264  if (str != "None" || str != "") {
265  throw std::runtime_error("invalid default string: " + str);
266  }
267  }
268 }
269 
270 FunctionSignature::FunctionSignature(const std::string& fmt)
271  : min_args(0)
272  , max_args(0)
273  , max_pos_args(0)
274  , hidden(false)
275  , deprecated(false)
276 {
277  auto open_paren = fmt.find('(');
278  if (open_paren == std::string::npos) {
279  throw std::runtime_error("missing opening parenthesis: " + fmt);
280  }
281  name = fmt.substr(0, open_paren);
282 
283  bool allow_numbers_as_tensors = should_allow_numbers_as_tensors(name);
284 
285  auto last_offset = open_paren + 1;
286  auto next_offset = last_offset;
287  bool keyword_only = false;
288  bool done = false;
289  while (!done) {
290  auto offset = fmt.find(", ", last_offset);
291  if (offset == std::string::npos) {
292  offset = fmt.find(')', last_offset);
293  done = true;
294  next_offset = offset + 1;
295  } else {
296  next_offset = offset + 2;
297  }
298  if (offset == std::string::npos) {
299  throw std::runtime_error("missing closing parenthesis: " + fmt);
300  }
301  if (offset == last_offset) {
302  break;
303  }
304 
305  auto param_str = fmt.substr(last_offset, offset - last_offset);
306  last_offset = next_offset;
307  if (param_str == "*") {
308  keyword_only = true;
309  } else {
310  params.emplace_back(param_str, keyword_only);
311  params.back().allow_numbers_as_tensors = allow_numbers_as_tensors;
312  }
313  }
314 
315  if (fmt.substr(last_offset) == "|deprecated") {
316  hidden = true;
317  // TODO: raise warning when parsing deprecated signatures
318  deprecated = true;
319  } else if (fmt.substr(last_offset) == "|hidden") {
320  hidden = true;
321  }
322 
323  max_args = params.size();
324 
325  // count the number of non-optional args
326  for (auto& param : params) {
327  if (!param.optional) {
328  min_args++;
329  }
330  if (!param.keyword_only) {
331  max_pos_args++;
332  }
333  }
334 }
335 
336 std::string FunctionSignature::toString() const {
337  std::ostringstream ss;
338  ss << "(";
339  int i = 0;
340  for (auto& param : params) {
341  if (i != 0) {
342  ss << ", ";
343  }
344  ss << param.type_name() << " " << param.name;
345  i++;
346  }
347  ss << ")";
348  return ss.str();
349 }
350 
351 [[noreturn]]
352 static void extra_args(const FunctionSignature& signature, ssize_t nargs) {
353  auto max_pos_args = signature.max_pos_args;
354  auto min_args = signature.min_args;
355  if (min_args != max_pos_args) {
356  throw TypeError("%s() takes from %d to %d positional arguments but %d were given",
357  signature.name.c_str(), min_args, max_pos_args, nargs);
358  }
359  throw TypeError("%s() takes %d positional argument%s but %d %s given",
360  signature.name.c_str(),
361  max_pos_args, max_pos_args == 1 ? "" : "s",
362  nargs, nargs == 1 ? "was" : "were");
363 }
364 
365 [[noreturn]]
366 static void missing_args(const FunctionSignature& signature, int idx) {
367  int num_missing = 0;
368  std::stringstream ss;
369 
370  auto& params = signature.params;
371  for (auto it = params.begin() + idx; it != params.end(); ++it) {
372  if (!it->optional) {
373  if (num_missing > 0) {
374  ss << ", ";
375  }
376  ss << '"' << it->name << '"';
377  num_missing++;
378  }
379  }
380 
381  throw TypeError("%s() missing %d required positional argument%s: %s",
382  signature.name.c_str(),
383  num_missing,
384  num_missing == 1 ? "s" : "",
385  ss.str().c_str());
386 }
387 
388 static ssize_t find_param(FunctionSignature& signature, PyObject* name) {
389  ssize_t i = 0;
390  for (auto& param : signature.params) {
391  int cmp = PyObject_RichCompareBool(name, param.python_name, Py_EQ);
392  if (cmp < 0) {
393  throw python_error();
394  } else if (cmp) {
395  return i;
396  }
397  i++;
398  }
399  return -1;
400 }
401 
402 [[noreturn]]
403 static void extra_kwargs(FunctionSignature& signature, PyObject* kwargs, ssize_t num_pos_args) {
404  PyObject *key, *value;
405  ssize_t pos = 0;
406 
407  while (PyDict_Next(kwargs, &pos, &key, &value)) {
408  if (!THPUtils_checkString(key)) {
409  throw TypeError("keywords must be strings");
410  }
411 
412  auto param_idx = find_param(signature, key);
413  if (param_idx < 0) {
414  throw TypeError("%s() got an unexpected keyword argument '%s'",
415  signature.name.c_str(), THPUtils_unpackString(key).c_str());
416  }
417 
418  if (param_idx < num_pos_args) {
419  throw TypeError("%s() got multiple values for argument '%s'",
420  signature.name.c_str(), THPUtils_unpackString(key).c_str());
421  }
422  }
423 
424  // this should never be hit
425  throw TypeError("invalid keyword arguments");
426 }
427 
428 bool FunctionSignature::parse(PyObject* args, PyObject* kwargs, PyObject* dst[],
429  bool raise_exception) {
430  auto nargs = PyTuple_GET_SIZE(args);
431  ssize_t remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0;
432  ssize_t arg_pos = 0;
433  bool allow_varargs_intlist = false;
434 
435  // if there is a single positional IntArrayRef argument, i.e. expand(..), view(...),
436  // allow a var-args style IntArrayRef, so expand(5,3) behaves as expand((5,3))
437  if (max_pos_args == 1 && params[0].type_ == ParameterType::INT_LIST) {
438  allow_varargs_intlist = true;
439  }
440 
441  if (nargs > max_pos_args && !allow_varargs_intlist) {
442  if (raise_exception) {
443  // foo() takes takes 2 positional arguments but 3 were given
444  extra_args(*this, nargs);
445  }
446  return false;
447  }
448 
449  int i = 0;
450  for (auto& param : params) {
451  PyObject* obj = nullptr;
452  bool is_kwd = false;
453  if (arg_pos < nargs) {
454  // extra positional args given after single positional IntArrayRef arg
455  if (param.keyword_only) {
456  if (raise_exception) {
457  extra_args(*this, nargs);
458  }
459  return false;
460  }
461  obj = PyTuple_GET_ITEM(args, arg_pos);
462  } else if (kwargs) {
463  obj = PyDict_GetItem(kwargs, param.python_name);
464  is_kwd = true;
465  }
466 
467  if ((!obj && param.optional) || (obj == Py_None && param.allow_none)) {
468  dst[i++] = nullptr;
469  } else if (!obj) {
470  if (raise_exception) {
471  // foo() missing 1 required positional argument: "b"
472  missing_args(*this, i);
473  }
474  return false;
475  } else if (param.check(obj)) {
476  dst[i++] = obj;
477  // XXX: the Variable check is necessary because sizes become tensors when
478  // tracer is enabled. This behavior easily leads to ambiguities, and we
479  // should avoid having complex signatures that make use of it...
480  } else if (allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
481  THPUtils_checkIndex(obj)) {
482  // take all positional arguments as this parameter
483  // e.g. permute(1, 2, 3) -> permute((1, 2, 3))
484  dst[i++] = args;
485  arg_pos = nargs;
486  continue;
487  } else if (raise_exception) {
488  if (is_kwd) {
489  // foo(): argument 'other' must be str, not int
490  throw TypeError("%s(): argument '%s' must be %s, not %s",
491  name.c_str(), param.name.c_str(), param.type_name().c_str(),
492  Py_TYPE(obj)->tp_name);
493  } else {
494  // foo(): argument 'other' (position 2) must be str, not int
495  throw TypeError("%s(): argument '%s' (position %d) must be %s, not %s",
496  name.c_str(), param.name.c_str(), arg_pos + 1,
497  param.type_name().c_str(), Py_TYPE(obj)->tp_name);
498  }
499  } else {
500  return false;
501  }
502 
503  if (!is_kwd) {
504  arg_pos++;
505  } else if (obj) {
506  remaining_kwargs--;
507  }
508  }
509 
510  if (remaining_kwargs > 0) {
511  if (raise_exception) {
512  // foo() got an unexpected keyword argument "b"
513  extra_kwargs(*this, kwargs, nargs);
514  }
515  return false;
516  }
517 
518  return true;
519 }
520 
521 PythonArgParser::PythonArgParser(std::vector<std::string> fmts, bool traceable)
522  : max_args(0)
523  , traceable(traceable)
524 {
525  for (auto& fmt : fmts) {
526  signatures_.emplace_back(fmt);
527  }
528  for (auto& signature : signatures_) {
529  if (signature.max_args > max_args) {
530  max_args = signature.max_args;
531  }
532  }
533  if (signatures_.size() > 0) {
534  function_name = signatures_[0].name;
535  }
536 }
537 
538 PythonArgs PythonArgParser::raw_parse(PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) {
539  if (signatures_.size() == 1) {
540  auto& signature = signatures_[0];
541  signature.parse(args, kwargs, parsed_args, true);
542  return PythonArgs(0, traceable, signature, parsed_args);
543  }
544 
545  int i = 0;
546  for (auto& signature : signatures_) {
547  if (signature.parse(args, kwargs, parsed_args, false)) {
548  return PythonArgs(i, traceable, signature, parsed_args);
549  }
550  i++;
551  }
552 
553  print_error(args, kwargs, parsed_args);
554 }
555 
556 void PythonArgParser::print_error(PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) {
557  auto num_args = PyTuple_GET_SIZE(args) + (kwargs ? PyDict_Size(kwargs) : 0);
558  std::vector<int> plausible_idxs;
559  ssize_t i = 0;
560  for (auto& signature : signatures_) {
561  if (num_args >= signature.min_args && num_args <= signature.max_args && !signature.hidden) {
562  plausible_idxs.push_back(i);
563  }
564  i++;
565  }
566 
567  if (plausible_idxs.size() == 1) {
568  auto& signature = signatures_[plausible_idxs[0]];
569  signature.parse(args, kwargs, parsed_args, true);
570  }
571 
572  std::vector<std::string> options;
573  for (auto& signature : signatures_) {
574  if (!signature.hidden) {
575  options.push_back(signature.toString());
576  }
577  }
578 
579  auto msg = torch::format_invalid_args(args, kwargs, function_name + "()", options);
580  throw TypeError("%s", msg.c_str());
581 }
582 
583 
584 } // namespace torch
Scalar represents a 0-dimensional tensor which contains a single element.
Definition: Scalar.h:22
Definition: jit_type.h:17