1 #include <torch/csrc/jit/script/schema_matching.h> 2 #include <torch/csrc/jit/operator.h> 3 #include <torch/csrc/jit/script/builtin_functions.h> 4 #include <torch/csrc/jit/script/error_report.h> 10 inline TypePtr unwrapOptional(TypePtr opt_type) {
11 if (
auto unwrap_list_type = opt_type->cast<OptionalType>()) {
12 return unwrap_list_type->getElementType();
17 static inline bool isIntOrFloatUsedAsList(
19 const Argument& arg) {
21 const auto& v_type = value->type();
22 if (v_type != FloatType::get() && v_type != IntType::get())
24 auto arg_type = unwrapOptional(arg.type());
25 auto list_type = arg_type->cast<ListType>();
26 return list_type && list_type->getElementType() == v_type && arg.N();
29 inline bool convertibleToList(
const TypePtr& type,
const TypePtr& list_type_) {
30 auto list_type = list_type_->cast<ListType>();
34 if (type->isSubtypeOf(list_type_)) {
37 if (
auto tuple = type->cast<TupleType>()) {
39 tuple->elements().begin(),
40 tuple->elements().end(),
41 [&](
const TypePtr& t) {
42 return t->isSubtypeOf(list_type->getElementType());
50 Value* tryConvertToType(
51 const SourceRange& loc,
53 const TypePtr& concrete_type,
55 bool allow_conversions) {
56 if (
auto value_tuple = value->type()->cast<TupleType>()) {
59 if (convertibleToList(value->type(), unwrapOptional(concrete_type))) {
60 auto unpacked = createTupleUnpack(value);
62 unwrapOptional(concrete_type)->expect<ListType>()->getElementType();
63 value = graph.insertNode(graph.createList(elem_type, unpacked))->output();
66 if (
auto concrete_tuple = concrete_type->cast<TupleType>()) {
67 if (!value_tuple->isSubtypeOf(concrete_tuple) &&
68 concrete_tuple->elements().size() == value_tuple->elements().size()) {
69 auto unpacked = createTupleUnpack(value);
70 std::vector<Value*> converted;
71 for (
size_t i = 0; i < concrete_tuple->elements().size(); ++i) {
72 converted.emplace_back(tryConvertToType(
75 concrete_tuple->elements().at(i),
79 value = graph.insertNode(graph.createTuple(converted))->output();
84 if (value->type()->isSubtypeOf(NoneType::get()) &&
85 !concrete_type->isSubtypeOf(NoneType::get())) {
86 if (
auto optional_type = concrete_type->cast<OptionalType>()) {
88 graph.insertNode(graph.createNone(optional_type->getElementType()))
93 value = graph.insertNode(graph.createNone(concrete_type))->output();
98 if (allow_conversions) {
99 if (concrete_type->isSubtypeOf(NumberType::get()) &&
100 value->type()->isSubtypeOf(TensorType::get())) {
101 auto n = graph.createImplicitTensorToNum(concrete_type, value);
102 value = graph.insertNode(n)
103 ->setSourceLocation(std::make_shared<SourceRange>(loc))
106 if (value->type()->isSubtypeOf(StringType::get()) &&
107 DeviceObjType::get()->isSubtypeOf(concrete_type)) {
108 return graph.insert(aten::device, {value}, {}, loc);
110 if (concrete_type == FloatType::get() &&
111 value->type() == NumberType::get()) {
112 return graph.insert(prim::Float, {value}, {}, loc);
119 Value* tryMatchArgument(
122 const SourceRange& loc,
123 const NamedValue& named_value,
124 const std::function<std::ostream&()>& err,
125 bool allow_conversions,
127 Value* value = named_value.value(graph);
132 if (isIntOrFloatUsedAsList(value, arg)) {
133 std::vector<Value*> repeated(*arg.N(), value);
135 graph.insertNode(graph.createList(value->type(), repeated))->output();
138 const MatchTypeReturn matched_type =
139 matchTypeVariables(arg.type(), value->type(), type_env);
140 if (!matched_type.type) {
141 err() <<
"could not match type " << value->type()->str() <<
" to " 142 << arg.type()->str() <<
" in argument '" << arg.name()
143 <<
"': " << matched_type.errMsg <<
"\n" 144 << named_value.locOr(loc);
147 const auto concrete_type = *matched_type.type;
149 value = tryConvertToType(loc, graph, concrete_type, value, allow_conversions);
151 if (!value->type()->isSubtypeOf(concrete_type)) {
152 auto& ostream = err() <<
"expected a value of type " << concrete_type->str()
153 <<
" for argument '" << arg.name() <<
"' but found " 154 << value->type()->str() <<
"\n";
156 if (value->type() == NumberType::get() &&
157 value->node()->kind() == aten::item) {
159 <<
"Use int(tensor) or float(tensor) to retrieve item() from a tensor with the appropriate type\n";
161 ostream << named_value.locOr(loc);
168 const std::string& name,
170 for (
size_t i = 0; i < kwargs.
size(); ++i) {
171 if (kwargs[i].name() == name)
177 Value* tryCreateList(
178 const TypePtr& elem_type,
180 const SourceRange& loc,
182 const std::function<std::ostream&()>& err,
183 bool convert_tensor_to_num,
185 Argument elem_arg(
"<varargs>", elem_type);
186 std::vector<Value*> list_ctor;
187 for (
const auto& a : varargs) {
188 Value* av = tryMatchArgument(
189 elem_arg, graph, loc, a, err, convert_tensor_to_num, type_env);
192 list_ctor.push_back(av);
194 return graph.insertNode(graph.createList(elem_type, list_ctor))->output();
198 const FunctionSchema& schema,
199 const SourceRange& loc,
204 std::ostream& failure_messages,
205 bool allow_conversions) {
206 auto err = [&]() -> std::ostream& {
207 failure_messages <<
"\nfor operator " << schema <<
":\n";
208 return failure_messages;
212 std::vector<Value*> positional_inputs;
213 std::vector<bool> used_kwarg(kwargs.
size(),
false);
216 size_t used_args = 0;
217 for (
size_t schema_i = 0; schema_i < schema.arguments().size(); ++schema_i) {
218 const auto& arg = schema.arguments()[schema_i];
220 if (arg.name() ==
"self" &&
self) {
223 }
else if (!arg.kwarg_only() && used_args < args.
size()) {
225 if (allow_conversions &&
226 arg.type()->kind() ==
227 TypeKind::ListType &&
230 (schema_i + 1 == schema.arguments().size() ||
231 schema.arguments()[schema_i + 1]
233 auto actual_type = args[used_args].value(graph)->type();
234 if (actual_type->kind() != TypeKind::ListType &&
237 unwrapOptional(arg.type()))) {
240 unwrapOptional(arg.type())->expect<ListType>()->getElementType();
241 Value* list = tryCreateList(
251 used_args = args.
size();
252 positional_inputs.push_back(list);
259 }
else if (
auto idx = findInputWithName(arg.name(), kwargs)) {
260 const NamedValue& nv = kwargs[*idx];
261 if (used_kwarg[*idx]) {
262 err() <<
"argument " << nv.name()
263 <<
" specified twice in schema, submit a bug report!\n" 267 used_kwarg[*idx] =
true;
269 }
else if (arg.default_value()) {
270 v = NamedValue(*arg.default_value());
272 err() <<
"argument " << schema.arguments()[schema_i].name()
273 <<
" not provided.\n" 278 tryMatchArgument(arg, graph, loc, *v, err, allow_conversions, type_env);
281 positional_inputs.push_back(positional);
284 if (
self != c10::nullopt) {
285 err() <<
"provided self argument not used in schema\n";
288 if (schema.is_vararg()) {
289 for (; used_args < args.
size(); ++used_args) {
290 positional_inputs.push_back(args[used_args].value(graph));
295 if (used_args < args.
size()) {
296 err() <<
"expected at most " << used_args <<
" arguments " 297 <<
"but found " << args.
size() <<
" positional arguments.\n" 302 for (
size_t i = 0; i < kwargs.
size(); ++i) {
303 const auto& nv = kwargs[i];
304 if (!used_kwarg[i]) {
305 if (!schema.argumentIndexWithName(nv.name())) {
306 err() <<
"keyword argument " << nv.name() <<
" unknown\n";
308 err() <<
"keyword argument " << nv.name() <<
" specified twice\n";
314 const auto& returns = schema.returns();
315 auto return_types = fmap(returns, [&](
const Argument& r) {
316 return evalTypeVariables(r.type(), type_env);
320 bool return_has_field_names =
321 std::all_of(returns.begin(), returns.end(), [&](
const Argument& r) {
322 return r.name().length() > 0;
325 if (return_has_field_names) {
327 fmap(returns, [&](
const Argument& r) {
return r.name(); });
329 return MatchedSchema{std::move(positional_inputs),
330 std::move(return_types),
331 std::move(return_field_names)};
340 if (values.
size() == 1) {
343 return g.insertNode(g.createTuple(values, std::move(field_names)))->output();
348 static Value* emitBuiltinNode(
349 const MatchedSchema& matched_schema,
350 const SourceRange& loc,
353 auto n = graph.insertNode(graph.create(name, matched_schema.inputs, 0))
354 ->setSourceLocation(std::make_shared<SourceRange>(loc));
356 for (
auto& ret : matched_schema.return_types) {
357 n->addOutput()->setType(ret);
364 return packOutputs(graph, n->outputs(), matched_schema.return_field_names);
367 static std::string prefixLine(
368 const std::string& str,
369 const std::string& prefix) {
370 std::stringstream ss;
371 bool was_newline =
true;
376 was_newline = c ==
'\n';
383 Value* emitBuiltinCall(
384 const SourceRange& loc,
393 const auto& variants = getAllOperatorsFor(name);
394 const auto& builtin_functions = getAllBuiltinFunctionsFor(name);
396 std::stringstream failure_messages;
399 for (
bool allow_conversions : {
false,
true}) {
401 failure_messages.str(
"");
402 for (
const std::shared_ptr<Operator>& op : variants) {
403 const auto matched_schema = tryMatchSchema(
412 if (matched_schema) {
413 return emitBuiltinNode(*matched_schema, loc, graph, name);
416 for (Method* method : builtin_functions) {
417 if (
auto result = try_emit_call_to(
426 allow_conversions)) {
437 if (variants.size() == 0) {
438 const auto close_symbols = findSimilarOperators(name);
439 auto error = ErrorReport(loc);
440 const auto& user_function_name = name.toQualString();
441 error <<
"unknown builtin op: " << user_function_name <<
"\n";
442 if (close_symbols.size() == 0) {
444 <<
"Could not find any similar ops to " << user_function_name
445 <<
". This op may not exist or may not be currently supported in TorchScript\n";
447 error <<
"Here are some suggestions: \n";
448 for (
const auto& sym : close_symbols) {
449 error <<
"\t" << sym.toQualString() <<
"\n";
455 throw ErrorReport(loc) <<
"arguments for call are not valid:\n" 456 << prefixLine(failure_messages.str(),
" ")
constexpr size_t size() const
size - Get the array size.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...