Caffe2 - C++ API
A deep learning, cross platform ML framework
invalid_arguments.cpp
1 #include <torch/csrc/utils/invalid_arguments.h>
2 
3 #include <torch/csrc/utils/python_strings.h>
4 
5 #include <torch/csrc/utils/memory.h>
6 
7 #include <algorithm>
8 #include <unordered_map>
9 #include <memory>
10 
11 namespace torch {
12 
13 namespace {
14 
15 std::string py_typename(PyObject *object) {
16  return Py_TYPE(object)->tp_name;
17 }
18 
19 struct Type {
20  virtual bool is_matching(PyObject *object) = 0;
21  virtual ~Type() = default;
22 };
23 
24 struct SimpleType: public Type {
25  SimpleType(std::string& name): name(name) {};
26 
27  bool is_matching(PyObject *object) override {
28  return py_typename(object) == name;
29  }
30 
31  std::string name;
32 };
33 
34 struct MultiType: public Type {
35  MultiType(std::initializer_list<std::string> accepted_types):
36  types(accepted_types) {};
37 
38  bool is_matching(PyObject *object) override {
39  auto it = std::find(types.begin(), types.end(), py_typename(object));
40  return it != types.end();
41  }
42 
43  std::vector<std::string> types;
44 };
45 
46 struct NullableType: public Type {
47  NullableType(std::unique_ptr<Type> type): type(std::move(type)) {};
48 
49  bool is_matching(PyObject *object) override {
50  return object == Py_None || type->is_matching(object);
51  }
52 
53  std::unique_ptr<Type> type;
54 };
55 
56 struct TupleType: public Type {
57  TupleType(std::vector<std::unique_ptr<Type>> types):
58  types(std::move(types)) {};
59 
60  bool is_matching(PyObject *object) override {
61  if (!PyTuple_Check(object)) return false;
62  auto num_elements = PyTuple_GET_SIZE(object);
63  if (num_elements != (long)types.size()) return false;
64  for (int i = 0; i < num_elements; i++) {
65  if (!types[i]->is_matching(PyTuple_GET_ITEM(object, i)))
66  return false;
67  }
68  return true;
69  }
70 
71  std::vector<std::unique_ptr<Type>> types;
72 };
73 
74 struct SequenceType: public Type {
75  SequenceType(std::unique_ptr<Type> type):
76  type(std::move(type)) {};
77 
78  bool is_matching(PyObject *object) override {
79  if (!PySequence_Check(object)) return false;
80  auto num_elements = PySequence_Length(object);
81  for (int i = 0; i < num_elements; i++) {
82  if (!type->is_matching(PySequence_GetItem(object, i)))
83  return false;
84  }
85  return true;
86  }
87 
88  std::unique_ptr<Type> type;
89 };
90 
91 struct Argument {
92  Argument(std::string name, std::unique_ptr<Type> type):
93  name(std::move(name)), type(std::move(type)) {};
94 
95  std::string name;
96  std::unique_ptr<Type> type;
97 };
98 
99 struct Option {
100  Option(std::vector<Argument> arguments, bool is_variadic, bool has_out):
101  arguments(std::move(arguments)), is_variadic(is_variadic), has_out(has_out) {};
102  Option(bool is_variadic, bool has_out):
103  arguments(), is_variadic(is_variadic), has_out(has_out) {};
104  Option(const Option&) = delete;
105  Option(Option&& other):
106  arguments(std::move(other.arguments)), is_variadic(other.is_variadic),
107  has_out(other.has_out) {};
108 
109  std::vector<Argument> arguments;
110  bool is_variadic;
111  bool has_out;
112 };
113 
114 std::vector<std::string> _splitString(const std::string &s, const std::string& delim) {
115  std::vector<std::string> tokens;
116  size_t start = 0;
117  size_t end;
118  while((end = s.find(delim, start)) != std::string::npos) {
119  tokens.push_back(s.substr(start, end-start));
120  start = end + delim.length();
121  }
122  tokens.push_back(s.substr(start));
123  return tokens;
124 }
125 
126 std::unique_ptr<Type> _buildType(std::string type_name, bool is_nullable) {
127  std::unique_ptr<Type> result;
128  if (type_name == "float") {
129  result = torch::make_unique<MultiType>(MultiType{"float", "int", "long"});
130  } else if (type_name == "int") {
131  result = torch::make_unique<MultiType>(MultiType{"int", "long"});
132  } else if (type_name.find("tuple[") == 0) {
133  auto type_list = type_name.substr(6);
134  type_list.pop_back();
135  std::vector<std::unique_ptr<Type>> types;
136  for (auto& type: _splitString(type_list, ","))
137  types.emplace_back(_buildType(type, false));
138  result = torch::make_unique<TupleType>(std::move(types));
139  } else if (type_name.find("sequence[") == 0) {
140  auto subtype = type_name.substr(9);
141  subtype.pop_back();
142  result = torch::make_unique<SequenceType>(_buildType(subtype, false));
143  } else {
144  result = torch::make_unique<SimpleType>(type_name);
145  }
146  if (is_nullable)
147  result = torch::make_unique<NullableType>(std::move(result));
148  return result;
149 }
150 
151 std::pair<Option, std::string> _parseOption(const std::string& _option_str,
152  const std::unordered_map<std::string, PyObject*>& kwargs)
153 {
154  if (_option_str == "no arguments")
155  return std::pair<Option, std::string>(Option(false, false), _option_str);
156  bool has_out = false;
157  std::vector<Argument> arguments;
158  std::string printable_option = _option_str;
159  std::string option_str = _option_str.substr(1, _option_str.length()-2);
160 
162  auto out_pos = printable_option.find('#');
163  if (out_pos != std::string::npos) {
164  if (kwargs.count("out") > 0) {
165  std::string kwonly_part = printable_option.substr(out_pos+1);
166  printable_option.erase(out_pos);
167  printable_option += "*, ";
168  printable_option += kwonly_part;
169  } else if (out_pos >= 2) {
170  printable_option.erase(out_pos-2);
171  printable_option += ")";
172  } else {
173  printable_option.erase(out_pos);
174  printable_option += ")";
175  }
176  has_out = true;
177  }
178 
179  for (auto& arg: _splitString(option_str, ", ")) {
180  bool is_nullable = false;
181  auto type_start_idx = 0;
182  if (arg[type_start_idx] == '#') {
183  type_start_idx++;
184  }
185  if (arg[type_start_idx] == '[') {
186  is_nullable = true;
187  type_start_idx++;
188  arg.erase(arg.length() - std::string(" or None]").length());
189  }
190 
191  auto type_end_idx = arg.find_last_of(' ');
192  auto name_start_idx = type_end_idx + 1;
193 
194  // "type ... name" => "type ... name"
195  // ^ ^
196  auto dots_idx = arg.find("...");
197  if (dots_idx != std::string::npos)
198  type_end_idx -= 4;
199 
200  std::string type_name =
201  arg.substr(type_start_idx, type_end_idx-type_start_idx);
202  std::string name =
203  arg.substr(name_start_idx);
204 
205  arguments.emplace_back(name, _buildType(type_name, is_nullable));
206  }
207 
208  bool is_variadic = option_str.find("...") != std::string::npos;
209  return std::pair<Option, std::string>(
210  Option(std::move(arguments), is_variadic, has_out),
211  std::move(printable_option)
212  );
213 }
214 
215 bool _argcountMatch(
216  const Option& option,
217  const std::vector<PyObject*>& arguments,
218  const std::unordered_map<std::string, PyObject*>& kwargs)
219 {
220  auto num_expected = option.arguments.size();
221  auto num_got = arguments.size() + kwargs.size();
222  // Note: variadic functions don't accept kwargs, so it's ok
223  if (option.has_out && kwargs.count("out") == 0)
224  num_expected--;
225  return num_got == num_expected ||
226  (option.is_variadic && num_got > num_expected);
227 }
228 
229 std::string _formattedArgDesc(
230  const Option& option,
231  const std::vector<PyObject*>& arguments,
232  const std::unordered_map<std::string, PyObject*>& kwargs)
233 {
234  std::string red;
235  std::string reset_red;
236  std::string green;
237  std::string reset_green;
238  if (isatty(1) && isatty(2)) {
239  red = "\33[31;1m";
240  reset_red = "\33[0m";
241  green = "\33[32;1m";
242  reset_green = "\33[0m";
243  } else {
244  red = "!";
245  reset_red = "!";
246  green = "";
247  reset_green = "";
248  }
249 
250  auto num_args = arguments.size() + kwargs.size();
251  std::string result = "(";
252  for (size_t i = 0; i < num_args; i++) {
253  bool is_kwarg = i >= arguments.size();
254  PyObject *arg = is_kwarg ? kwargs.at(option.arguments[i].name) : arguments[i];
255 
256  bool is_matching = false;
257  if (i < option.arguments.size()) {
258  is_matching = option.arguments[i].type->is_matching(arg);
259  } else if (option.is_variadic) {
260  is_matching = option.arguments.back().type->is_matching(arg);
261  }
262 
263  if (is_matching)
264  result += green;
265  else
266  result += red;
267  if (is_kwarg) result += option.arguments[i].name + "=";
268  result += py_typename(arg);
269  if (is_matching)
270  result += reset_green;
271  else
272  result += reset_red;
273  result += ", ";
274  }
275  if (arguments.size() > 0)
276  result.erase(result.length()-2);
277  result += ")";
278  return result;
279 }
280 
281 std::string _argDesc(const std::vector<PyObject *>& arguments,
282  const std::unordered_map<std::string, PyObject *>& kwargs)
283 {
284  std::string result = "(";
285  for (auto& arg: arguments)
286  result += std::string(py_typename(arg)) + ", ";
287  for (auto& kwarg: kwargs)
288  result += kwarg.first + "=" + py_typename(kwarg.second) + ", ";
289  if (arguments.size() > 0)
290  result.erase(result.length()-2);
291  result += ")";
292  return result;
293 }
294 
295 std::vector<std::string> _tryMatchKwargs(const Option& option,
296  const std::unordered_map<std::string, PyObject*>& kwargs) {
297  std::vector<std::string> unmatched;
298  int start_idx = option.arguments.size() - kwargs.size();
299  if (option.has_out && kwargs.count("out") == 0)
300  start_idx--;
301  if (start_idx < 0)
302  start_idx = 0;
303  for (auto& entry: kwargs) {
304  bool found = false;
305  for (unsigned int i = start_idx; i < option.arguments.size(); i++) {
306  if (option.arguments[i].name == entry.first) {
307  found = true;
308  break;
309  }
310  }
311  if (!found)
312  unmatched.push_back(entry.first);
313  }
314  return unmatched;
315 }
316 
317 } // anonymous namespace
318 
319 std::string format_invalid_args(
320  PyObject *given_args, PyObject *given_kwargs, const std::string& function_name,
321  const std::vector<std::string>& options)
322 {
323  std::vector<PyObject *> args;
324  std::unordered_map<std::string, PyObject *> kwargs;
325  std::string error_msg;
326  error_msg.reserve(2000);
327  error_msg += function_name;
328  error_msg += " received an invalid combination of arguments - ";
329 
330  Py_ssize_t num_args = PyTuple_Size(given_args);
331  for (int i = 0; i < num_args; i++) {
332  PyObject *arg = PyTuple_GET_ITEM(given_args, i);
333  args.push_back(arg);
334  }
335 
336  bool has_kwargs = given_kwargs && PyDict_Size(given_kwargs) > 0;
337  if (has_kwargs) {
338  PyObject *key, *value;
339  Py_ssize_t pos = 0;
340 
341  while (PyDict_Next(given_kwargs, &pos, &key, &value)) {
342  kwargs.emplace(THPUtils_unpackString(key), value);
343  }
344  }
345 
346  if (options.size() == 1) {
347  auto pair = _parseOption(options[0], kwargs);
348  auto& option = pair.first;
349  auto& option_str = pair.second;
350  std::vector<std::string> unmatched_kwargs;
351  if (has_kwargs)
352  unmatched_kwargs = _tryMatchKwargs(option, kwargs);
353  if (unmatched_kwargs.size()) {
354  error_msg += "got unrecognized keyword arguments: ";
355  for (auto& kwarg: unmatched_kwargs)
356  error_msg += kwarg + ", ";
357  error_msg.erase(error_msg.length()-2);
358  } else {
359  error_msg += "got ";
360  if (_argcountMatch(option, args, kwargs)) {
361  error_msg += _formattedArgDesc(option, args, kwargs);
362  } else {
363  error_msg += _argDesc(args, kwargs);
364  }
365  error_msg += ", but expected ";
366  error_msg += option_str;
367  }
368  } else {
369  error_msg += "got ";
370  error_msg += _argDesc(args, kwargs);
371  error_msg += ", but expected one of:\n";
372  for (auto &option_str: options) {
373  auto pair = _parseOption(option_str, kwargs);
374  auto& option = pair.first;
375  auto& printable_option_str = pair.second;
376  error_msg += " * ";
377  error_msg += printable_option_str;
378  error_msg += "\n";
379  if (_argcountMatch(option, args, kwargs)) {
380  std::vector<std::string> unmatched_kwargs;
381  if (has_kwargs)
382  unmatched_kwargs = _tryMatchKwargs(option, kwargs);
383  if (unmatched_kwargs.size() > 0) {
384  error_msg += " didn't match because some of the keywords were incorrect: ";
385  for (auto& kwarg: unmatched_kwargs)
386  error_msg += kwarg + ", ";
387  error_msg.erase(error_msg.length()-2);
388  error_msg += "\n";
389  } else {
390  error_msg += " didn't match because some of the arguments have invalid types: ";
391  error_msg += _formattedArgDesc(option, args, kwargs);
392  error_msg += "\n";
393  }
394  }
395  }
396  }
397  return error_msg;
398 }
399 
400 
401 } // namespace torch
Definition: jit_type.h:17