1 #include <torch/csrc/jit/script/compiler.h> 3 #include <c10/util/Exception.h> 4 #include <torch/csrc/jit/hooks_for_testing.h> 5 #include <torch/csrc/jit/interpreter.h> 6 #include <torch/csrc/jit/ir.h> 7 #include <torch/csrc/jit/operator.h> 8 #include <torch/csrc/jit/passes/constant_pooling.h> 9 #include <torch/csrc/jit/passes/lower_tuples.h> 10 #include <torch/csrc/jit/script/final_returns.h> 11 #include <torch/csrc/jit/script/parser.h> 12 #include <torch/csrc/jit/script/schema_matching.h> 13 #include <torch/csrc/jit/script/script_type_parser.h> 14 #include <torch/csrc/utils/object_ptr.h> 16 #include <torch/csrc/jit/constants.h> 18 #include <c10/util/Optional.h> 27 using SugaredValuePtr = std::shared_ptr<SugaredValue>;
28 using FunctionTable = std::unordered_map<std::string, Method&>;
29 using ValueTable = std::unordered_map<std::string, SugaredValuePtr>;
30 using AttributeMap = std::unordered_map<std::string, Const>;
31 using ListAttributeMap = std::unordered_map<std::string, std::vector<Const>>;
33 using TypeAndRange = std::pair<TypePtr, const SourceRange*>;
39 std::map<std::string, TypeAndRange> mappings_;
41 void setRefinement(
const std::string& name, TypeAndRange mapping) {
42 mappings_[name] = std::move(mapping);
46 const auto& maybe_mapping = mappings_.find(name);
47 if (maybe_mapping == mappings_.end()) {
50 return maybe_mapping->second;
55 void intersectRefinements(
const Refinements& other) {
57 for (
const auto& name_mapping : mappings_) {
58 const auto& name = name_mapping.first;
59 const auto& mapping = name_mapping.second;
60 if (
auto other_mapping = other.getRefinement(name_mapping.first)) {
61 const auto maybe_unified_type =
62 unifyTypes(mapping.first, other_mapping->first);
63 if (maybe_unified_type) {
65 name, TypeAndRange(*maybe_unified_type, mapping.second));
69 mappings_ = std::move(ret.mappings_);
76 for (
const auto& name_mapping : mappings_) {
77 const auto& name = name_mapping.first;
78 const auto& mapping = name_mapping.second;
79 TypePtr t_1 = mapping.first;
80 if (
auto other_mapping = other.getRefinement(name_mapping.first)) {
81 TypePtr t_2 = other_mapping->first;
83 if (t_1->isSubtypeOf(t_2)) {
84 maybe_unified_type = t_1;
85 }
else if (t_2->isSubtypeOf(t_1)) {
86 maybe_unified_type = t_2;
88 if (maybe_unified_type) {
90 name, TypeAndRange(*maybe_unified_type, mapping.second));
93 ret.setRefinement(name, mapping);
97 for (
auto& name_mapping : other.mappings_) {
98 if (!getRefinement(name_mapping.first)) {
99 ret.setRefinement(name_mapping.first, name_mapping.second);
103 mappings_ = std::move(ret.mappings_);
114 : true_refinements_(std::move(true_refinements)),
115 false_refinements_(std::move(false_refinements)){};
126 true_refinements_.intersectRefinements(other.true_refinements_);
127 false_refinements_.unionRefinements(other.false_refinements_);
136 true_refinements_.unionRefinements(other.true_refinements_);
137 false_refinements_.intersectRefinements(other.false_refinements_);
142 static Value* asSimple(
const SugaredValuePtr& value) {
143 if (
SimpleValue* sv = dynamic_cast<SimpleValue*>(value.get())) {
144 return sv->getValue();
151 static bool meaningfulName(
const std::string& name) {
152 if (name.size() == 0)
158 for (
size_t i = 1; i < name.size(); ++i) {
159 if (!isdigit(name[i]))
196 std::shared_ptr<Environment> next =
nullptr)
198 resolver(std::move(resolver)),
200 next(std::move(next)) {}
204 std::vector<std::string> captured_inputs;
205 std::unordered_map<std::string, std::string> error_messages;
208 std::shared_ptr<Environment> next;
212 void setVariableTypeError(
const std::string& name,
const std::string& msg) {
214 while (runner->next) {
215 runner = runner->next.get();
217 runner->error_messages[name] = msg;
223 while (runner->next) {
224 runner = runner->next.get();
226 auto msg = runner->error_messages.find(name);
227 if (msg != runner->error_messages.end()) {
234 SugaredValuePtr findInThisFrame(
const std::string& name) {
235 auto it = value_table.find(name);
236 if (it != value_table.end()) {
242 SugaredValuePtr findInParentFrame(
const std::string& name) {
243 return next ? next->findInAnyFrame(name) :
nullptr;
246 SugaredValuePtr findInAnyFrame(
const std::string& name) {
247 for (
auto runner =
this; runner; runner = runner->next.get()) {
248 if (
auto r = runner->findInThisFrame(name)) {
255 Value* getValueInThisFrame(
const SourceRange& loc,
const std::string& name) {
256 return value_table.at(name)->asValue(loc, method);
259 SugaredValuePtr createCapturedInput(
Value* orig,
const std::string& name) {
263 size_t insert_pos = 0;
264 while (insert_pos < captured_inputs.size() &&
265 name > captured_inputs[insert_pos]) {
268 captured_inputs.insert(captured_inputs.begin() + insert_pos, name);
271 const size_t loop_carried_block_inputs_offset = 1;
273 b->insertInput(loop_carried_block_inputs_offset + insert_pos)
274 ->setType(orig->type());
277 auto sv = std::make_shared<SimpleValue>(new_input);
278 value_table[name] = sv;
283 SugaredValuePtr createCapturedInputIfNeeded(
285 const std::string& ident) {
286 auto in_frame = findInThisFrame(ident);
293 next ? next->createCapturedInputIfNeeded(loc, ident) :
nullptr;
296 if (from_parent && getBlockOwningKind() == prim::Loop) {
297 if (
Value* simple_val = asSimple(from_parent))
298 from_parent = createCapturedInput(simple_val, ident);
306 Symbol getBlockOwningKind() {
308 if (b->owningNode()) {
309 owning_kind = b->owningNode()->kind();
314 void setVar(
const SourceRange& loc,
const std::string& name,
Value* value) {
315 setSugaredVar(loc, name, std::make_shared<SimpleValue>(value));
320 const std::string& name,
321 SugaredValuePtr value) {
322 Value* as_simple_value = asSimple(value);
323 if (as_simple_value && !as_simple_value->hasUniqueName() &&
324 meaningfulName(name) &&
328 as_simple_value->node()->owningBlock() == block()) {
329 as_simple_value->setUniqueName(name);
338 if (
auto parent = findInParentFrame(name)) {
339 if (!as_simple_value) {
341 <<
"Cannot re-assign '" << name <<
"' to a value of type " 342 << value->kind() <<
" because " << name
343 <<
" is not a first-class value. Only reassignments to first-class values are allowed";
345 Value* simple_parent = asSimple(parent);
346 if (!simple_parent) {
348 <<
"Cannot re-assign '" << name <<
"' because it has type " 349 << value->kind() <<
" and " << name
350 <<
" is not a first-class value. Only reassignments to first-class values are allowed";
352 if (!as_simple_value->type()->isSubtypeOf(
353 unshapedType(simple_parent->type()))) {
354 std::stringstream errMsg;
355 errMsg <<
"variable '" << name <<
"' previously has type " 356 << simple_parent->type()->str()
357 <<
" but is now being assigned to a value of type " 358 << as_simple_value->type()->str();
360 if (simple_parent->type()->kind() == TypeKind::ListType &&
361 as_simple_value->type()->kind() == TypeKind::ListType) {
362 errMsg <<
"\n. (Note: empty lists are constructed as Tensor[]; " 363 <<
"if you want an empty list of a different type, " 364 <<
"use `torch.jit.annotate(List[T], [])`, " 365 <<
"where `T` is the type of elements in the list)";
371 createCapturedInputIfNeeded(loc, name);
372 value_table[name] = std::move(value);
375 SugaredValuePtr getSugaredVar(
const Ident& ident,
bool required =
true) {
376 return getSugaredVar(ident.name(), ident.range());
379 return getSugaredVar(ident)->asValue(ident.range(), method);
382 SugaredValuePtr getSugaredVar(
383 const std::string& ident,
385 bool required =
true) {
386 auto retval = createCapturedInputIfNeeded(range, ident);
389 static std::unordered_map<std::string, SugaredValuePtr> globals = {
390 {
"print", std::make_shared<PrintValue>()},
391 {
"float", std::make_shared<CastValue>(FloatType::get(), prim::Float)},
392 {
"int", std::make_shared<CastValue>(IntType::get(), prim::Int)},
393 {
"bool", std::make_shared<CastValue>(BoolType::get(), prim::Bool)},
394 {
"getattr", std::make_shared<GetAttrValue>()},
395 {
"isinstance", std::make_shared<IsInstanceValue>()},
399 std::make_shared<CastValue>(TensorType::get(), prim::NumToTensor)},
400 {
"len", std::make_shared<BuiltinFunction>(aten::len, at::nullopt)},
401 {
"min", std::make_shared<BuiltinFunction>(prim::min, at::nullopt)},
402 {
"max", std::make_shared<BuiltinFunction>(prim::max, at::nullopt)},
403 {
"list", std::make_shared<BuiltinFunction>(aten::list, at::nullopt)},
404 {
"rangelist", std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
406 auto it = globals.find(ident);
407 if (it != globals.end()) {
413 if (
auto class_type = ClassType::get(ident)) {
414 retval = std::make_shared<script::ClassValue>(class_type);
419 retval = resolver(ident, method, range);
422 if (!retval && required) {
425 if (
auto msg = findVariableTypeError(ident)) {
426 throw ErrorReport(range) << *msg <<
"and was used here";
428 throw ErrorReport(range) <<
"undefined value " << ident;
434 return getSugaredVar(ident, range)->asValue(range, method);
448 AT_ASSERT(b->inputs().size() == b->outputs().size());
449 AT_ASSERT(b->inputs().size() == captured_inputs.size() + 1);
450 for (
size_t i = b->inputs().size() - 1; i > 0; i--) {
452 if (b->inputs()[i] == b->outputs()[i]) {
453 auto name = captured_inputs[i - 1];
454 Value* orig = findInParentFrame(name)->asValue(loc, method);
455 b->inputs()[i]->replaceAllUsesWith(orig);
458 captured_inputs.erase(captured_inputs.begin() + i - 1);
462 std::vector<std::string> definedVariables() {
463 std::vector<std::string> result;
464 for (
auto& kv : value_table) {
465 result.push_back(kv.first);
471 ValueTable value_table;
475 static Value* materializeConstant(
479 std::unordered_map<T, Value*>& map) {
480 auto existing_constant = map.find(val);
481 if (existing_constant != map.end()) {
482 return existing_constant->second;
486 auto new_constant = graph.insertConstant(val,
nullptr, r);
487 map[val] = new_constant;
493 if (!v->type()->isSubtypeOf(IntType::get())) {
495 <<
"expected a int but found a " << v->type()->str();
500 inline bool isSupportedListElementType(
const TypePtr& type) {
501 return type->isSubtypeOf(TensorType::get()) ||
502 type->isSubtypeOf(NumberType::get());
509 TypePtr declared_return_type_;
510 TypePtr merged_return_type_;
520 graph(method.graph()),
521 resolver(std::move(resolver_)),
522 environment_stack(
nullptr) {
524 pushFrame(graph->block(),
true);
529 if (
self && def.decl().params().size() == 0) {
531 <<
"methods must have a self argument";
534 method.setSchema(emitDef(def,
self, graph->block()));
536 runCleanupPasses(graph);
541 std::shared_ptr<Graph> graph;
543 std::unordered_map<int64_t, Value*> integral_constants;
544 std::unordered_map<double, Value*> fp_constants;
548 std::shared_ptr<Environment> environment_stack;
549 std::vector<DefContext> def_stack_;
551 void pushFrame(
Block* b,
bool starts_def =
false) {
553 def_stack_.emplace_back();
556 std::make_shared<Environment>(method, resolver, b, environment_stack);
558 std::shared_ptr<Environment> popFrame(
bool ends_def =
false) {
559 auto old_frame = environment_stack;
560 environment_stack = environment_stack->next;
562 def_stack_.pop_back();
567 void runCleanupPasses(std::shared_ptr<Graph>& to_clean) {
569 LowerSimpleTuples(to_clean);
570 ConstantPooling(to_clean);
577 auto schema = extractSchemaFromDef(def,
self);
579 if (schema.returns().size() == 1) {
580 def_stack_.back().declared_return_type_ = schema.returns().at(0).type();
582 std::vector<Argument> arguments =
583 emitFormalArguments(def,
self, schema, block);
586 auto stmts_list = moveAllReturnsToEnd(def.statements());
587 emitStatements(stmts_list.begin(), stmts_list.end());
588 std::vector<Argument> returns = {emitOutput(def.range(), schema, block)};
589 return {def.name().name(),
"", std::move(arguments), std::move(returns)};
592 std::vector<IValue> evaluateDefaults(
594 const std::vector<Expr>& default_types,
595 const std::vector<Expr>& default_exprs) {
596 std::vector<IValue> default_values;
597 if (default_exprs.empty())
598 return default_values;
606 auto tuple_type = Subscript::create(
608 Var::create(r, Ident::create(r,
"Tuple")),
610 auto blank_decl = Decl::create(
615 auto ret = Return::create(r, tuple_expr);
616 auto def = Def::create(
618 Ident::create(r,
"defaults"),
621 auto m = std::make_shared<Module>();
622 defineMethodsInModule(m, {def}, {resolver}, c10::nullopt);
624 m->get_method(
"defaults").run(stack);
625 return stack.at(0).toTuple()->elements();
628 std::vector<Argument> parseArgsFromDecl(
631 auto params_begin = decl.params().begin();
632 auto params_end = decl.params().end();
636 std::vector<Argument> retval;
638 std::vector<Expr> default_types;
639 std::vector<Expr> default_exprs;
641 for (
auto it = params_begin; it != params_end; ++it) {
643 auto def = param.defaultValue();
645 default_types.emplace_back(param.type());
646 default_exprs.emplace_back(def.get());
649 auto default_values =
650 evaluateDefaults(decl.range(), default_types, default_exprs);
652 auto defaults_it = default_values.begin();
653 for (
auto it = params_begin; it != params_end; ++it) {
660 if (
auto maybe_broad_list = parseBroadcastList(decl_arg.type())) {
661 type = maybe_broad_list->first;
662 N = maybe_broad_list->second;
664 type = parseTypeFromExpr(decl_arg.type());
668 if (decl_arg.defaultValue().present()) {
669 default_value = *defaults_it++;
672 decl_arg.ident().name(),
676 decl_arg.kwarg_only());
677 retval.push_back(arg);
682 std::vector<Argument> parseReturnFromDecl(
const Decl& decl) {
687 if (!decl.return_type().present())
690 if (parseBroadcastList(decl.return_type().get()))
692 <<
"Broadcastable lists cannot appear as a return type";
693 auto parsed_type = parseTypeFromExpr(decl.return_type().get());
704 const auto name = def.name().name();
705 std::vector<Argument> args = parseArgsFromDecl(def.decl(),
self);
706 std::vector<Argument> returns = parseReturnFromDecl(def.decl());
708 name,
"", std::move(args), std::move(returns),
false,
false);
711 std::vector<Argument> emitFormalArguments(
716 std::vector<Argument> arguments;
718 auto it = def.decl().params().begin();
719 auto end = def.decl().params().end();
720 auto expected_annotation_size = def.decl().params().size();
722 expected_annotation_size--;
724 if (schema.arguments().size() != expected_annotation_size) {
726 <<
"Number of type annotations for" 727 <<
" function parameters (" << schema.arguments().size() <<
")" 728 <<
" does not match the number of parameters on the function (" 729 << expected_annotation_size <<
")!";
733 AT_ASSERT(it != end);
734 const auto& name = (*it).ident().name();
735 if (
auto type = self->asFirstClass()) {
737 block->addInput()->setUniqueName(name)->setType(type);
738 environment_stack->setVar((*it).ident().range(), name, new_input);
739 arguments.emplace_back(name, type);
741 environment_stack->setSugaredVar(def.range(), name,
self->asSugared());
745 size_t arg_annotation_idx = 0;
746 for (; it != end; ++it) {
747 auto& name = (*it).ident().name();
749 Value* new_input = block->addInput();
750 if (meaningfulName(name)) {
751 new_input->setUniqueName(name);
753 environment_stack->setVar((*it).ident().range(), name, new_input);
756 arguments.push_back(schema.arguments().at(arg_annotation_idx++));
757 new_input->setType(arguments.back().type());
767 AT_ASSERT(def_stack_.back().merged_return_type_);
769 Value* result = environment_stack->getVar(
"$return", range);
770 block->registerOutput(result);
771 return Argument(
"", def_stack_.back().merged_return_type_);
774 void emitStatements(
const List<Stmt>& statements) {
775 return emitStatements(statements.begin(), statements.end());
777 std::pair<std::shared_ptr<Graph>,
Value*> lambdaLift(
Block* block) {
778 auto subgraph = std::make_shared<Graph>();
781 graph->insertNode(graph->create(prim::TupleConstruct, {}, 1));
782 Value* context = subgraph->addInput(
"context");
784 Node* unpack_context =
785 subgraph->insertNode(subgraph->create(prim::TupleUnpack, {context}, 0));
787 std::unordered_map<Value*, Value*> captures;
789 auto it = captures.find(v);
790 if (it != captures.end()) {
793 pack_context->addInput(v);
794 Value* r = unpack_context->addOutput()->copyMetadata(v);
798 subgraph->block()->cloneFrom(block, env);
799 auto context_type = TupleType::create(
800 fmap(pack_context->inputs(), [](
Value* v) {
return v->type(); }));
801 pack_context->output()->setType(context_type);
802 context->setType(context_type);
803 return std::make_pair(std::move(subgraph), pack_context->output());
817 void emitClosure(
const Def& def) {
818 Node* closure_node = graph->insertNode(graph->create(prim::Function, 1));
819 closure_node->output()->setType(
822 Block* block = closure_node->addBlock();
825 pushFrame(block,
true);
833 std::shared_ptr<Graph> subgraph;
835 std::tie(subgraph, context) = lambdaLift(block);
836 runCleanupPasses(subgraph);
837 closure_node->eraseBlock(0);
838 closure_node->g_(attr::Subgraph, std::move(subgraph));
840 graph->insertNode(graph->createTuple({closure_node->output(), context}))
842 environment_stack->setVar(def.name().range(), def.name().name(), tup);
845 void emitReturn(
const Return& stmt) {
846 Value* result = emitExpr(stmt.expr());
847 TypePtr result_type = def_stack_.back().declared_return_type_;
853 if (!(result_type->isSubtypeOf(TensorType::get()) &&
854 result->type()->isSubtypeOf(NoneType::get()))) {
855 result = tryConvertToType(
863 if (!result->type()->isSubtypeOf(result_type)) {
865 <<
"Return value was annotated as having type " 866 << result_type->python_str() <<
" but is actually of type " 867 << result->type()->python_str();
870 result_type = def_stack_.back().merged_return_type_;
872 result_type = result->type();
874 if (!unifyTypes(result_type, result->type())) {
876 <<
"Previous return statement returned a value of type " 877 << result_type->python_str()
878 <<
" but this return statement returns a value of type " 879 << result->type()->python_str();
882 AT_ASSERT(result_type);
883 def_stack_.back().merged_return_type_ = result_type;
884 environment_stack->setVar(stmt.range(),
"$return", result);
890 for (; begin != end; ++begin) {
892 switch (stmt.kind()) {
897 emitWhile(
While(stmt));
903 emitAssignment(
Assign(stmt));
909 for (
auto ident :
Global(stmt).names()) {
910 const auto& name =
Ident(ident).name();
911 environment_stack->setVar(
912 ident.range(), name, graph->addInput(name));
917 emitSugaredExpr(expr, 0);
920 emitRaise(
Raise(stmt).range());
932 emitClosure(
Def(stmt));
936 <<
"Unrecognized statement kind " << kindToString(stmt.kind());
941 std::shared_ptr<Environment> emitSingleIfBranch(
947 insertRefinements(refinements);
948 emitStatements(branch);
953 return graph->create(kind, n_outputs)
954 ->setSourceLocation(std::make_shared<SourceRange>(loc));
958 const auto& bool_info = findRefinements(expr.cond());
959 Value* cond_value = emitCond(expr.cond());
960 auto true_expr = [&] {
961 insertRefinements(bool_info.true_refinements_);
962 return emitExpr(expr.true_expr());
964 auto false_expr = [&] {
965 insertRefinements(bool_info.false_refinements_);
966 return emitExpr(expr.false_expr());
968 return emitIfExpr(expr.range(), cond_value, true_expr, false_expr);
973 for (
const auto& name_mappings : ref.mappings_) {
974 const std::string& name = name_mappings.first;
975 auto type = name_mappings.second.first;
976 const auto& range = *name_mappings.second.second;
977 Value* v = environment_stack->getVar(name, range);
978 if (type != NoneType::get()) {
979 Value* output = graph->insert(prim::unchecked_unwrap_optional, {v});
980 environment_stack->setVar(range, name, output);
986 Value* emitShortCircuitIf(
988 const TreeRef& first_expr,
989 const TreeRef& second_expr,
991 const auto first_bool_info = findRefinements(first_expr);
992 Value* first_value = emitCond(
Expr(first_expr));
999 first_expr_refinements = &first_bool_info.true_refinements_;
1000 second_expr_refinements = &first_bool_info.false_refinements_;
1002 first_expr_refinements = &first_bool_info.false_refinements_;
1003 second_expr_refinements = &first_bool_info.true_refinements_;
1006 auto get_first_expr = [&] {
1007 insertRefinements(*first_expr_refinements);
1011 auto get_second_expr = [&] {
1012 insertRefinements(*second_expr_refinements);
1013 return emitCond(
Expr(second_expr));
1019 return emitIfExpr(loc, first_value, get_first_expr, get_second_expr);
1021 return emitIfExpr(loc, first_value, get_second_expr, get_first_expr);
1028 std::function<
Value*()> true_expr,
1029 std::function<
Value*()> false_expr) {
1030 Node* n = graph->insertNode(create(prim::If, range, 0));
1032 n->addInput(cond_value);
1033 auto* true_block = n->addBlock();
1034 auto* false_block = n->addBlock();
1036 auto emit_if_expr = [
this](
Block* b, std::function<Value*()> expr_value) {
1039 Value* out_val = expr_value();
1040 b->registerOutput(out_val);
1044 emit_if_expr(true_block, std::move(true_expr));
1045 emit_if_expr(false_block, std::move(false_expr));
1047 auto true_type = true_block->outputs().at(0)->type();
1048 auto false_type = false_block->outputs().at(0)->type();
1049 auto unified = unifyTypes(true_type, false_type);
1052 <<
"if-expression's true branch has type " << true_type->str()
1053 <<
" but false branch has type " << false_type->str();
1057 auto expr_value = n->addOutput()->setType(*unified);
1063 Value* v = emitExpr(cond);
1064 if (!v->type()->isSubtypeOf(BoolType::get())) {
1066 error <<
"expected a boolean expression for condition but found " 1067 << v->type()->str();
1068 if (v->type()->isSubtypeOf(TensorType::get())) {
1069 error <<
", to use a tensor in a boolean" 1070 <<
" expression, explicitly cast it with `bool()`";
1077 void emitIfElseBlocks(
Value* cond_value,
const If& stmt) {
1078 Node* n = graph->insertNode(create(prim::If, stmt.range(), 0));
1079 n->addInput(cond_value);
1080 const auto bool_info = findRefinements(stmt.cond());
1081 auto* true_block = n->addBlock();
1082 auto* false_block = n->addBlock();
1085 auto save_true = emitSingleIfBranch(
1086 true_block, stmt.trueBranch(), bool_info.true_refinements_);
1087 auto save_false = emitSingleIfBranch(
1088 false_block, stmt.falseBranch(), bool_info.false_refinements_);
1114 std::set<std::string> mutated_variables;
1116 for (
auto& v : save_true->definedVariables()) {
1117 if (save_false->findInAnyFrame(v)) {
1118 mutated_variables.insert(v);
1121 for (
auto& v : save_false->definedVariables()) {
1122 if (save_true->findInAnyFrame(v)) {
1123 mutated_variables.insert(v);
1128 for (
const auto& x : mutated_variables) {
1129 auto tv = save_true->getVar(x, stmt.range());
1130 auto fv = save_false->getVar(x, stmt.range());
1131 auto unified = unifyTypes(tv->type(), fv->type());
1146 error <<
"Type mismatch: " << x <<
" is set to type " 1147 << tv->type()->str() <<
" in the true branch" 1148 <<
" and type " << fv->type()->str() <<
" in the false branch";
1149 if (save_true->findInParentFrame(x) ||
1150 save_false->findInParentFrame(x)) {
1156 save_true->setVariableTypeError(x, error.what());
1160 true_block->registerOutput(tv);
1161 false_block->registerOutput(fv);
1162 environment_stack->setVar(
1163 stmt.range(), x, n->addOutput()->setType(*unified));
1167 void emitIf(
const If& stmt) {
1171 Expr cond = stmt.cond();
1173 if (cond.kind() != TK_IS && cond.kind() != TK_ISNOT) {
1175 Value* cond_value = emitCond(cond);
1176 emitIfElseBlocks(cond_value, stmt);
1181 auto cond_op =
BinOp(cond);
1182 SugaredValuePtr lhs_val = emitSugaredExpr(cond_op.lhs(), 1);
1183 SugaredValuePtr rhs_val = emitSugaredExpr(cond_op.rhs(), 1);
1186 cond.kind() == TK_IS ? stmt.trueBranch() : stmt.falseBranch();
1188 cond.kind() == TK_IS ? stmt.falseBranch() : stmt.trueBranch();
1190 auto lhs_none = lhs_val->isNone();
1191 auto rhs_none = rhs_val->isNone();
1199 if (lhs_none == ALWAYS && rhs_none == ALWAYS) {
1201 emitStatements(always_none_branch);
1203 (lhs_none == ALWAYS && rhs_none == NEVER) ||
1204 (lhs_none == NEVER && rhs_none == ALWAYS)) {
1206 emitStatements(never_none_branch);
1210 auto lhs_range = cond_op.lhs().get()->range();
1211 auto rhs_range = cond_op.rhs().get()->range();
1213 auto kind = getNodeKind(cond.kind(), cond.get()->trees().size());
1214 Value* cond_value = emitBuiltinCall(
1215 cond.get()->range(),
1219 {lhs_val->asValue(lhs_range, method),
1220 rhs_val->asValue(rhs_range, method)},
1223 emitIfElseBlocks(cond_value, stmt);
1243 void emitLoopCommon(
1249 bool in_list =
false) {
1250 Node* n = graph->insertNode(create(prim::Loop, range, 0));
1251 Value *max_trip_count_val, *cond_val;
1254 if (max_trip_count) {
1256 auto listArg = emitExpr(max_trip_count.value());
1258 max_trip_count_val = emitBuiltinCall(
1259 max_trip_count->range(),
1267 max_trip_count_val = ensureInt(
1268 max_trip_count->range(), emitExpr(max_trip_count.value()));
1271 max_trip_count_val = materializeConstant(
1272 std::numeric_limits<int64_t>::max(),
1275 integral_constants);
1278 cond_val = emitCond(cond.value());
1280 cond_val = graph->insertConstant(
true,
nullptr, range);
1283 n->addInput(max_trip_count_val);
1284 n->addInput(cond_val);
1285 auto* body_block = n->addBlock();
1287 body_block->addInput()->setType(IntType::get());
1290 pushFrame(body_block);
1295 auto listArg = emitExpr(max_trip_count.value());
1296 trip_count = emitBuiltinCall(
1297 max_trip_count->range(),
1301 {listArg, trip_count},
1305 environment_stack->setVar(
1306 itr_ident->range(), itr_ident->name(), trip_count);
1308 emitStatements(body);
1312 Value* body_cond_value = emitCond(cond.value());
1313 body_block->registerOutput(body_cond_value);
1315 Value* cond_value_dummy = graph->insertConstant(
true,
nullptr, range);
1316 body_block->registerOutput(cond_value_dummy);
1319 auto body_frame = popFrame();
1320 auto outer_frame = environment_stack;
1324 for (
const auto& x : body_frame->captured_inputs) {
1325 auto fv = body_frame->getValueInThisFrame(range, x);
1326 body_block->registerOutput(fv);
1331 body_frame->deleteExtraInputs(range);
1334 for (
size_t i = 0; i < body_frame->captured_inputs.size(); ++i) {
1335 auto x = body_frame->captured_inputs[i];
1336 n->addInput(outer_frame->getVar(x, range));
1339 auto typ = body_block->inputs()[i + 1]->type();
1340 outer_frame->setVar(range, x, n->addOutput()->setType(typ));
1347 const Ident& target,
1351 if (args.size() != 1) {
1353 <<
"range() expects 1 argument but got " << args.size();
1355 emitLoopCommon(range, {args[0]}, {}, body, target);
1358 void emitFor(
const For& stmt) {
1360 auto targets = stmt.targets();
1361 auto itrs = stmt.itrs();
1362 auto body = stmt.body();
1364 if (stmt.itrs().size() != 1) {
1366 <<
"List of iterables is not supported currently.";
1368 if (targets.size() != 1) {
1370 <<
"Iteration variable unpacking is not supported";
1373 if (targets[0].kind() != TK_VAR) {
1375 <<
"unexpected expression in variable initialization of for loop";
1377 auto target =
Var(targets[0]).name();
1381 if (itrs[0].kind() == TK_APPLY) {
1383 if (range_iterator.callee().kind() == TK_VAR) {
1384 Var var =
Var(range_iterator.callee());
1385 if (var.name().name() ==
"range") {
1386 return emitForRange(
1387 stmt.range(), target, range_iterator.inputs(), body);
1394 auto sv = emitSugaredExpr(itrs[0], 1);
1396 if (
auto siv = std::dynamic_pointer_cast<SimpleValue>(sv)) {
1397 if (siv->getValue()->type()->kind() == TypeKind::ListType) {
1398 return emitLoopCommon(
1399 stmt.range(), {itrs[0]}, {}, body, {target},
true);
1402 auto instances = sv->asTuple(stmt.range(), method);
1403 const std::string& target_name = target.name();
1404 pushFrame(environment_stack->block());
1405 for (
const auto& inst : instances) {
1406 environment_stack->setSugaredVar(itrs[0].range(), target_name, inst);
1407 emitStatements(body);
1410 for (
const auto& n : environment_stack->definedVariables()) {
1411 if (environment_stack->findInParentFrame(n)) {
1412 environment_stack->next->setVar(
1413 stmt.range(), n, environment_stack->getVar(n, stmt.range()));
1419 void emitWhile(
const While& stmt) {
1420 auto cond = stmt.cond();
1421 emitLoopCommon(stmt.range(), {}, {cond}, stmt.body(), {});
1437 const std::string exception =
"Exception";
1438 auto string_input = insertConstant(*graph, exception,
nullptr, loc);
1439 graph->insert(prim::RaiseException, {string_input}, {}, loc);
1442 void emitAssert(
const Assert& stmt) {
1443 Value* cond_value = emitCond(stmt.test());
1444 Node* n = graph->insertNode(create(prim::If, stmt.range(), 0));
1446 n->addInput(cond_value);
1448 auto* false_block = n->addBlock();
1451 pushFrame(false_block);
1453 emitRaise(stmt.range());
1466 size_t num_normal_assign = 0;
1467 size_t num_starred = 0;
1468 for (
const auto& assignee : lhs) {
1469 if (assignee.kind() == TK_VAR || assignee.kind() == TK_SUBSCRIPT) {
1470 num_normal_assign++;
1471 }
else if (assignee.kind() == TK_STARRED) {
1474 throw ErrorReport(assignee) <<
"lhs of assignment must be a variable, " 1475 <<
"subscript, or starred expression.";
1479 if (num_starred > 1) {
1481 <<
"Only one starred expression is allowed on the lhs.";
1484 if (num_starred > 0 && num_normal_assign == 0) {
1485 throw ErrorReport(r) <<
"A Starred expression may only appear on the " 1486 <<
"lhs within the presence of another non-starred" 1497 switch (stmt.aug_op()) {
1499 return isTensor ? aten::add_ : aten::add;
1501 return isTensor ? aten::sub_ : aten::sub;
1503 return isTensor ? aten::div_ : aten::div;
1505 return isTensor ? aten::mul_ : aten::mul;
1508 <<
"Unknown augmented assignment: " << kindToString(stmt.aug_op());
1513 void emitAugAssignment(
const AugAssign& stmt) {
1514 switch (stmt.lhs().kind()) {
1516 emitAugAssignmentToVar(stmt);
1519 emitAugAssignmentToSelectVar(stmt);
1521 case TK_SUBSCRIPT: {
1522 emitAugAssignmentToSubscript(stmt);
1526 <<
"unexpected expression on " 1527 <<
"left-hand side of augmented assignment.";
1545 void emitAugAssignmentToSelectVar(
const AugAssign& stmt) {
1546 const auto lhs =
Select(stmt.lhs());
1547 const auto lhsSugaredVar =
1548 environment_stack->getSugaredVar(
Var(lhs.value()).name());
1549 const auto lhsValue =
1550 lhsSugaredVar->attr(lhs.range(), method, lhs.selector().name())
1551 ->asValue(lhs.range(), method);
1552 if (lhsValue->type()->isSubtypeOf(TensorType::get())) {
1555 const auto rhs =
NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
1556 const auto self =
NamedValue(stmt.lhs().range(),
"self", lhsValue);
1560 getAugOp(stmt,
true),
1568 <<
"left-hand side of augmented assignment to module " 1569 <<
"parameters/buffers can only be tensor types";
1573 void emitAugAssignmentToVar(
const AugAssign& stmt) {
1574 const auto lhs =
Var(stmt.lhs());
1575 const auto lhsValue = environment_stack->getSugaredVar(lhs.name())
1576 ->asValue(lhs.range(), method);
1577 if (lhsValue->type()->isSubtypeOf(TensorType::get())) {
1579 const auto rhs =
NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
1580 const auto self =
NamedValue(stmt.lhs().range(),
"self", lhsValue);
1581 const auto output = emitBuiltinCall(
1584 getAugOp(stmt,
true),
1590 environment_stack->setVar(lhs.range(), lhs.name().name(), output);
1594 Ident lhs =
Var(stmt.lhs()).name();
1595 Expr expr = BinOp::create(
1598 Var::create(lhs.range(), lhs),
1600 environment_stack->setVar(lhs.range(), lhs.name(), emitExpr(expr));
1604 void emitAugAssignmentToSubscript(
const AugAssign& stmt) {
1607 const auto sliceable = emitExpr(lhs.value());
1609 if (sliceable->type()->isSubtypeOf(TensorType::get())) {
1612 std::vector<Value*> tensorIndices;
1614 std::tie(sliced, tensorIndices) = emitIntAndSliceIndexing(
1615 lhs.range(), sliceable, lhs.subscript_exprs());
1617 const auto slicedArg =
NamedValue(stmt.lhs().range(),
"self", sliced);
1618 const auto rhs =
NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
1619 if (tensorIndices.size() == 0) {
1625 getAugOp(stmt,
true),
1633 const auto indices = graph
1634 ->insertNode(graph->createList(
1635 OptionalType::ofTensor(), tensorIndices))
1637 const auto indexed =
1638 graph->insert(aten::index, {slicedArg, indices}, {}, stmt.range());
1639 const auto augmented = emitBuiltinCall(
1642 getAugOp(stmt,
true),
1649 {slicedArg, indices, augmented},
1657 const auto listType = sliceable->type()->cast<
ListType>();
1658 AT_ASSERT(listType !=
nullptr);
1661 listType->getElementType()->isSubtypeOf(TensorType::get());
1664 const auto subscriptExprs = lhs.subscript_exprs();
1665 if (subscriptExprs.size() != 1) {
1667 <<
"Sliced expression not yet supported for" 1668 <<
" subscripted list augmented assignment. " 1669 <<
"File a bug if you want this.";
1671 const auto idxValue = emitExpr(subscriptExprs[0]);
1673 const auto listArg =
NamedValue(lhs.value().range(),
"list", sliceable);
1674 const auto idxArg =
NamedValue(subscriptExprs.range(),
"idx", idxValue);
1675 const auto valueArg =
1676 NamedValue(stmt.rhs().range(),
"value", emitExpr(stmt.rhs()));
1678 const auto getItem =
1679 graph->insert(aten::select, {listArg, idxArg}, {}, stmt.range());
1680 const auto augmentedItem = graph->insert(
1681 getAugOp(stmt, isTensorList), {getItem, valueArg}, {}, stmt.range());
1683 aten::_set_item, {listArg, idxArg, augmentedItem}, {}, stmt.range());
1688 void emitSubscriptAssign(
1692 emitSubscriptAssign(stmtRange, lhs,
NamedValue(rhs.range(), emitExpr(rhs)));
1695 void emitSubscriptAssign(
1700 auto sliceable = emitExpr(lhs.value());
1703 if (sliceable->type()->isSubtypeOf(TensorType::get())) {
1704 std::vector<Value*> tensorIndices;
1711 std::tie(sliced, tensorIndices) = emitIntAndSliceIndexing(
1712 lhs.range(), sliceable, lhs.subscript_exprs());
1714 const auto slicedArg =
NamedValue(lhs.range(), sliced);
1715 if (tensorIndices.size() == 0) {
1718 graph->insert(aten::copy_, {slicedArg, rhs}, {}, stmtRange);
1722 const auto indices = graph
1723 ->insertNode(graph->createList(
1724 OptionalType::ofTensor(), tensorIndices))
1728 aten::index_put_, {slicedArg, indices, rhs}, {}, stmtRange);
1734 const auto subscript = lhs.subscript_exprs();
1735 if (subscript.size() != 1 || subscript[0].kind() == TK_SLICE_EXPR) {
1737 <<
"Sliced expression not yet supported for" 1738 <<
" subscripted list assignment. " 1739 <<
"File a bug if you want this.";
1742 std::vector<NamedValue> args;
1743 args.emplace_back(lhs.value().range(),
"list", sliceable);
1745 lhs.subscript_exprs().range(),
"idx", emitExpr(subscript[0]));
1746 args.push_back(rhs);
1748 graph->insert(aten::_set_item, args, {}, stmtRange);
1753 size_t n_binders = tl.inputs().size();
1754 bool starred_unpack = calcNumStarredUnpack(tl.inputs(), tl.range());
1757 auto output = emitSugaredExpr(rhs, n_binders);
1758 auto outputs = output->asTuple(
1762 if (outputs.size() < n_binders) {
1764 <<
"need " << (starred_unpack ?
"at least " :
"") << n_binders
1765 <<
" values to unpack but found only " << outputs.size();
1767 if (outputs.size() > n_binders && !starred_unpack) {
1768 throw ErrorReport(tl) <<
"too many values to unpack: need " << n_binders
1769 <<
" but found " << outputs.size();
1772 for (
auto assignee : tl.inputs()) {
1773 switch (assignee.kind()) {
1775 emitSubscriptAssign(
1779 rhs.range(), outputs.at(i)->asValue(rhs.range(), method)));
1783 environment_stack->setSugaredVar(
1784 assignee.range(),
Var(assignee).name().name(), outputs.at(i));
1788 auto var =
Starred(assignee).expr();
1789 if (var.kind() != TK_VAR) {
1791 <<
"Cannot pack a tuple into a non-variable.";
1793 size_t n_matched = outputs.size() - n_binders;
1796 outputs_ref.
slice(i, n_matched),
1797 [&](
const std::shared_ptr<SugaredValue>& v) {
1798 return v->asValue(assignee.range(), method);
1800 auto tup = graph->insertNode(graph->createTuple(values))->output();
1801 environment_stack->setVar(var.range(),
Var(var).name().name(), tup);
1806 <<
"unexpected expression on the left-hand side";
1811 void emitAssignment(
const Assign& stmt) {
1812 switch (stmt.lhs().kind()) {
1814 auto v =
Var(stmt.lhs());
1815 environment_stack->setSugaredVar(
1816 v.range(), v.name().name(), emitSugaredExpr(stmt.rhs(), 1));
1818 case TK_TUPLE_LITERAL:
1822 emitSelectAssign(stmt);
1825 emitSubscriptAssign(stmt.range(),
Subscript(stmt.lhs()), stmt.rhs());
1829 <<
"unexpected expression on left-hand side of assignment.";
1833 void emitSelectAssign(
const Assign& stmt) {
1834 const auto lhs =
Select(stmt.lhs());
1835 const auto basename =
Var(lhs.value()).name();
1836 const auto rhsValue =
1837 emitSugaredExpr(stmt.rhs(), 1)->asValue(stmt.rhs().range(), method);
1838 auto userObject = environment_stack->getSugaredVar(basename);
1839 userObject->setAttr(stmt.range(), method, lhs.selector().name(), rhsValue);
1842 NodeKind getNodeKind(
int kind,
int ninputs) {
1848 case TK_UNARY_MINUS:
1855 return aten::matmul;
1857 return prim::Starred;
1861 return aten::remainder;
1875 return aten::__and__;
1877 return aten::__or__;
1879 return aten::__is__;
1881 return aten::__isnot__;
1883 return aten::__not__;
1885 return aten::floordiv;
1887 return aten::__and__;
1889 return aten::__or__;
1891 return aten::__xor__;
1893 throw std::runtime_error(
"unknown kind " + std::to_string(kind));
1897 std::vector<NamedValue> getNamedValues(
1898 const TreeList& trees,
1899 bool maybe_unpack) {
1900 std::vector<NamedValue> values;
1901 for (
const auto& tree : trees) {
1902 if (maybe_unpack && tree->kind() == TK_STARRED) {
1904 auto entries = emitSugaredExpr(starred.expr(), 1)
1905 ->asTuple(starred.range(), method);
1906 for (
const auto& entry : entries) {
1907 values.emplace_back(
1908 tree->range(), entry->asValue(starred.range(), method));
1911 values.emplace_back(tree->range(), emitExpr(
Expr(tree)));
1916 std::vector<NamedValue> getNamedValues(
1918 bool maybe_unpack) {
1919 return getNamedValues(trees.tree()->trees(), maybe_unpack);
1922 std::vector<Value*> getValues(
const TreeList& trees,
bool maybe_unpack) {
1923 return toValues(*graph, getNamedValues(trees, maybe_unpack));
1925 std::vector<Value*> getValues(
const List<Expr>& trees,
bool maybe_unpack) {
1926 return getValues(trees.tree()->trees(), maybe_unpack);
1929 std::vector<NamedValue> emitAttributes(
const List<Attribute>& attributes) {
1930 return fmap(attributes, [&](
const Attribute& attr) {
1932 attr.range(), attr.name().name(), emitExpr(attr.value()));
1937 if (apply.inputs().size() != 2) {
1939 <<
" expected exactly two arguments but found " 1940 << apply.inputs().size();
1942 if (apply.attributes().size() > 0) {
1944 <<
Var(apply.callee()).name().name() <<
" takes no keyword arguments";
1948 std::shared_ptr<SugaredValue> emitApplyExpr(
Apply& apply,
size_t n_binders) {
1949 auto sv = emitSugaredExpr(apply.callee(), 1);
1950 auto loc = apply.callee().range();
1951 if (
auto fork_value = dynamic_cast<ForkValue*>(sv.get())) {
1952 auto& trees = apply.inputs().tree()->trees();
1953 if (trees.size() < 1) {
1954 throw ErrorReport(loc) <<
"Expected at least one argument to fork()";
1957 auto forked = emitSugaredExpr(
Expr(trees[0]), 1);
1958 TreeList sliced_trees(trees.begin() + 1, trees.end());
1959 auto inputs = getNamedValues(sliced_trees,
true);
1960 auto attributes = emitAttributes(apply.attributes());
1961 return emitForkExpr(loc, forked, inputs, attributes);
1962 }
else if (
auto annotate_value = dynamic_cast<AnnotateValue*>(sv.get())) {
1963 checkApplyExpr(apply, loc);
1964 TypePtr type = parseTypeFromExpr(apply.inputs()[0]);
1965 Value* expr = tryConvertToType(
1969 emitExpr(apply.inputs()[1], type),
1978 bool forget_opt_annotate =
1979 opt_type && *opt_type->getElementType() == *type;
1981 if (!forget_opt_annotate && !expr->type()->isSubtypeOf(type)) {
1983 <<
"expected an expression of type " << type->python_str()
1984 <<
" but found " << expr->type()->python_str();
1986 return std::make_shared<SimpleValue>(expr);
1987 }
else if (
auto getattr = dynamic_cast<GetAttrValue*>(sv.get())) {
1988 checkApplyExpr(apply, loc);
1989 auto obj = emitSugaredExpr(apply.inputs()[0], 1);
1990 auto selector = apply.inputs()[1];
1991 if (selector.kind() != TK_STRINGLITERAL) {
1993 <<
"getattr's second argument must be a string literal";
1996 return obj->attr(apply.range(), method, name);
1997 }
else if (
auto isinstance = dynamic_cast<IsInstanceValue*>(sv.get())) {
2001 std::function<bool(Expr, Expr)> isInstanceCheck = [&](
Expr obj,
2003 if (classinfo.kind() == TK_TUPLE_LITERAL) {
2007 if (isInstanceCheck(obj, e)) {
2013 auto type_name = parseBaseTypeName(classinfo);
2016 <<
"type must be a type identifier";
2018 auto val = emitExpr(obj);
2022 if (*type_name ==
"list" && val->type()->cast<
ListType>()) {
2024 }
else if (*type_name ==
"tuple" && val->type()->cast<
TupleType>()) {
2028 <<
"Optional isinstance check is not supported, " 2029 <<
"consider use is/isnot None instead";
2031 TypePtr type = parseTypeFromExpr(classinfo);
2032 if (val->type()->isSubtypeOf(type)) {
2038 checkApplyExpr(apply, loc);
2039 bool is_instance_val =
2040 isInstanceCheck(apply.inputs()[0], apply.inputs()[1]);
2041 return std::make_shared<SimpleValue>(
2042 graph->insertConstant(is_instance_val,
nullptr, loc));
2043 }
else if (
auto classNew = dynamic_cast<ClassNewMethod*>(sv.get())) {
2044 if (apply.inputs().size() != 1) {
2045 throw ErrorReport(loc) <<
"Only one argument to __new__ allowed";
2047 return classNew->createObject(
2048 apply.range(), method,
Var(apply.inputs()[0]).name().name());;
2050 auto inputs = getNamedValues(apply.inputs(),
true);
2051 auto attributes = emitAttributes(apply.attributes());
2052 return sv->call(loc, method, inputs, attributes, n_binders);
2056 BoolInfo findRefinements(
const TreeRef& tree) {
2057 switch (tree->kind()) {
2060 const auto& inputs = tree->trees();
2061 if (inputs.at(0)->kind() == TK_VAR && inputs.at(1)->kind() == TK_NONE) {
2062 const std::string& var_name =
Var(inputs[0]).name().name();
2065 environment_stack->getVar(var_name, inputs[0]->range())->type();
2067 false_info.setRefinement(
2069 TypeAndRange(opt_type->getElementType(), &tree->range()));
2070 true_info.setRefinement(
2071 var_name, TypeAndRange(NoneType::get(), &tree->range()));
2073 if (tree->kind() == TK_IS) {
2074 return BoolInfo(true_info, false_info);
2076 return BoolInfo(false_info, true_info);
2081 const auto& inputs = tree->trees();
2082 auto bool_info = findRefinements(inputs[0]);
2084 bool_info.false_refinements_, bool_info.true_refinements_);
2088 const auto& inputs = tree->trees();
2089 auto first = findRefinements(inputs[0]);
2090 auto second = findRefinements(inputs[1]);
2091 if (tree->kind() == TK_OR) {
2092 return *first.mergeOr(second);
2094 return *first.mergeAnd(second);
2101 Value* emitExpr(
const Expr& tree,
const TypePtr& type_hint =
nullptr) {
2102 return emitSugaredExpr(tree, 1, type_hint)->asValue(tree.range(), method);
2106 if (kind == aten::lt) {
2108 }
else if (kind == aten::le) {
2110 }
else if (kind == aten::gt) {
2112 }
else if (kind == aten::ge) {
2115 throw std::runtime_error(
2116 "reverseComparision: unsupported NodeKind. File a bug");
2126 std::shared_ptr<SugaredValue> emitSugaredExpr(
2129 const TypePtr& type_hint =
nullptr) {
2130 switch (tree.kind()) {
2132 return environment_stack->getSugaredVar(
Var(tree).name());
2134 auto select =
Select(tree);
2135 auto sv = emitSugaredExpr(select.value(), 1);
2136 return sv->attr(select.range(), method, select.selector().name());
2139 auto apply =
Apply(tree);
2140 return emitApplyExpr(apply, n_binders);
2143 return std::make_shared<SimpleValue>(emitSimpleExpr(tree, type_hint));
2147 Value* emitNegate(
const TreeRef& tree) {
2148 const auto& inputs = tree->trees();
2149 auto named_values = getNamedValues(inputs,
false);
2151 auto neg_val = emitBuiltinCall(
2161 auto maybe_constant_input = toIValue(neg_val->node()->input());
2162 if (!maybe_constant_input) {
2165 auto op = getOperation(neg_val->node());
2167 stack.push_back(*maybe_constant_input);
2169 AT_ASSERT(stack.size() == 1);
2170 return graph->insertConstant(stack[0],
nullptr, tree->range());
2174 std::shared_ptr<SugaredValue> emitForkExpr(
2176 const std::shared_ptr<SugaredValue>& forked,
2182 ->insertNode(method.graph()->create(prim::fork, 1))
2183 ->setSourceLocation(std::make_shared<SourceRange>(loc));
2184 auto body_block = fork_node->addBlock();
2190 auto fn_sugared_output = forked->call(loc, method, inputs, attributes, 1);
2191 auto fn_simple_output = fn_sugared_output->asValue(loc, method);
2192 body_block->registerOutput(fn_simple_output);
2193 node_output = fork_node->output()->setType(
2194 FutureType::create(fn_simple_output->type()));
2198 lambdaLiftFork(fork_node);
2200 return std::make_shared<SimpleValue>(node_output);
2203 Value* emitSimpleExpr(
2204 const TreeRef& tree,
2205 const TypePtr& type_hint =
nullptr) {
2206 switch (tree->kind()) {
2226 case TK_FLOOR_DIV: {
2227 const auto& inputs = tree->trees();
2228 auto kind = getNodeKind(tree->kind(), inputs.
size());
2229 auto named_values = getNamedValues(inputs,
false);
2230 return emitBuiltinCall(
2239 case TK_UNARY_MINUS: {
2240 return emitNegate(tree);
2244 const auto& inputs = tree->trees();
2245 return emitShortCircuitIf(
2246 tree->range(), inputs[0], inputs[1], tree->kind() == TK_OR);
2250 <<
"Unexpected starred expansion. File a bug report.";
2253 return emitConst(
Const(tree));
2256 return graph->insertConstant(
true,
nullptr, tree->range());
2259 return graph->insertConstant(
false,
nullptr, tree->range());
2262 return graph->insertConstant(
IValue(),
nullptr, tree->range());
2264 case TK_SUBSCRIPT: {
2270 case TK_STRINGLITERAL: {
2273 case TK_LIST_LITERAL: {
2275 auto values = getValues(ll.inputs(),
true);
2281 TypePtr elem_type = TensorType::get();
2282 if (type_hint && type_hint->kind() == TypeKind::ListType) {
2283 elem_type = type_hint->expect<
ListType>()->getElementType();
2284 }
else if (!values.empty()) {
2285 elem_type = values.at(0)->type();
2291 if (elem_type->isSubtypeOf(TensorType::get())) {
2292 for (
const auto& value : values) {
2293 elem_type = unifyTypes(elem_type, value->type()).value();
2296 for (
auto v : values) {
2297 if (!v->type()->isSubtypeOf(elem_type)) {
2299 <<
"Lists must contain only a single type, expected: " 2300 << *elem_type <<
" but found " << *v->type() <<
" instead";
2304 graph->insertNode(graph->createList(elem_type, values))->output();
2307 case TK_TUPLE_LITERAL: {
2309 auto values = getValues(ll.inputs(),
true);
2310 return graph->insertNode(graph->createTuple(values))->output();
2312 case TK_DICT_LITERAL: {
2314 auto key_trees = dl.key_inputs().tree()->trees();
2315 auto value_trees = dl.value_inputs().tree()->trees();
2316 AT_ASSERT(key_trees.size() == value_trees.size());
2317 std::vector<Value*> keys, values;
2318 for (
size_t i = 0; i < key_trees.size(); ++i) {
2319 keys.push_back(emitExpr(
Expr(key_trees[i])));
2320 values.push_back(emitExpr(
Expr(value_trees[i])));
2323 TypePtr key_type =
nullptr;
2324 TypePtr value_type =
nullptr;
2326 if (type_hint && type_hint->kind() == TypeKind::DictType) {
2327 auto dict_type = type_hint->expect<
DictType>();
2328 key_type = dict_type->getKeyType();
2329 value_type = dict_type->getValueType();
2330 }
else if (!keys.empty()) {
2331 key_type = keys.at(0)->type();
2332 value_type = values.at(0)->type();
2334 key_type = StringType::get();
2335 value_type = TensorType::get();
2337 AT_ASSERT(key_type !=
nullptr && value_type !=
nullptr);
2340 ->insertNode(graph->createDict(key_type, value_type, keys, values))
2344 throw ErrorReport(tree) <<
"Cannot emit expr for: " << tree;
2350 if (c.isFloatingPoint())
2351 return materializeConstant(
2352 c.asFloatingPoint(), *graph, c.range(), fp_constants);
2354 return materializeConstant(
2355 c.asIntegral(), *graph, c.range(), integral_constants);
2359 return insertConstant(*graph, c.text(),
nullptr, c.range());
2368 return emitBuiltinCall(
2373 {input, graph->insertConstant(dim,
nullptr, loc), index},
2385 std::vector<NamedValue> args;
2387 args.emplace_back(loc,
"self", input);
2392 AT_ASSERT(input->type()->isSubtypeOf(TensorType::get()));
2394 loc,
"dim", graph->insertConstant(dim.value(),
nullptr, loc));
2396 AT_ASSERT(!input->type()->isSubtypeOf(TensorType::get()));
2399 args.emplace_back(loc,
"begin", emitExpr(
Expr(slice.startOr(0))));
2400 const auto has_end = slice.end().present();
2402 args.emplace_back(loc,
"end", emitExpr(
Expr(slice.end().get())));
2406 return emitTupleSlice(loc, args[0], args[1], args[2]);
2408 return emitTupleSlice(loc, args[0], args[1], c10::nullopt);
2412 NamedValue(loc,
"step", graph->insertConstant(1,
nullptr, loc));
2413 return emitBuiltinCall(
2414 loc, *graph, aten::slice, c10::nullopt, args, {step},
true);
2425 graph->insertNode(graph->createList(OptionalType::ofTensor(), indices))
2427 return emitBuiltinCall(
2428 loc, *graph, aten::index, c10::nullopt, {input, index}, {},
true);
2438 std::pair<Value*, std::vector<Value*>> emitIntAndSliceIndexing(
2442 std::vector<Value*> tensor_indices;
2445 auto handle_tensor = [&](
Value* tensor) {
2447 tensor_indices.resize(dim + 1);
2448 tensor_indices[dim] = tensor;
2452 for (
const auto& subscript_expr : subscript_exprs) {
2453 if (subscript_expr.kind() == TK_SLICE_EXPR) {
2454 sliceable = emitSlice(loc, sliceable, dim,
SliceExpr(subscript_expr));
2458 auto index = emitExpr(subscript_expr, OptionalType::ofTensor());
2459 if (index->type() == IntType::get()) {
2460 sliceable = emitSelect(loc, sliceable, dim, index);
2462 }
else if (index->type()->isSubtypeOf(OptionalType::ofTensor())) {
2464 handle_tensor(index);
2468 <<
"Unsupported operation: indexing tensor with unsupported index type '" 2469 << index->type()->str()
2470 <<
"'. Only ints, slices, and tensors are supported";
2474 for (
auto& index : tensor_indices) {
2475 if (index ==
nullptr) {
2477 graph->insertNode(graph->createNone(TensorType::get()))->output();
2480 return std::make_pair(sliceable, tensor_indices);
2501 Value* emitMultidimSlicing(
2505 if (!sliceable->type()->isSubtypeOf(TensorType::get())) {
2507 <<
"Unsupported operation: attempted to use multidimensional " 2508 <<
"indexing on a non-tensor type.";
2511 std::vector<Value*> tensor_indices;
2512 std::tie(sliceable, tensor_indices) =
2513 emitIntAndSliceIndexing(loc, sliceable, subscript_exprs);
2515 if (tensor_indices.empty()) {
2520 return emitIndex(loc, sliceable, tensor_indices);
2525 Value* emitBasicSlice(
2529 AT_ASSERT(subscript_exprs.size() == 1);
2530 AT_ASSERT(subscript_exprs[0].kind() == TK_SLICE_EXPR);
2531 auto slice_exp =
SliceExpr(subscript_exprs[0]);
2533 if (sliceable->type()->isSubtypeOf(TensorType::get())) {
2537 return emitSlice(loc, sliceable, maybe_dim, slice_exp);
2540 int64_t getTupleIndexVal(
2542 const TupleTypePtr& tuple_type,
2544 bool allow_out_of_bounds) {
2547 if (ivalue && ivalue->isInt()) {
2548 index = ivalue->to<int64_t>();
2550 throw ErrorReport(loc) <<
"tuple indices must be integer constants";
2553 int64_t adj_index = index;
2554 int64_t tuple_len = tuple_type->elements().size();
2556 adj_index = tuple_len + index;
2558 if (!allow_out_of_bounds && (adj_index >= tuple_len || adj_index < 0)) {
2559 throw ErrorReport(loc) <<
"Tuple index out of range. Tuple is length " 2560 << tuple_len <<
" and index is " << index;
2565 Value* emitTupleIndex(
2569 auto tuple_typ = tuple_val->type()->cast<
TupleType>();
2570 auto adj_index = getTupleIndexVal(
2571 loc, tuple_typ, idx_val,
false);
2572 return graph->insertNode(graph->createTupleIndex(tuple_val, adj_index))
2576 Value* emitDictIndex(
2580 auto dict_type = dict_val->type()->cast<
DictType>();
2581 AT_ASSERT(key_val->type()->isSubtypeOf(dict_type->getKeyType()));
2582 return graph->insertNode(graph->createDictIndex(dict_val, key_val))
2586 Value* emitTupleSlice(
2591 auto tuple_type = tuple_val.value(*graph)->type()->expect<
TupleType>();
2592 int64_t beg = getTupleIndexVal(
2593 loc, tuple_type, beg_val.value(*graph),
true);
2595 int64_t tuple_len = tuple_type->elements().size();
2597 end = getTupleIndexVal(loc, tuple_type, end_val->value(*graph),
true);
2602 end = std::min(std::max((int64_t)0, end), tuple_len);
2603 beg = std::min(std::max((int64_t)0, beg), tuple_len);
2606 ->insertNode(graph->createTupleSlice(tuple_val.value(*graph), beg, end))
2611 return emitSubscript(
2613 emitExpr(subscript.value()),
2614 subscript.subscript_exprs());
2617 Value* emitSubscript(
2621 if (subscript_exprs.size() != 1) {
2622 return emitMultidimSlicing(loc, sliceable, subscript_exprs);
2624 if (subscript_exprs[0].kind() == TK_SLICE_EXPR) {
2625 return emitBasicSlice(loc, sliceable, subscript_exprs);
2627 return emitBasicGather(loc, sliceable, subscript_exprs);
2632 Value* emitBasicGather(
2636 AT_ASSERT(subscript_exprs.size() == 1);
2638 if (gatherable->type()->kind() == TypeKind::ListType) {
2640 auto* idx = emitExpr(subscript_exprs[0]);
2641 return emitBuiltinCall(
2642 loc, *graph, aten::select, c10::nullopt, {gatherable, idx}, {},
true);
2643 }
else if (gatherable->type()->isSubtypeOf(TensorType::get())) {
2644 return emitMultidimSlicing(loc, gatherable, subscript_exprs);
2645 }
else if (
auto tuple_type = gatherable->type()->cast<
TupleType>()) {
2646 auto* idx = emitExpr(subscript_exprs[0]);
2647 return emitTupleIndex(loc, gatherable, idx);
2648 }
else if (
auto dict_type = gatherable->type()->cast<
DictType>()) {
2649 auto* idx = emitExpr(subscript_exprs[0]);
2650 return emitDictIndex(loc, gatherable, idx);
2653 <<
"Indexing only supported on lists, dictionaries, " 2654 "tensors, and tuples, but got type '" 2655 << gatherable->type()->str() <<
"'";
2660 void defineMethodsInModule(
2661 const std::shared_ptr<Module>& m,
2662 const std::vector<Def>& definitions,
2663 const std::vector<Resolver>& resolvers,
2665 AT_ASSERT(definitions.size() == resolvers.size());
2666 auto resolver_it = resolvers.begin();
2667 std::vector<Method*> methods;
2668 std::unordered_map<std::string, Method*> function_table;
2669 for (
const Def& def : definitions) {
2670 const std::string& name = def.name().name();
2671 auto resolver = *resolver_it++;
2672 AT_ASSERT(resolver);
2677 resolver = [resolver, &function_table](
2678 const std::string& name,
2680 const SourceRange& loc) -> std::shared_ptr<SugaredValue> {
2681 auto it = function_table.find(name);
2682 if (it != function_table.end()) {
2683 return std::make_shared<MethodValue>(
nullptr, *it->second);
2685 return resolver(name, m, loc);
2688 auto creator = [def, resolver,
self](
Method& method) {
2689 AT_ASSERT(resolver);
2690 to_ir(def, resolver,
self, method);
2692 Method& method = m->create_method(name, creator);
2693 function_table[name] = &method;
2694 methods.push_back(&method);
2696 for (
Method* method : methods) {
2697 method->ensure_defined();
2699 if (!
self || !self->asFirstClass()) {
2701 didFinishEmitModule(m);
2705 void defineMethodsInModule(
2706 const std::shared_ptr<Module>& m,
2707 const std::string& source,
2708 const Resolver& resolver,
2711 std::vector<Def> definitions;
2712 std::vector<Resolver> resolvers;
2713 while (p.lexer().cur().kind != TK_EOF) {
2714 auto def =
Def(p.parseFunction(
bool(
self)));
2715 definitions.push_back(def);
2716 resolvers.push_back(resolver);
2718 defineMethodsInModule(m, definitions, resolvers,
self);
2721 void lambdaLiftFork(
Node* fork_node) {
2723 auto forked_graph = std::make_shared<Graph>();
2724 auto body_block = fork_node->blocks()[0];
2728 std::unordered_map<Value*, Value*> uncaptures_map;
2730 if (!uncaptures_map.count(v)) {
2732 uncaptures_map[v] = forked_graph->addInput()->copyMetadata(v);
2733 fork_node->addInput(v);
2735 return uncaptures_map[v];
2737 forked_graph->block()->cloneFrom(body_block, env);
2740 fork_node->g_(attr::Subgraph, forked_graph);
2741 fork_node->eraseBlock(0);
AT_CPP14_CONSTEXPR ArrayRef< T > slice(size_t N, size_t M) const
slice(n, m) - Chop off the first N elements of the array, and keep M elements in the array...
constexpr size_t size() const
size - Get the array size.
An utility class for setting temporary insertion points.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...