1 #include <torch/csrc/utils/invalid_arguments.h> 3 #include <torch/csrc/utils/python_strings.h> 5 #include <torch/csrc/utils/memory.h> 8 #include <unordered_map> 15 std::string py_typename(PyObject *
object) {
16 return Py_TYPE(
object)->tp_name;
20 virtual bool is_matching(PyObject *
object) = 0;
21 virtual ~Type() =
default;
24 struct SimpleType:
public Type {
25 SimpleType(std::string& name): name(name) {};
27 bool is_matching(PyObject *
object)
override {
28 return py_typename(
object) == name;
34 struct MultiType:
public Type {
35 MultiType(std::initializer_list<std::string> accepted_types):
36 types(accepted_types) {};
38 bool is_matching(PyObject *
object)
override {
39 auto it = std::find(types.begin(), types.end(), py_typename(
object));
40 return it != types.end();
43 std::vector<std::string> types;
46 struct NullableType:
public Type {
47 NullableType(std::unique_ptr<Type> type): type(
std::move(type)) {};
49 bool is_matching(PyObject *
object)
override {
50 return object == Py_None || type->is_matching(
object);
53 std::unique_ptr<Type> type;
56 struct TupleType:
public Type {
57 TupleType(std::vector<std::unique_ptr<Type>> types):
58 types(
std::move(types)) {};
60 bool is_matching(PyObject *
object)
override {
61 if (!PyTuple_Check(
object))
return false;
62 auto num_elements = PyTuple_GET_SIZE(
object);
63 if (num_elements != (
long)types.size())
return false;
64 for (
int i = 0; i < num_elements; i++) {
65 if (!types[i]->is_matching(PyTuple_GET_ITEM(
object, i)))
71 std::vector<std::unique_ptr<Type>> types;
74 struct SequenceType:
public Type {
75 SequenceType(std::unique_ptr<Type> type):
76 type(
std::move(type)) {};
78 bool is_matching(PyObject *
object)
override {
79 if (!PySequence_Check(
object))
return false;
80 auto num_elements = PySequence_Length(
object);
81 for (
int i = 0; i < num_elements; i++) {
82 if (!type->is_matching(PySequence_GetItem(
object, i)))
88 std::unique_ptr<Type> type;
92 Argument(std::string name, std::unique_ptr<Type> type):
93 name(
std::move(name)), type(
std::move(type)) {};
96 std::unique_ptr<Type> type;
100 Option(std::vector<Argument> arguments,
bool is_variadic,
bool has_out):
101 arguments(
std::move(arguments)), is_variadic(is_variadic), has_out(has_out) {};
102 Option(
bool is_variadic,
bool has_out):
103 arguments(), is_variadic(is_variadic), has_out(has_out) {};
104 Option(
const Option&) =
delete;
105 Option(Option&& other):
106 arguments(
std::move(other.arguments)), is_variadic(other.is_variadic),
107 has_out(other.has_out) {};
109 std::vector<Argument> arguments;
114 std::vector<std::string> _splitString(
const std::string &s,
const std::string& delim) {
115 std::vector<std::string> tokens;
118 while((end = s.find(delim, start)) != std::string::npos) {
119 tokens.push_back(s.substr(start, end-start));
120 start = end + delim.length();
122 tokens.push_back(s.substr(start));
126 std::unique_ptr<Type> _buildType(std::string type_name,
bool is_nullable) {
127 std::unique_ptr<Type> result;
128 if (type_name ==
"float") {
129 result = torch::make_unique<MultiType>(MultiType{
"float",
"int",
"long"});
130 }
else if (type_name ==
"int") {
131 result = torch::make_unique<MultiType>(MultiType{
"int",
"long"});
132 }
else if (type_name.find(
"tuple[") == 0) {
133 auto type_list = type_name.substr(6);
134 type_list.pop_back();
135 std::vector<std::unique_ptr<Type>> types;
136 for (
auto& type: _splitString(type_list,
","))
137 types.emplace_back(_buildType(type,
false));
138 result = torch::make_unique<TupleType>(std::move(types));
139 }
else if (type_name.find(
"sequence[") == 0) {
140 auto subtype = type_name.substr(9);
142 result = torch::make_unique<SequenceType>(_buildType(subtype,
false));
144 result = torch::make_unique<SimpleType>(type_name);
147 result = torch::make_unique<NullableType>(std::move(result));
151 std::pair<Option, std::string> _parseOption(
const std::string& _option_str,
152 const std::unordered_map<std::string, PyObject*>& kwargs)
154 if (_option_str ==
"no arguments")
155 return std::pair<Option, std::string>(Option(
false,
false), _option_str);
156 bool has_out =
false;
157 std::vector<Argument> arguments;
158 std::string printable_option = _option_str;
159 std::string option_str = _option_str.substr(1, _option_str.length()-2);
162 auto out_pos = printable_option.find(
'#');
163 if (out_pos != std::string::npos) {
164 if (kwargs.count(
"out") > 0) {
165 std::string kwonly_part = printable_option.substr(out_pos+1);
166 printable_option.erase(out_pos);
167 printable_option +=
"*, ";
168 printable_option += kwonly_part;
169 }
else if (out_pos >= 2) {
170 printable_option.erase(out_pos-2);
171 printable_option +=
")";
173 printable_option.erase(out_pos);
174 printable_option +=
")";
179 for (
auto& arg: _splitString(option_str,
", ")) {
180 bool is_nullable =
false;
181 auto type_start_idx = 0;
182 if (arg[type_start_idx] ==
'#') {
185 if (arg[type_start_idx] ==
'[') {
188 arg.erase(arg.length() - std::string(
" or None]").length());
191 auto type_end_idx = arg.find_last_of(
' ');
192 auto name_start_idx = type_end_idx + 1;
196 auto dots_idx = arg.find(
"...");
197 if (dots_idx != std::string::npos)
200 std::string type_name =
201 arg.substr(type_start_idx, type_end_idx-type_start_idx);
203 arg.substr(name_start_idx);
205 arguments.emplace_back(name, _buildType(type_name, is_nullable));
208 bool is_variadic = option_str.find(
"...") != std::string::npos;
209 return std::pair<Option, std::string>(
210 Option(std::move(arguments), is_variadic, has_out),
211 std::move(printable_option)
216 const Option& option,
217 const std::vector<PyObject*>& arguments,
218 const std::unordered_map<std::string, PyObject*>& kwargs)
220 auto num_expected = option.arguments.size();
221 auto num_got = arguments.size() + kwargs.size();
223 if (option.has_out && kwargs.count(
"out") == 0)
225 return num_got == num_expected ||
226 (option.is_variadic && num_got > num_expected);
229 std::string _formattedArgDesc(
230 const Option& option,
231 const std::vector<PyObject*>& arguments,
232 const std::unordered_map<std::string, PyObject*>& kwargs)
235 std::string reset_red;
237 std::string reset_green;
238 if (isatty(1) && isatty(2)) {
240 reset_red =
"\33[0m";
242 reset_green =
"\33[0m";
250 auto num_args = arguments.size() + kwargs.size();
251 std::string result =
"(";
252 for (
size_t i = 0; i < num_args; i++) {
253 bool is_kwarg = i >= arguments.size();
254 PyObject *arg = is_kwarg ? kwargs.at(option.arguments[i].name) : arguments[i];
256 bool is_matching =
false;
257 if (i < option.arguments.size()) {
258 is_matching = option.arguments[i].type->is_matching(arg);
259 }
else if (option.is_variadic) {
260 is_matching = option.arguments.back().type->is_matching(arg);
267 if (is_kwarg) result += option.arguments[i].name +
"=";
268 result += py_typename(arg);
270 result += reset_green;
275 if (arguments.size() > 0)
276 result.erase(result.length()-2);
281 std::string _argDesc(
const std::vector<PyObject *>& arguments,
282 const std::unordered_map<std::string, PyObject *>& kwargs)
284 std::string result =
"(";
285 for (
auto& arg: arguments)
286 result += std::string(py_typename(arg)) +
", ";
287 for (
auto& kwarg: kwargs)
288 result += kwarg.first +
"=" + py_typename(kwarg.second) +
", ";
289 if (arguments.size() > 0)
290 result.erase(result.length()-2);
295 std::vector<std::string> _tryMatchKwargs(
const Option& option,
296 const std::unordered_map<std::string, PyObject*>& kwargs) {
297 std::vector<std::string> unmatched;
298 int start_idx = option.arguments.size() - kwargs.size();
299 if (option.has_out && kwargs.count(
"out") == 0)
303 for (
auto& entry: kwargs) {
305 for (
unsigned int i = start_idx; i < option.arguments.size(); i++) {
306 if (option.arguments[i].name == entry.first) {
312 unmatched.push_back(entry.first);
319 std::string format_invalid_args(
320 PyObject *given_args, PyObject *given_kwargs,
const std::string& function_name,
321 const std::vector<std::string>& options)
323 std::vector<PyObject *> args;
324 std::unordered_map<std::string, PyObject *> kwargs;
325 std::string error_msg;
326 error_msg.reserve(2000);
327 error_msg += function_name;
328 error_msg +=
" received an invalid combination of arguments - ";
330 Py_ssize_t num_args = PyTuple_Size(given_args);
331 for (
int i = 0; i < num_args; i++) {
332 PyObject *arg = PyTuple_GET_ITEM(given_args, i);
336 bool has_kwargs = given_kwargs && PyDict_Size(given_kwargs) > 0;
338 PyObject *key, *value;
341 while (PyDict_Next(given_kwargs, &pos, &key, &value)) {
342 kwargs.emplace(THPUtils_unpackString(key), value);
346 if (options.size() == 1) {
347 auto pair = _parseOption(options[0], kwargs);
348 auto& option = pair.first;
349 auto& option_str = pair.second;
350 std::vector<std::string> unmatched_kwargs;
352 unmatched_kwargs = _tryMatchKwargs(option, kwargs);
353 if (unmatched_kwargs.size()) {
354 error_msg +=
"got unrecognized keyword arguments: ";
355 for (
auto& kwarg: unmatched_kwargs)
356 error_msg += kwarg +
", ";
357 error_msg.erase(error_msg.length()-2);
360 if (_argcountMatch(option, args, kwargs)) {
361 error_msg += _formattedArgDesc(option, args, kwargs);
363 error_msg += _argDesc(args, kwargs);
365 error_msg +=
", but expected ";
366 error_msg += option_str;
370 error_msg += _argDesc(args, kwargs);
371 error_msg +=
", but expected one of:\n";
372 for (
auto &option_str: options) {
373 auto pair = _parseOption(option_str, kwargs);
374 auto& option = pair.first;
375 auto& printable_option_str = pair.second;
377 error_msg += printable_option_str;
379 if (_argcountMatch(option, args, kwargs)) {
380 std::vector<std::string> unmatched_kwargs;
382 unmatched_kwargs = _tryMatchKwargs(option, kwargs);
383 if (unmatched_kwargs.size() > 0) {
384 error_msg +=
" didn't match because some of the keywords were incorrect: ";
385 for (
auto& kwarg: unmatched_kwargs)
386 error_msg += kwarg +
", ";
387 error_msg.erase(error_msg.length()-2);
390 error_msg +=
" didn't match because some of the arguments have invalid types: ";
391 error_msg += _formattedArgDesc(option, args, kwargs);