1 #include <torch/csrc/utils/python_arg_parser.h> 3 #include <torch/csrc/Exceptions.h> 4 #include <torch/csrc/Layout.h> 5 #include <torch/csrc/utils/invalid_arguments.h> 6 #include <torch/csrc/utils/python_strings.h> 13 #include <unordered_map> 18 static std::unordered_map<std::string, ParameterType> type_map = {
19 {
"Tensor", ParameterType::TENSOR},
20 {
"Scalar", ParameterType::SCALAR},
21 {
"int64_t", ParameterType::INT64},
22 {
"double", ParameterType::DOUBLE},
23 {
"TensorList", ParameterType::TENSOR_LIST},
24 {
"IntArrayRef", ParameterType::INT_LIST},
25 {
"Generator", ParameterType::GENERATOR},
26 {
"bool", ParameterType::BOOL},
27 {
"Storage", ParameterType::STORAGE},
28 {
"PyObject*", ParameterType::PYOBJECT},
29 {
"ScalarType", ParameterType::SCALARTYPE},
30 {
"Layout", ParameterType::LAYOUT},
31 {
"Device", ParameterType::DEVICE},
32 {
"std::string", ParameterType::STRING},
43 static bool should_allow_numbers_as_tensors(
const std::string& name) {
44 static std::unordered_set<std::string> allowed = {
45 "add",
"add_",
"add_out",
46 "div",
"div_",
"div_out",
47 "mul",
"mul_",
"mul_out",
48 "sub",
"sub_",
"sub_out",
50 return allowed.find(name) != allowed.end();
54 FunctionParameter::FunctionParameter(
const std::string& fmt,
bool keyword_only)
57 , keyword_only(keyword_only)
61 auto space = fmt.find(
' ');
62 if (space == std::string::npos) {
63 throw std::runtime_error(
"FunctionParameter(): missing type: " + fmt);
66 auto type_str = fmt.substr(0, space);
68 auto question = type_str.find(
'?');
69 if (question != std::string::npos) {
71 type_str = type_str.substr(0, question);
75 auto bracket = type_str.find(
'[');
76 if (bracket != std::string::npos) {
77 auto size_str = type_str.substr(bracket + 1, type_str.length() - bracket - 2);
78 size = atoi(size_str.c_str());
79 type_str = type_str.substr(0, bracket);
82 auto name_str = fmt.substr(space + 1);
83 auto it = type_map.find(type_str);
84 if (it == type_map.end()) {
85 throw std::runtime_error(
"FunctionParameter(): invalid type string: " + type_str);
89 auto eq = name_str.find(
'=');
90 if (eq != std::string::npos) {
91 name = name_str.substr(0, eq);
93 set_default_str(name_str.substr(eq + 1));
97 #if PY_MAJOR_VERSION == 2 98 python_name = PyString_InternFromString(name.c_str());
100 python_name = PyUnicode_InternFromString(name.c_str());
104 bool FunctionParameter::check(PyObject* obj) {
106 case ParameterType::TENSOR: {
107 return THPVariable_Check(obj) || (allow_numbers_as_tensors && THPUtils_checkDouble(obj));
109 case ParameterType::SCALAR:
110 if (PyComplex_Check(obj)) {
114 case ParameterType::DOUBLE: {
115 if (THPUtils_checkDouble(obj)) {
118 if (THPVariable_Check(obj)) {
120 return !var.requires_grad() && var.dim() == 0;
124 case ParameterType::INT64: {
125 if (THPUtils_checkLong(obj)) {
128 if (THPVariable_Check(obj)) {
130 return at::isIntegralType(var.scalar_type()) && !var.requires_grad() && var.dim() == 0;
134 case ParameterType::TENSOR_LIST:
return six::isTuple(obj) || PyList_Check(obj);
135 case ParameterType::INT_LIST: {
136 if (PyTuple_Check(obj) || PyList_Check(obj)) {
140 return size > 0 && THPUtils_checkLong(obj);
142 case ParameterType::GENERATOR:
return THPGenerator_Check(obj);
143 case ParameterType::BOOL:
return PyBool_Check(obj);
144 case ParameterType::STORAGE:
return isStorage(obj);
145 case ParameterType::PYOBJECT:
return true;
146 case ParameterType::SCALARTYPE:
return THPDtype_Check(obj);
147 case ParameterType::LAYOUT:
return THPLayout_Check(obj);
148 case ParameterType::DEVICE:
149 return THPUtils_checkLong(obj) || THPUtils_checkString(obj) || THPDevice_Check(obj);
150 case ParameterType::STRING:
return THPUtils_checkString(obj);
151 default:
throw std::runtime_error(
"unknown parameter type");
155 std::string FunctionParameter::type_name()
const {
157 case ParameterType::TENSOR:
return "Tensor";
158 case ParameterType::SCALAR:
return "Number";
159 case ParameterType::INT64:
return "int";
160 case ParameterType::DOUBLE:
return "float";
161 case ParameterType::TENSOR_LIST:
return "tuple of Tensors";
162 case ParameterType::INT_LIST:
return "tuple of ints";
163 case ParameterType::GENERATOR:
return "torch.Generator";
164 case ParameterType::BOOL:
return "bool";
165 case ParameterType::STORAGE:
return "torch.Storage";
166 case ParameterType::PYOBJECT:
return "object";
167 case ParameterType::SCALARTYPE:
return "torch.dtype";
168 case ParameterType::LAYOUT:
return "torch.layout";
169 case ParameterType::DEVICE:
return "torch.device";
170 case ParameterType::STRING:
return "str";
171 default:
throw std::runtime_error(
"unknown parameter type");
179 long ans = strtol(s.c_str(), &str_end, 0);
191 static inline std::vector<int64_t> parse_intlist_args(
const std::string& s, int64_t size) {
194 if (s.empty())
return std::vector<int64_t>();
198 return std::vector<int64_t>(size, std::stol(s));
204 AT_CHECK(s[n - 1] ==
'}',
"Default value of IntArrayRef is missing right brace '}', found ", s[n - 1]);
206 auto args = std::vector<int64_t>();
207 std::istringstream ss(s.substr(1, s.length() - 2));
210 while(std::getline(ss, tok,
',')) {
211 args.emplace_back(std::stol(tok));
216 void FunctionParameter::set_default_str(
const std::string& str) {
220 if (type_ == ParameterType::TENSOR) {
222 throw std::runtime_error(
"default value for Tensor must be none, got: " + str);
224 }
else if (type_ == ParameterType::INT64) {
225 default_int = atol(str.c_str());
226 }
else if (type_ == ParameterType::BOOL) {
227 default_bool = (str ==
"True" || str ==
"true");
228 }
else if (type_ == ParameterType::DOUBLE) {
229 default_double = atof(str.c_str());
230 }
else if (type_ == ParameterType::SCALAR) {
233 const auto as_integer = parse_as_integer(str);
234 default_scalar = as_integer.has_value() ?
at::Scalar(as_integer.value()) :
237 }
else if (type_ == ParameterType::INT_LIST) {
239 default_intlist = parse_intlist_args(str, size);
241 }
else if (type_ == ParameterType::SCALARTYPE) {
243 default_scalartype = at::ScalarType::Undefined;
244 }
else if (str ==
"torch.int64") {
245 default_scalartype = at::ScalarType::Long;
247 throw std::runtime_error(
"invalid default value for ScalarType: " + str);
249 }
else if (type_ == ParameterType::LAYOUT) {
251 default_layout =
nullptr;
252 }
else if (str ==
"torch.strided") {
253 default_layout = torch::getLayout(at::Backend::CPU);
254 }
else if (str ==
"torch.sparse_coo") {
255 default_layout = torch::getLayout(at::Backend::SparseCPU);
257 throw std::runtime_error(
"invalid default value for layout: " + str);
259 }
else if (type_ == ParameterType::DEVICE) {
261 throw std::runtime_error(
"invalid device: " + str);
263 }
else if (type_ == ParameterType::STRING) {
264 if (str !=
"None" || str !=
"") {
265 throw std::runtime_error(
"invalid default string: " + str);
270 FunctionSignature::FunctionSignature(
const std::string& fmt)
277 auto open_paren = fmt.find(
'(');
278 if (open_paren == std::string::npos) {
279 throw std::runtime_error(
"missing opening parenthesis: " + fmt);
281 name = fmt.substr(0, open_paren);
283 bool allow_numbers_as_tensors = should_allow_numbers_as_tensors(name);
285 auto last_offset = open_paren + 1;
286 auto next_offset = last_offset;
287 bool keyword_only =
false;
290 auto offset = fmt.find(
", ", last_offset);
291 if (offset == std::string::npos) {
292 offset = fmt.find(
')', last_offset);
294 next_offset = offset + 1;
296 next_offset = offset + 2;
298 if (offset == std::string::npos) {
299 throw std::runtime_error(
"missing closing parenthesis: " + fmt);
301 if (offset == last_offset) {
305 auto param_str = fmt.substr(last_offset, offset - last_offset);
306 last_offset = next_offset;
307 if (param_str ==
"*") {
310 params.emplace_back(param_str, keyword_only);
311 params.back().allow_numbers_as_tensors = allow_numbers_as_tensors;
315 if (fmt.substr(last_offset) ==
"|deprecated") {
319 }
else if (fmt.substr(last_offset) ==
"|hidden") {
323 max_args = params.size();
326 for (
auto& param : params) {
327 if (!param.optional) {
330 if (!param.keyword_only) {
336 std::string FunctionSignature::toString()
const {
337 std::ostringstream ss;
340 for (
auto& param : params) {
344 ss << param.type_name() <<
" " << param.name;
352 static void extra_args(
const FunctionSignature& signature, ssize_t nargs) {
353 auto max_pos_args = signature.max_pos_args;
354 auto min_args = signature.min_args;
355 if (min_args != max_pos_args) {
356 throw TypeError(
"%s() takes from %d to %d positional arguments but %d were given",
357 signature.name.c_str(), min_args, max_pos_args, nargs);
359 throw TypeError(
"%s() takes %d positional argument%s but %d %s given",
360 signature.name.c_str(),
361 max_pos_args, max_pos_args == 1 ?
"" :
"s",
362 nargs, nargs == 1 ?
"was" :
"were");
366 static void missing_args(
const FunctionSignature& signature,
int idx) {
368 std::stringstream ss;
370 auto& params = signature.params;
371 for (
auto it = params.begin() + idx; it != params.end(); ++it) {
373 if (num_missing > 0) {
376 ss <<
'"' << it->name <<
'"';
381 throw TypeError(
"%s() missing %d required positional argument%s: %s",
382 signature.name.c_str(),
384 num_missing == 1 ?
"s" :
"",
388 static ssize_t find_param(FunctionSignature& signature, PyObject* name) {
390 for (
auto& param : signature.params) {
391 int cmp = PyObject_RichCompareBool(name, param.python_name, Py_EQ);
403 static void extra_kwargs(FunctionSignature& signature, PyObject* kwargs, ssize_t num_pos_args) {
404 PyObject *key, *value;
407 while (PyDict_Next(kwargs, &pos, &key, &value)) {
408 if (!THPUtils_checkString(key)) {
409 throw TypeError(
"keywords must be strings");
412 auto param_idx = find_param(signature, key);
414 throw TypeError(
"%s() got an unexpected keyword argument '%s'",
415 signature.name.c_str(), THPUtils_unpackString(key).c_str());
418 if (param_idx < num_pos_args) {
419 throw TypeError(
"%s() got multiple values for argument '%s'",
420 signature.name.c_str(), THPUtils_unpackString(key).c_str());
425 throw TypeError(
"invalid keyword arguments");
428 bool FunctionSignature::parse(PyObject* args, PyObject* kwargs, PyObject* dst[],
429 bool raise_exception) {
430 auto nargs = PyTuple_GET_SIZE(args);
431 ssize_t remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0;
433 bool allow_varargs_intlist =
false;
437 if (max_pos_args == 1 && params[0].type_ == ParameterType::INT_LIST) {
438 allow_varargs_intlist =
true;
441 if (nargs > max_pos_args && !allow_varargs_intlist) {
442 if (raise_exception) {
444 extra_args(*
this, nargs);
450 for (
auto& param : params) {
451 PyObject* obj =
nullptr;
453 if (arg_pos < nargs) {
455 if (param.keyword_only) {
456 if (raise_exception) {
457 extra_args(*
this, nargs);
461 obj = PyTuple_GET_ITEM(args, arg_pos);
463 obj = PyDict_GetItem(kwargs, param.python_name);
467 if ((!obj && param.optional) || (obj == Py_None && param.allow_none)) {
470 if (raise_exception) {
472 missing_args(*
this, i);
475 }
else if (param.check(obj)) {
480 }
else if (allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
481 THPUtils_checkIndex(obj)) {
487 }
else if (raise_exception) {
490 throw TypeError(
"%s(): argument '%s' must be %s, not %s",
491 name.c_str(), param.name.c_str(), param.type_name().c_str(),
492 Py_TYPE(obj)->tp_name);
495 throw TypeError(
"%s(): argument '%s' (position %d) must be %s, not %s",
496 name.c_str(), param.name.c_str(), arg_pos + 1,
497 param.type_name().c_str(), Py_TYPE(obj)->tp_name);
510 if (remaining_kwargs > 0) {
511 if (raise_exception) {
513 extra_kwargs(*
this, kwargs, nargs);
521 PythonArgParser::PythonArgParser(std::vector<std::string> fmts,
bool traceable)
523 , traceable(traceable)
525 for (
auto& fmt : fmts) {
526 signatures_.emplace_back(fmt);
528 for (
auto& signature : signatures_) {
529 if (signature.max_args > max_args) {
530 max_args = signature.max_args;
533 if (signatures_.size() > 0) {
534 function_name = signatures_[0].name;
538 PythonArgs PythonArgParser::raw_parse(PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) {
539 if (signatures_.size() == 1) {
540 auto& signature = signatures_[0];
541 signature.parse(args, kwargs, parsed_args,
true);
542 return PythonArgs(0, traceable, signature, parsed_args);
546 for (
auto& signature : signatures_) {
547 if (signature.parse(args, kwargs, parsed_args,
false)) {
548 return PythonArgs(i, traceable, signature, parsed_args);
553 print_error(args, kwargs, parsed_args);
556 void PythonArgParser::print_error(PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) {
557 auto num_args = PyTuple_GET_SIZE(args) + (kwargs ? PyDict_Size(kwargs) : 0);
558 std::vector<int> plausible_idxs;
560 for (
auto& signature : signatures_) {
561 if (num_args >= signature.min_args && num_args <= signature.max_args && !signature.hidden) {
562 plausible_idxs.push_back(i);
567 if (plausible_idxs.size() == 1) {
568 auto& signature = signatures_[plausible_idxs[0]];
569 signature.parse(args, kwargs, parsed_args,
true);
572 std::vector<std::string> options;
573 for (
auto& signature : signatures_) {
574 if (!signature.hidden) {
575 options.push_back(signature.toString());
579 auto msg = torch::format_invalid_args(args, kwargs, function_name +
"()", options);
580 throw TypeError(
"%s", msg.c_str());
Scalar represents a 0-dimensional tensor which contains a single element.