1 #include <torch/csrc/jit/script/compiler.h>     3 #include <c10/util/Exception.h>     4 #include <torch/csrc/jit/hooks_for_testing.h>     5 #include <torch/csrc/jit/interpreter.h>     6 #include <torch/csrc/jit/ir.h>     7 #include <torch/csrc/jit/operator.h>     8 #include <torch/csrc/jit/passes/constant_pooling.h>     9 #include <torch/csrc/jit/passes/lower_tuples.h>    10 #include <torch/csrc/jit/script/final_returns.h>    11 #include <torch/csrc/jit/script/parser.h>    12 #include <torch/csrc/jit/script/schema_matching.h>    13 #include <torch/csrc/jit/script/script_type_parser.h>    14 #include <torch/csrc/utils/object_ptr.h>    16 #include <torch/csrc/jit/constants.h>    18 #include <c10/util/Optional.h>    27 using SugaredValuePtr = std::shared_ptr<SugaredValue>;
    28 using FunctionTable = std::unordered_map<std::string, Method&>;
    29 using ValueTable = std::unordered_map<std::string, SugaredValuePtr>;
    30 using AttributeMap = std::unordered_map<std::string, Const>;
    31 using ListAttributeMap = std::unordered_map<std::string, std::vector<Const>>;
    33 using TypeAndRange = std::pair<TypePtr, const SourceRange*>;
    39   std::map<std::string, TypeAndRange> mappings_;
    41   void setRefinement(
const std::string& name, TypeAndRange mapping) {
    42     mappings_[name] = std::move(mapping);
    46     const auto& maybe_mapping = mappings_.find(name);
    47     if (maybe_mapping == mappings_.end()) {
    50     return maybe_mapping->second;
    55   void intersectRefinements(
const Refinements& other) {
    57     for (
const auto& name_mapping : mappings_) {
    58       const auto& name = name_mapping.first;
    59       const auto& mapping = name_mapping.second;
    60       if (
auto other_mapping = other.getRefinement(name_mapping.first)) {
    61         const auto maybe_unified_type =
    62             unifyTypes(mapping.first, other_mapping->first);
    63         if (maybe_unified_type) {
    65               name, TypeAndRange(*maybe_unified_type, mapping.second));
    69     mappings_ = std::move(ret.mappings_);
    76     for (
const auto& name_mapping : mappings_) {
    77       const auto& name = name_mapping.first;
    78       const auto& mapping = name_mapping.second;
    79       TypePtr t_1 = mapping.first;
    80       if (
auto other_mapping = other.getRefinement(name_mapping.first)) {
    81         TypePtr t_2 = other_mapping->first;
    83         if (t_1->isSubtypeOf(t_2)) {
    84           maybe_unified_type = t_1;
    85         } 
else if (t_2->isSubtypeOf(t_1)) {
    86           maybe_unified_type = t_2;
    88         if (maybe_unified_type) {
    90               name, TypeAndRange(*maybe_unified_type, mapping.second));
    93         ret.setRefinement(name, mapping);
    97     for (
auto& name_mapping : other.mappings_) {
    98       if (!getRefinement(name_mapping.first)) {
    99         ret.setRefinement(name_mapping.first, name_mapping.second);
   103     mappings_ = std::move(ret.mappings_);
   114       : true_refinements_(std::move(true_refinements)),
   115         false_refinements_(std::move(false_refinements)){};
   126     true_refinements_.intersectRefinements(other.true_refinements_);
   127     false_refinements_.unionRefinements(other.false_refinements_);
   136     true_refinements_.unionRefinements(other.true_refinements_);
   137     false_refinements_.intersectRefinements(other.false_refinements_);
   142 static Value* asSimple(
const SugaredValuePtr& value) {
   143   if (
SimpleValue* sv = dynamic_cast<SimpleValue*>(value.get())) {
   144     return sv->getValue();
   151 static bool meaningfulName(
const std::string& name) {
   152   if (name.size() == 0)
   158   for (
size_t i = 1; i < name.size(); ++i) {
   159     if (!isdigit(name[i]))
   196       std::shared_ptr<Environment> next = 
nullptr)
   198         resolver(std::move(resolver)),
   200         next(std::move(next)) {}
   204   std::vector<std::string> captured_inputs;
   205   std::unordered_map<std::string, std::string> error_messages;
   208   std::shared_ptr<Environment> next;
   212   void setVariableTypeError(
const std::string& name, 
const std::string& msg) {
   214     while (runner->next) {
   215       runner = runner->next.get();
   217     runner->error_messages[name] = msg;
   223     while (runner->next) {
   224       runner = runner->next.get();
   226     auto msg = runner->error_messages.find(name);
   227     if (msg != runner->error_messages.end()) {
   234   SugaredValuePtr findInThisFrame(
const std::string& name) {
   235     auto it = value_table.find(name);
   236     if (it != value_table.end()) {
   242   SugaredValuePtr findInParentFrame(
const std::string& name) {
   243     return next ? next->findInAnyFrame(name) : 
nullptr;
   246   SugaredValuePtr findInAnyFrame(
const std::string& name) {
   247     for (
auto runner = 
this; runner; runner = runner->next.get()) {
   248       if (
auto r = runner->findInThisFrame(name)) {
   255   Value* getValueInThisFrame(
const SourceRange& loc, 
const std::string& name) {
   256     return value_table.at(name)->asValue(loc, method);
   259   SugaredValuePtr createCapturedInput(
Value* orig, 
const std::string& name) {
   263     size_t insert_pos = 0;
   264     while (insert_pos < captured_inputs.size() &&
   265            name > captured_inputs[insert_pos]) {
   268     captured_inputs.insert(captured_inputs.begin() + insert_pos, name);
   271     const size_t loop_carried_block_inputs_offset = 1;
   273         b->insertInput(loop_carried_block_inputs_offset + insert_pos)
   274             ->setType(orig->type());
   277     auto sv = std::make_shared<SimpleValue>(new_input);
   278     value_table[name] = sv;
   283   SugaredValuePtr createCapturedInputIfNeeded(
   285       const std::string& ident) {
   286     auto in_frame = findInThisFrame(ident);
   293         next ? next->createCapturedInputIfNeeded(loc, ident) : 
nullptr;
   296     if (from_parent && getBlockOwningKind() == prim::Loop) {
   297       if (
Value* simple_val = asSimple(from_parent))
   298         from_parent = createCapturedInput(simple_val, ident);
   306   Symbol getBlockOwningKind() {
   308     if (b->owningNode()) {
   309       owning_kind = b->owningNode()->kind();
   314   void setVar(
const SourceRange& loc, 
const std::string& name, 
Value* value) {
   315     setSugaredVar(loc, name, std::make_shared<SimpleValue>(value));
   320       const std::string& name,
   321       SugaredValuePtr value) {
   322     Value* as_simple_value = asSimple(value);
   323     if (as_simple_value && !as_simple_value->hasUniqueName() &&
   324         meaningfulName(name) &&
   328         as_simple_value->node()->owningBlock() == block()) {
   329       as_simple_value->setUniqueName(name);
   338     if (
auto parent = findInParentFrame(name)) {
   339       if (!as_simple_value) {
   341             << 
"Cannot re-assign '" << name << 
"' to a value of type "   342             << value->kind() << 
" because " << name
   343             << 
" is not a first-class value.  Only reassignments to first-class values are allowed";
   345       Value* simple_parent = asSimple(parent);
   346       if (!simple_parent) {
   348             << 
"Cannot re-assign '" << name << 
"' because it has type "   349             << value->kind() << 
" and " << name
   350             << 
" is not a first-class value.  Only reassignments to first-class values are allowed";
   352       if (!as_simple_value->type()->isSubtypeOf(
   353               unshapedType(simple_parent->type()))) {
   354         std::stringstream errMsg;
   355         errMsg << 
"variable '" << name << 
"' previously has type "   356                << simple_parent->type()->str()
   357                << 
" but is now being assigned to a value of type "   358                << as_simple_value->type()->str();
   360         if (simple_parent->type()->kind() == TypeKind::ListType &&
   361             as_simple_value->type()->kind() == TypeKind::ListType) {
   362           errMsg << 
"\n. (Note: empty lists are constructed as Tensor[]; "   363                  << 
"if you want an empty list of a different type, "   364                  << 
"use `torch.jit.annotate(List[T], [])`, "   365                  << 
"where `T` is the type of elements in the list)";
   371       createCapturedInputIfNeeded(loc, name);
   372     value_table[name] = std::move(value);
   375   SugaredValuePtr getSugaredVar(
const Ident& ident, 
bool required = 
true) {
   376     return getSugaredVar(ident.name(), ident.range());
   379     return getSugaredVar(ident)->asValue(ident.range(), method);
   382   SugaredValuePtr getSugaredVar(
   383       const std::string& ident,
   385       bool required = 
true) {
   386     auto retval = createCapturedInputIfNeeded(range, ident);
   389       static std::unordered_map<std::string, SugaredValuePtr> globals = {
   390           {
"print", std::make_shared<PrintValue>()},
   391           {
"float", std::make_shared<CastValue>(FloatType::get(), prim::Float)},
   392           {
"int", std::make_shared<CastValue>(IntType::get(), prim::Int)},
   393           {
"bool", std::make_shared<CastValue>(BoolType::get(), prim::Bool)},
   394           {
"getattr", std::make_shared<GetAttrValue>()},
   395           {
"isinstance", std::make_shared<IsInstanceValue>()},
   399            std::make_shared<CastValue>(TensorType::get(), prim::NumToTensor)},
   400           {
"len", std::make_shared<BuiltinFunction>(aten::len, at::nullopt)},
   401           {
"min", std::make_shared<BuiltinFunction>(prim::min, at::nullopt)},
   402           {
"max", std::make_shared<BuiltinFunction>(prim::max, at::nullopt)},
   403           {
"list", std::make_shared<BuiltinFunction>(aten::list, at::nullopt)},
   404           {
"rangelist", std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
   406       auto it = globals.find(ident);
   407       if (it != globals.end()) {
   413       if (
auto class_type = ClassType::get(ident)) {
   414         retval = std::make_shared<script::ClassValue>(class_type);
   419       retval = resolver(ident, method, range);
   422     if (!retval && required) {
   425       if (
auto msg = findVariableTypeError(ident)) {
   426         throw ErrorReport(range) << *msg << 
"and was used here";
   428       throw ErrorReport(range) << 
"undefined value " << ident;
   434     return getSugaredVar(ident, range)->asValue(range, method);
   448     AT_ASSERT(b->inputs().size() == b->outputs().size());
   449     AT_ASSERT(b->inputs().size() == captured_inputs.size() + 1);
   450     for (
size_t i = b->inputs().size() - 1; i > 0; i--) {
   452       if (b->inputs()[i] == b->outputs()[i]) {
   453         auto name = captured_inputs[i - 1];
   454         Value* orig = findInParentFrame(name)->asValue(loc, method);
   455         b->inputs()[i]->replaceAllUsesWith(orig);
   458         captured_inputs.erase(captured_inputs.begin() + i - 1);
   462   std::vector<std::string> definedVariables() {
   463     std::vector<std::string> result;
   464     for (
auto& kv : value_table) {
   465       result.push_back(kv.first);
   471   ValueTable value_table;
   475 static Value* materializeConstant(
   479     std::unordered_map<T, Value*>& map) {
   480   auto existing_constant = map.find(val);
   481   if (existing_constant != map.end()) {
   482     return existing_constant->second;
   486   auto new_constant = graph.insertConstant(val, 
nullptr, r);
   487   map[val] = new_constant;
   493   if (!v->type()->isSubtypeOf(IntType::get())) {
   495         << 
"expected a int but found a " << v->type()->str();
   500 inline bool isSupportedListElementType(
const TypePtr& type) {
   501   return type->isSubtypeOf(TensorType::get()) ||
   502       type->isSubtypeOf(NumberType::get());
   509   TypePtr declared_return_type_; 
   510   TypePtr merged_return_type_; 
   520         graph(method.graph()),
   521         resolver(std::move(resolver_)),
   522         environment_stack(
nullptr) {
   524     pushFrame(graph->block(), 
true);
   529     if (
self && def.decl().params().size() == 0) {
   531           << 
"methods must have a self argument";
   534     method.setSchema(emitDef(def, 
self, graph->block()));
   536     runCleanupPasses(graph);
   541   std::shared_ptr<Graph> graph;
   543   std::unordered_map<int64_t, Value*> integral_constants;
   544   std::unordered_map<double, Value*> fp_constants;
   548   std::shared_ptr<Environment> environment_stack;
   549   std::vector<DefContext> def_stack_;
   551   void pushFrame(
Block* b, 
bool starts_def = 
false) {
   553       def_stack_.emplace_back();
   556         std::make_shared<Environment>(method, resolver, b, environment_stack);
   558   std::shared_ptr<Environment> popFrame(
bool ends_def = 
false) {
   559     auto old_frame = environment_stack;
   560     environment_stack = environment_stack->next;
   562       def_stack_.pop_back();
   567   void runCleanupPasses(std::shared_ptr<Graph>& to_clean) {
   569     LowerSimpleTuples(to_clean);
   570     ConstantPooling(to_clean);
   577     auto schema = extractSchemaFromDef(def, 
self);
   579     if (schema.returns().size() == 1) {
   580       def_stack_.back().declared_return_type_ = schema.returns().at(0).type();
   582     std::vector<Argument> arguments =
   583         emitFormalArguments(def, 
self, schema, block);
   586     auto stmts_list = moveAllReturnsToEnd(def.statements());
   587     emitStatements(stmts_list.begin(), stmts_list.end());
   588     std::vector<Argument> returns = {emitOutput(def.range(), schema, block)};
   589     return {def.name().name(), 
"", std::move(arguments), std::move(returns)};
   592   std::vector<IValue> evaluateDefaults(
   594       const std::vector<Expr>& default_types,
   595       const std::vector<Expr>& default_exprs) {
   596     std::vector<IValue> default_values;
   597     if (default_exprs.empty())
   598       return default_values;
   606     auto tuple_type = Subscript::create(
   608         Var::create(r, Ident::create(r, 
"Tuple")),
   610     auto blank_decl = Decl::create(
   615     auto ret = Return::create(r, tuple_expr);
   616     auto def = Def::create(
   618         Ident::create(r, 
"defaults"),
   621     auto m = std::make_shared<Module>();
   622     defineMethodsInModule(m, {def}, {resolver}, c10::nullopt);
   624     m->get_method(
"defaults").run(stack);
   625     return stack.at(0).toTuple()->elements();
   628   std::vector<Argument> parseArgsFromDecl(
   631     auto params_begin = decl.params().begin();
   632     auto params_end = decl.params().end();
   636     std::vector<Argument> retval;
   638     std::vector<Expr> default_types;
   639     std::vector<Expr> default_exprs;
   641     for (
auto it = params_begin; it != params_end; ++it) {
   643       auto def = param.defaultValue();
   645         default_types.emplace_back(param.type());
   646         default_exprs.emplace_back(def.get());
   649     auto default_values =
   650         evaluateDefaults(decl.range(), default_types, default_exprs);
   652     auto defaults_it = default_values.begin();
   653     for (
auto it = params_begin; it != params_end; ++it) {
   660       if (
auto maybe_broad_list = parseBroadcastList(decl_arg.type())) {
   661         type = maybe_broad_list->first;
   662         N = maybe_broad_list->second;
   664         type = parseTypeFromExpr(decl_arg.type());
   668       if (decl_arg.defaultValue().present()) {
   669         default_value = *defaults_it++;
   672           decl_arg.ident().name(),
   676           decl_arg.kwarg_only());
   677       retval.push_back(arg);
   682   std::vector<Argument> parseReturnFromDecl(
const Decl& decl) {
   687     if (!decl.return_type().present())
   690     if (parseBroadcastList(decl.return_type().get()))
   692           << 
"Broadcastable lists cannot appear as a return type";
   693     auto parsed_type = parseTypeFromExpr(decl.return_type().get());
   704     const auto name = def.name().name();
   705     std::vector<Argument> args = parseArgsFromDecl(def.decl(), 
self);
   706     std::vector<Argument> returns = parseReturnFromDecl(def.decl());
   708         name, 
"", std::move(args), std::move(returns), 
false, 
false);
   711   std::vector<Argument> emitFormalArguments(
   716     std::vector<Argument> arguments; 
   718     auto it = def.decl().params().begin();
   719     auto end = def.decl().params().end();
   720     auto expected_annotation_size = def.decl().params().size();
   722       expected_annotation_size--;
   724     if (schema.arguments().size() != expected_annotation_size) {
   726           << 
"Number of type annotations for"   727           << 
" function parameters (" << schema.arguments().size() << 
")"   728           << 
" does not match the number of parameters on the function ("   729           << expected_annotation_size << 
")!";
   733       AT_ASSERT(it != end);
   734       const auto& name = (*it).ident().name();
   735       if (
auto type = self->asFirstClass()) {
   737             block->addInput()->setUniqueName(name)->setType(type);
   738         environment_stack->setVar((*it).ident().range(), name, new_input);
   739         arguments.emplace_back(name, type);
   741         environment_stack->setSugaredVar(def.range(), name, 
self->asSugared());
   745     size_t arg_annotation_idx = 0;
   746     for (; it != end; ++it) {
   747       auto& name = (*it).ident().name();
   749       Value* new_input = block->addInput();
   750       if (meaningfulName(name)) {
   751         new_input->setUniqueName(name);
   753       environment_stack->setVar((*it).ident().range(), name, new_input);
   756       arguments.push_back(schema.arguments().at(arg_annotation_idx++));
   757       new_input->setType(arguments.back().type());
   767     AT_ASSERT(def_stack_.back().merged_return_type_);
   769     Value* result = environment_stack->getVar(
"$return", range);
   770     block->registerOutput(result);
   771     return Argument(
"", def_stack_.back().merged_return_type_);
   774   void emitStatements(
const List<Stmt>& statements) {
   775     return emitStatements(statements.begin(), statements.end());
   777   std::pair<std::shared_ptr<Graph>, 
Value*> lambdaLift(
Block* block) {
   778     auto subgraph = std::make_shared<Graph>();
   781         graph->insertNode(graph->create(prim::TupleConstruct, {}, 1));
   782     Value* context = subgraph->addInput(
"context");
   784     Node* unpack_context =
   785         subgraph->insertNode(subgraph->create(prim::TupleUnpack, {context}, 0));
   787     std::unordered_map<Value*, Value*> captures;
   789       auto it = captures.find(v);
   790       if (it != captures.end()) {
   793       pack_context->addInput(v);
   794       Value* r = unpack_context->addOutput()->copyMetadata(v);
   798     subgraph->block()->cloneFrom(block, env);
   799     auto context_type = TupleType::create(
   800         fmap(pack_context->inputs(), [](
Value* v) { 
return v->type(); }));
   801     pack_context->output()->setType(context_type);
   802     context->setType(context_type);
   803     return std::make_pair(std::move(subgraph), pack_context->output());
   817   void emitClosure(
const Def& def) {
   818     Node* closure_node = graph->insertNode(graph->create(prim::Function, 1));
   819     closure_node->output()->setType(
   822     Block* block = closure_node->addBlock();
   825       pushFrame(block, 
true);
   833     std::shared_ptr<Graph> subgraph;
   835     std::tie(subgraph, context) = lambdaLift(block);
   836     runCleanupPasses(subgraph);
   837     closure_node->eraseBlock(0);
   838     closure_node->g_(attr::Subgraph, std::move(subgraph));
   840         graph->insertNode(graph->createTuple({closure_node->output(), context}))
   842     environment_stack->setVar(def.name().range(), def.name().name(), tup);
   845   void emitReturn(
const Return& stmt) {
   846     Value* result = emitExpr(stmt.expr());
   847     TypePtr result_type = def_stack_.back().declared_return_type_;
   853       if (!(result_type->isSubtypeOf(TensorType::get()) &&
   854             result->type()->isSubtypeOf(NoneType::get()))) {
   855         result = tryConvertToType(
   863       if (!result->type()->isSubtypeOf(result_type)) {
   865             << 
"Return value was annotated as having type "   866             << result_type->python_str() << 
" but is actually of type "   867             << result->type()->python_str();
   870       result_type = def_stack_.back().merged_return_type_;
   872         result_type = result->type();
   874       if (!unifyTypes(result_type, result->type())) {
   876             << 
"Previous return statement returned a value of type "   877             << result_type->python_str()
   878             << 
" but this return statement returns a value of type "   879             << result->type()->python_str();
   882     AT_ASSERT(result_type);
   883     def_stack_.back().merged_return_type_ = result_type;
   884     environment_stack->setVar(stmt.range(), 
"$return", result);
   890     for (; begin != end; ++begin) {
   892       switch (stmt.kind()) {
   897           emitWhile(
While(stmt));
   903           emitAssignment(
Assign(stmt));
   909           for (
auto ident : 
Global(stmt).names()) {
   910             const auto& name = 
Ident(ident).name();
   911             environment_stack->setVar(
   912                 ident.range(), name, graph->addInput(name));
   917           emitSugaredExpr(expr, 0);
   920           emitRaise(
Raise(stmt).range());
   932           emitClosure(
Def(stmt));
   936               << 
"Unrecognized statement kind " << kindToString(stmt.kind());
   941   std::shared_ptr<Environment> emitSingleIfBranch(
   947     insertRefinements(refinements);
   948     emitStatements(branch);
   953     return graph->create(kind, n_outputs)
   954         ->setSourceLocation(std::make_shared<SourceRange>(loc));
   958     const auto& bool_info = findRefinements(expr.cond());
   959     Value* cond_value = emitCond(expr.cond());
   960     auto true_expr = [&] {
   961       insertRefinements(bool_info.true_refinements_);
   962       return emitExpr(expr.true_expr());
   964     auto false_expr = [&] {
   965       insertRefinements(bool_info.false_refinements_);
   966       return emitExpr(expr.false_expr());
   968     return emitIfExpr(expr.range(), cond_value, true_expr, false_expr);
   973     for (
const auto& name_mappings : ref.mappings_) {
   974       const std::string& name = name_mappings.first;
   975       auto type = name_mappings.second.first;
   976       const auto& range = *name_mappings.second.second;
   977       Value* v = environment_stack->getVar(name, range);
   978       if (type != NoneType::get()) {
   979         Value* output = graph->insert(prim::unchecked_unwrap_optional, {v});
   980         environment_stack->setVar(range, name, output);
   986   Value* emitShortCircuitIf(
   988       const TreeRef& first_expr,
   989       const TreeRef& second_expr,
   991     const auto first_bool_info = findRefinements(first_expr);
   992     Value* first_value = emitCond(
Expr(first_expr));
   999       first_expr_refinements = &first_bool_info.true_refinements_;
  1000       second_expr_refinements = &first_bool_info.false_refinements_;
  1002       first_expr_refinements = &first_bool_info.false_refinements_;
  1003       second_expr_refinements = &first_bool_info.true_refinements_;
  1006     auto get_first_expr = [&] {
  1007       insertRefinements(*first_expr_refinements);
  1011     auto get_second_expr = [&] {
  1012       insertRefinements(*second_expr_refinements);
  1013       return emitCond(
Expr(second_expr));
  1019       return emitIfExpr(loc, first_value, get_first_expr, get_second_expr);
  1021       return emitIfExpr(loc, first_value, get_second_expr, get_first_expr);
  1028       std::function<
Value*()> true_expr,
  1029       std::function<
Value*()> false_expr) {
  1030     Node* n = graph->insertNode(create(prim::If, range, 0));
  1032     n->addInput(cond_value);
  1033     auto* true_block = n->addBlock();
  1034     auto* false_block = n->addBlock();
  1036     auto emit_if_expr = [
this](
Block* b, std::function<Value*()> expr_value) {
  1039       Value* out_val = expr_value();
  1040       b->registerOutput(out_val);
  1044     emit_if_expr(true_block, std::move(true_expr));
  1045     emit_if_expr(false_block, std::move(false_expr));
  1047     auto true_type = true_block->outputs().at(0)->type();
  1048     auto false_type = false_block->outputs().at(0)->type();
  1049     auto unified = unifyTypes(true_type, false_type);
  1052           << 
"if-expression's true branch has type " << true_type->str()
  1053           << 
" but false branch has type " << false_type->str();
  1057     auto expr_value = n->addOutput()->setType(*unified); 
  1063     Value* v = emitExpr(cond);
  1064     if (!v->type()->isSubtypeOf(BoolType::get())) {
  1066       error << 
"expected a boolean expression for condition but found "  1067             << v->type()->str();
  1068       if (v->type()->isSubtypeOf(TensorType::get())) {
  1069         error << 
", to use a tensor in a boolean"  1070               << 
" expression, explicitly cast it with `bool()`";
  1077   void emitIfElseBlocks(
Value* cond_value, 
const If& stmt) {
  1078     Node* n = graph->insertNode(create(prim::If, stmt.range(), 0));
  1079     n->addInput(cond_value);
  1080     const auto bool_info = findRefinements(stmt.cond());
  1081     auto* true_block = n->addBlock();
  1082     auto* false_block = n->addBlock();
  1085     auto save_true = emitSingleIfBranch(
  1086         true_block, stmt.trueBranch(), bool_info.true_refinements_);
  1087     auto save_false = emitSingleIfBranch(
  1088         false_block, stmt.falseBranch(), bool_info.false_refinements_);
  1114     std::set<std::string> mutated_variables;
  1116     for (
auto& v : save_true->definedVariables()) {
  1117       if (save_false->findInAnyFrame(v)) {
  1118         mutated_variables.insert(v);
  1121     for (
auto& v : save_false->definedVariables()) {
  1122       if (save_true->findInAnyFrame(v)) {
  1123         mutated_variables.insert(v);
  1128     for (
const auto& x : mutated_variables) {
  1129       auto tv = save_true->getVar(x, stmt.range());
  1130       auto fv = save_false->getVar(x, stmt.range());
  1131       auto unified = unifyTypes(tv->type(), fv->type());
  1146         error << 
"Type mismatch: " << x << 
" is set to type "  1147               << tv->type()->str() << 
" in the true branch"  1148               << 
" and type " << fv->type()->str() << 
" in the false branch";
  1149         if (save_true->findInParentFrame(x) ||
  1150             save_false->findInParentFrame(x)) {
  1156           save_true->setVariableTypeError(x, error.what());
  1160       true_block->registerOutput(tv);
  1161       false_block->registerOutput(fv);
  1162       environment_stack->setVar(
  1163           stmt.range(), x, n->addOutput()->setType(*unified));
  1167   void emitIf(
const If& stmt) {
  1171     Expr cond = stmt.cond();
  1173     if (cond.kind() != TK_IS && cond.kind() != TK_ISNOT) {
  1175       Value* cond_value = emitCond(cond);
  1176       emitIfElseBlocks(cond_value, stmt);
  1181     auto cond_op = 
BinOp(cond);
  1182     SugaredValuePtr lhs_val = emitSugaredExpr(cond_op.lhs(), 1);
  1183     SugaredValuePtr rhs_val = emitSugaredExpr(cond_op.rhs(), 1);
  1186         cond.kind() == TK_IS ? stmt.trueBranch() : stmt.falseBranch();
  1188         cond.kind() == TK_IS ? stmt.falseBranch() : stmt.trueBranch();
  1190     auto lhs_none = lhs_val->isNone();
  1191     auto rhs_none = rhs_val->isNone();
  1199     if (lhs_none == ALWAYS && rhs_none == ALWAYS) {
  1201       emitStatements(always_none_branch);
  1203         (lhs_none == ALWAYS && rhs_none == NEVER) ||
  1204         (lhs_none == NEVER && rhs_none == ALWAYS)) {
  1206       emitStatements(never_none_branch);
  1210       auto lhs_range = cond_op.lhs().get()->range();
  1211       auto rhs_range = cond_op.rhs().get()->range();
  1213       auto kind = getNodeKind(cond.kind(), cond.get()->trees().size());
  1214       Value* cond_value = emitBuiltinCall(
  1215           cond.get()->range(),
  1219           {lhs_val->asValue(lhs_range, method),
  1220            rhs_val->asValue(rhs_range, method)},
  1223       emitIfElseBlocks(cond_value, stmt);
  1243   void emitLoopCommon(
  1249       bool in_list = 
false) {
  1250     Node* n = graph->insertNode(create(prim::Loop, range, 0));
  1251     Value *max_trip_count_val, *cond_val;
  1254       if (max_trip_count) {
  1256           auto listArg = emitExpr(max_trip_count.value());
  1258           max_trip_count_val = emitBuiltinCall(
  1259               max_trip_count->range(),
  1267           max_trip_count_val = ensureInt(
  1268               max_trip_count->range(), emitExpr(max_trip_count.value()));
  1271         max_trip_count_val = materializeConstant(
  1272             std::numeric_limits<int64_t>::max(),
  1275             integral_constants);
  1278         cond_val = emitCond(cond.value());
  1280         cond_val = graph->insertConstant(
true, 
nullptr, range);
  1283     n->addInput(max_trip_count_val);
  1284     n->addInput(cond_val);
  1285     auto* body_block = n->addBlock();
  1287         body_block->addInput()->setType(IntType::get()); 
  1290       pushFrame(body_block);
  1295           auto listArg = emitExpr(max_trip_count.value());
  1296           trip_count = emitBuiltinCall(
  1297               max_trip_count->range(),
  1301               {listArg, trip_count},
  1305         environment_stack->setVar(
  1306             itr_ident->range(), itr_ident->name(), trip_count);
  1308       emitStatements(body);
  1312         Value* body_cond_value = emitCond(cond.value());
  1313         body_block->registerOutput(body_cond_value);
  1315         Value* cond_value_dummy = graph->insertConstant(
true, 
nullptr, range);
  1316         body_block->registerOutput(cond_value_dummy);
  1319       auto body_frame = popFrame();
  1320       auto outer_frame = environment_stack;
  1324       for (
const auto& x : body_frame->captured_inputs) {
  1325         auto fv = body_frame->getValueInThisFrame(range, x);
  1326         body_block->registerOutput(fv);
  1331       body_frame->deleteExtraInputs(range);
  1334       for (
size_t i = 0; i < body_frame->captured_inputs.size(); ++i) {
  1335         auto x = body_frame->captured_inputs[i];
  1336         n->addInput(outer_frame->getVar(x, range));
  1339         auto typ = body_block->inputs()[i + 1]->type();
  1340         outer_frame->setVar(range, x, n->addOutput()->setType(typ));
  1347       const Ident& target,
  1351     if (args.size() != 1) {
  1353           << 
"range() expects 1 argument but got " << args.size();
  1355     emitLoopCommon(range, {args[0]}, {}, body, target);
  1358   void emitFor(
const For& stmt) {
  1360     auto targets = stmt.targets();
  1361     auto itrs = stmt.itrs();
  1362     auto body = stmt.body();
  1364     if (stmt.itrs().size() != 1) {
  1366           << 
"List of iterables is not supported currently.";
  1368     if (targets.size() != 1) {
  1370           << 
"Iteration variable unpacking is not supported";
  1373     if (targets[0].kind() != TK_VAR) {
  1375           << 
"unexpected expression in variable initialization of for loop";
  1377     auto target = 
Var(targets[0]).name();
  1381     if (itrs[0].kind() == TK_APPLY) {
  1383       if (range_iterator.callee().kind() == TK_VAR) {
  1384         Var var = 
Var(range_iterator.callee());
  1385         if (var.name().name() == 
"range") {
  1386           return emitForRange(
  1387               stmt.range(), target, range_iterator.inputs(), body);
  1394     auto sv = emitSugaredExpr(itrs[0], 1);
  1396     if (
auto siv = std::dynamic_pointer_cast<SimpleValue>(sv)) {
  1397       if (siv->getValue()->type()->kind() == TypeKind::ListType) {
  1398         return emitLoopCommon(
  1399             stmt.range(), {itrs[0]}, {}, body, {target}, 
true);
  1402     auto instances = sv->asTuple(stmt.range(), method);
  1403     const std::string& target_name = target.name();
  1404     pushFrame(environment_stack->block());
  1405     for (
const auto& inst : instances) {
  1406       environment_stack->setSugaredVar(itrs[0].range(), target_name, inst);
  1407       emitStatements(body);
  1410     for (
const auto& n : environment_stack->definedVariables()) {
  1411       if (environment_stack->findInParentFrame(n)) {
  1412         environment_stack->next->setVar(
  1413             stmt.range(), n, environment_stack->getVar(n, stmt.range()));
  1419   void emitWhile(
const While& stmt) {
  1420     auto cond = stmt.cond();
  1421     emitLoopCommon(stmt.range(), {}, {cond}, stmt.body(), {});
  1437     const std::string exception = 
"Exception";
  1438     auto string_input = insertConstant(*graph, exception, 
nullptr, loc);
  1439     graph->insert(prim::RaiseException, {string_input}, {}, loc);
  1442   void emitAssert(
const Assert& stmt) {
  1443     Value* cond_value = emitCond(stmt.test());
  1444     Node* n = graph->insertNode(create(prim::If, stmt.range(), 0));
  1446     n->addInput(cond_value);
  1448     auto* false_block = n->addBlock();
  1451     pushFrame(false_block);
  1453     emitRaise(stmt.range());
  1466     size_t num_normal_assign = 0;
  1467     size_t num_starred = 0;
  1468     for (
const auto& assignee : lhs) {
  1469       if (assignee.kind() == TK_VAR || assignee.kind() == TK_SUBSCRIPT) {
  1470         num_normal_assign++;
  1471       } 
else if (assignee.kind() == TK_STARRED) {
  1474         throw ErrorReport(assignee) << 
"lhs of assignment must be a variable, "  1475                                     << 
"subscript, or starred expression.";
  1479     if (num_starred > 1) {
  1481           << 
"Only one starred expression is allowed on the lhs.";
  1484     if (num_starred > 0 && num_normal_assign == 0) {
  1485       throw ErrorReport(r) << 
"A Starred expression may only appear on the "  1486                            << 
"lhs within the presence of another non-starred"  1497     switch (stmt.aug_op()) {
  1499         return isTensor ? aten::add_ : aten::add;
  1501         return isTensor ? aten::sub_ : aten::sub;
  1503         return isTensor ? aten::div_ : aten::div;
  1505         return isTensor ? aten::mul_ : aten::mul;
  1508             << 
"Unknown augmented assignment: " << kindToString(stmt.aug_op());
  1513   void emitAugAssignment(
const AugAssign& stmt) {
  1514     switch (stmt.lhs().kind()) {
  1516         emitAugAssignmentToVar(stmt);
  1519         emitAugAssignmentToSelectVar(stmt);
  1521       case TK_SUBSCRIPT: {
  1522         emitAugAssignmentToSubscript(stmt);
  1526             << 
"unexpected expression on "  1527             << 
"left-hand side of augmented assignment.";
  1545   void emitAugAssignmentToSelectVar(
const AugAssign& stmt) {
  1546     const auto lhs = 
Select(stmt.lhs());
  1547     const auto lhsSugaredVar =
  1548         environment_stack->getSugaredVar(
Var(lhs.value()).name());
  1549     const auto lhsValue =
  1550         lhsSugaredVar->attr(lhs.range(), method, lhs.selector().name())
  1551             ->asValue(lhs.range(), method);
  1552     if (lhsValue->type()->isSubtypeOf(TensorType::get())) {
  1555       const auto rhs = 
NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
  1556       const auto self = 
NamedValue(stmt.lhs().range(), 
"self", lhsValue);
  1560           getAugOp(stmt, 
true),
  1568           << 
"left-hand side of augmented assignment to module "  1569           << 
"parameters/buffers can only be tensor types";
  1573   void emitAugAssignmentToVar(
const AugAssign& stmt) {
  1574     const auto lhs = 
Var(stmt.lhs());
  1575     const auto lhsValue = environment_stack->getSugaredVar(lhs.name())
  1576                               ->asValue(lhs.range(), method);
  1577     if (lhsValue->type()->isSubtypeOf(TensorType::get())) {
  1579       const auto rhs = 
NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
  1580       const auto self = 
NamedValue(stmt.lhs().range(), 
"self", lhsValue);
  1581       const auto output = emitBuiltinCall(
  1584           getAugOp(stmt, 
true),
  1590       environment_stack->setVar(lhs.range(), lhs.name().name(), output);
  1594       Ident lhs = 
Var(stmt.lhs()).name();
  1595       Expr expr = BinOp::create(
  1598           Var::create(lhs.range(), lhs),
  1600       environment_stack->setVar(lhs.range(), lhs.name(), emitExpr(expr));
  1604   void emitAugAssignmentToSubscript(
const AugAssign& stmt) {
  1607     const auto sliceable = emitExpr(lhs.value());
  1609     if (sliceable->type()->isSubtypeOf(TensorType::get())) {
  1612       std::vector<Value*> tensorIndices;
  1614       std::tie(sliced, tensorIndices) = emitIntAndSliceIndexing(
  1615           lhs.range(), sliceable, lhs.subscript_exprs());
  1617       const auto slicedArg = 
NamedValue(stmt.lhs().range(), 
"self", sliced);
  1618       const auto rhs = 
NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
  1619       if (tensorIndices.size() == 0) {
  1625             getAugOp(stmt, 
true),
  1633         const auto indices = graph
  1634                                  ->insertNode(graph->createList(
  1635                                      OptionalType::ofTensor(), tensorIndices))
  1637         const auto indexed =
  1638             graph->insert(aten::index, {slicedArg, indices}, {}, stmt.range());
  1639         const auto augmented = emitBuiltinCall(
  1642             getAugOp(stmt, 
true),
  1649             {slicedArg, indices, augmented},
  1657       const auto listType = sliceable->type()->cast<
ListType>();
  1658       AT_ASSERT(listType != 
nullptr);
  1661           listType->getElementType()->isSubtypeOf(TensorType::get());
  1664       const auto subscriptExprs = lhs.subscript_exprs();
  1665       if (subscriptExprs.size() != 1) {
  1667             << 
"Sliced expression not yet supported for"  1668             << 
" subscripted list augmented assignment. "  1669             << 
"File a bug if you want this.";
  1671       const auto idxValue = emitExpr(subscriptExprs[0]);
  1673       const auto listArg = 
NamedValue(lhs.value().range(), 
"list", sliceable);
  1674       const auto idxArg = 
NamedValue(subscriptExprs.range(), 
"idx", idxValue);
  1675       const auto valueArg =
  1676           NamedValue(stmt.rhs().range(), 
"value", emitExpr(stmt.rhs()));
  1678       const auto getItem =
  1679           graph->insert(aten::select, {listArg, idxArg}, {}, stmt.range());
  1680       const auto augmentedItem = graph->insert(
  1681           getAugOp(stmt, isTensorList), {getItem, valueArg}, {}, stmt.range());
  1683           aten::_set_item, {listArg, idxArg, augmentedItem}, {}, stmt.range());
  1688   void emitSubscriptAssign(
  1692     emitSubscriptAssign(stmtRange, lhs, 
NamedValue(rhs.range(), emitExpr(rhs)));
  1695   void emitSubscriptAssign(
  1700     auto sliceable = emitExpr(lhs.value());
  1703     if (sliceable->type()->isSubtypeOf(TensorType::get())) {
  1704       std::vector<Value*> tensorIndices;
  1711       std::tie(sliced, tensorIndices) = emitIntAndSliceIndexing(
  1712           lhs.range(), sliceable, lhs.subscript_exprs());
  1714       const auto slicedArg = 
NamedValue(lhs.range(), sliced);
  1715       if (tensorIndices.size() == 0) {
  1718         graph->insert(aten::copy_, {slicedArg, rhs}, {}, stmtRange);
  1722         const auto indices = graph
  1723                                  ->insertNode(graph->createList(
  1724                                      OptionalType::ofTensor(), tensorIndices))
  1728             aten::index_put_, {slicedArg, indices, rhs}, {}, stmtRange);
  1734       const auto subscript = lhs.subscript_exprs();
  1735       if (subscript.size() != 1 || subscript[0].kind() == TK_SLICE_EXPR) {
  1737             << 
"Sliced expression not yet supported for"  1738             << 
" subscripted list assignment. "  1739             << 
"File a bug if you want this.";
  1742       std::vector<NamedValue> args;
  1743       args.emplace_back(lhs.value().range(), 
"list", sliceable);
  1745           lhs.subscript_exprs().range(), 
"idx", emitExpr(subscript[0]));
  1746       args.push_back(rhs);
  1748       graph->insert(aten::_set_item, args, {}, stmtRange);
  1753     size_t n_binders = tl.inputs().size();
  1754     bool starred_unpack = calcNumStarredUnpack(tl.inputs(), tl.range());
  1757     auto output = emitSugaredExpr(rhs, n_binders);
  1758     auto outputs = output->asTuple(
  1762     if (outputs.size() < n_binders) {
  1764           << 
"need " << (starred_unpack ? 
"at least " : 
"") << n_binders
  1765           << 
" values to unpack but found only " << outputs.size();
  1767     if (outputs.size() > n_binders && !starred_unpack) {
  1768       throw ErrorReport(tl) << 
"too many values to unpack: need " << n_binders
  1769                             << 
" but found " << outputs.size();
  1772     for (
auto assignee : tl.inputs()) {
  1773       switch (assignee.kind()) {
  1775           emitSubscriptAssign(
  1779                   rhs.range(), outputs.at(i)->asValue(rhs.range(), method)));
  1783           environment_stack->setSugaredVar(
  1784               assignee.range(), 
Var(assignee).name().name(), outputs.at(i));
  1788           auto var = 
Starred(assignee).expr();
  1789           if (var.kind() != TK_VAR) {
  1791                 << 
"Cannot pack a tuple into a non-variable.";
  1793           size_t n_matched = outputs.size() - n_binders;
  1796               outputs_ref.
slice(i, n_matched),
  1797               [&](
const std::shared_ptr<SugaredValue>& v) {
  1798                 return v->asValue(assignee.range(), method);
  1800           auto tup = graph->insertNode(graph->createTuple(values))->output();
  1801           environment_stack->setVar(var.range(), 
Var(var).name().name(), tup);
  1806               << 
"unexpected expression on the left-hand side";
  1811   void emitAssignment(
const Assign& stmt) {
  1812     switch (stmt.lhs().kind()) {
  1814         auto v = 
Var(stmt.lhs());
  1815         environment_stack->setSugaredVar(
  1816             v.range(), v.name().name(), emitSugaredExpr(stmt.rhs(), 1));
  1818       case TK_TUPLE_LITERAL:
  1822         emitSelectAssign(stmt);
  1825         emitSubscriptAssign(stmt.range(), 
Subscript(stmt.lhs()), stmt.rhs());
  1829             << 
"unexpected expression on left-hand side of assignment.";
  1833   void emitSelectAssign(
const Assign& stmt) {
  1834     const auto lhs = 
Select(stmt.lhs());
  1835     const auto basename = 
Var(lhs.value()).name();
  1836     const auto rhsValue =
  1837         emitSugaredExpr(stmt.rhs(), 1)->asValue(stmt.rhs().range(), method);
  1838     auto userObject = environment_stack->getSugaredVar(basename);
  1839     userObject->setAttr(stmt.range(), method, lhs.selector().name(), rhsValue);
  1842   NodeKind getNodeKind(
int kind, 
int ninputs) {
  1848       case TK_UNARY_MINUS:
  1855         return aten::matmul;
  1857         return prim::Starred;
  1861         return aten::remainder;
  1875         return aten::__and__;
  1877         return aten::__or__;
  1879         return aten::__is__;
  1881         return aten::__isnot__;
  1883         return aten::__not__;
  1885         return aten::floordiv;
  1887         return aten::__and__;
  1889         return aten::__or__;
  1891         return aten::__xor__;
  1893         throw std::runtime_error(
"unknown kind " + std::to_string(kind));
  1897   std::vector<NamedValue> getNamedValues(
  1898       const TreeList& trees,
  1899       bool maybe_unpack) {
  1900     std::vector<NamedValue> values;
  1901     for (
const auto& tree : trees) {
  1902       if (maybe_unpack && tree->kind() == TK_STARRED) {
  1904         auto entries = emitSugaredExpr(starred.expr(), 1)
  1905                            ->asTuple(starred.range(), method);
  1906         for (
const auto& entry : entries) {
  1907           values.emplace_back(
  1908               tree->range(), entry->asValue(starred.range(), method));
  1911         values.emplace_back(tree->range(), emitExpr(
Expr(tree)));
  1916   std::vector<NamedValue> getNamedValues(
  1918       bool maybe_unpack) {
  1919     return getNamedValues(trees.tree()->trees(), maybe_unpack);
  1922   std::vector<Value*> getValues(
const TreeList& trees, 
bool maybe_unpack) {
  1923     return toValues(*graph, getNamedValues(trees, maybe_unpack));
  1925   std::vector<Value*> getValues(
const List<Expr>& trees, 
bool maybe_unpack) {
  1926     return getValues(trees.tree()->trees(), maybe_unpack);
  1929   std::vector<NamedValue> emitAttributes(
const List<Attribute>& attributes) {
  1930     return fmap(attributes, [&](
const Attribute& attr) {
  1932           attr.range(), attr.name().name(), emitExpr(attr.value()));
  1937     if (apply.inputs().size() != 2) {
  1939                              << 
" expected exactly two arguments but found "  1940                              << apply.inputs().size();
  1942     if (apply.attributes().size() > 0) {
  1944           << 
Var(apply.callee()).name().name() << 
" takes no keyword arguments";
  1948   std::shared_ptr<SugaredValue> emitApplyExpr(
Apply& apply, 
size_t n_binders) {
  1949     auto sv = emitSugaredExpr(apply.callee(), 1);
  1950     auto loc = apply.callee().range();
  1951     if (
auto fork_value = dynamic_cast<ForkValue*>(sv.get())) {
  1952       auto& trees = apply.inputs().tree()->trees();
  1953       if (trees.size() < 1) {
  1954         throw ErrorReport(loc) << 
"Expected at least one argument to fork()";
  1957       auto forked = emitSugaredExpr(
Expr(trees[0]), 1);
  1958       TreeList sliced_trees(trees.begin() + 1, trees.end());
  1959       auto inputs = getNamedValues(sliced_trees, 
true);
  1960       auto attributes = emitAttributes(apply.attributes());
  1961       return emitForkExpr(loc, forked, inputs, attributes);
  1962     } 
else if (
auto annotate_value = dynamic_cast<AnnotateValue*>(sv.get())) {
  1963       checkApplyExpr(apply, loc);
  1964       TypePtr type = parseTypeFromExpr(apply.inputs()[0]);
  1965       Value* expr = tryConvertToType(
  1969           emitExpr(apply.inputs()[1], type),
  1978       bool forget_opt_annotate =
  1979           opt_type && *opt_type->getElementType() == *type;
  1981       if (!forget_opt_annotate && !expr->type()->isSubtypeOf(type)) {
  1983             << 
"expected an expression of type " << type->python_str()
  1984             << 
" but found " << expr->type()->python_str();
  1986       return std::make_shared<SimpleValue>(expr);
  1987     } 
else if (
auto getattr = dynamic_cast<GetAttrValue*>(sv.get())) {
  1988       checkApplyExpr(apply, loc);
  1989       auto obj = emitSugaredExpr(apply.inputs()[0], 1);
  1990       auto selector = apply.inputs()[1];
  1991       if (selector.kind() != TK_STRINGLITERAL) {
  1993             << 
"getattr's second argument must be a string literal";
  1996       return obj->attr(apply.range(), method, name);
  1997     } 
else if (
auto isinstance = dynamic_cast<IsInstanceValue*>(sv.get())) {
  2001       std::function<bool(Expr, Expr)> isInstanceCheck = [&](
Expr obj,
  2003         if (classinfo.kind() == TK_TUPLE_LITERAL) {
  2007             if (isInstanceCheck(obj, e)) {
  2013         auto type_name = parseBaseTypeName(classinfo);
  2016               << 
"type must be a type identifier";
  2018         auto val = emitExpr(obj);
  2022         if (*type_name == 
"list" && val->type()->cast<
ListType>()) {
  2024         } 
else if (*type_name == 
"tuple" && val->type()->cast<
TupleType>()) {
  2028               << 
"Optional isinstance check is not supported, "  2029               << 
"consider use is/isnot None instead";
  2031           TypePtr type = parseTypeFromExpr(classinfo);
  2032           if (val->type()->isSubtypeOf(type)) {
  2038       checkApplyExpr(apply, loc);
  2039       bool is_instance_val =
  2040           isInstanceCheck(apply.inputs()[0], apply.inputs()[1]);
  2041       return std::make_shared<SimpleValue>(
  2042           graph->insertConstant(is_instance_val, 
nullptr, loc));
  2043     } 
else if (
auto classNew = dynamic_cast<ClassNewMethod*>(sv.get())) {
  2044       if (apply.inputs().size() != 1) {
  2045         throw ErrorReport(loc) << 
"Only one argument to __new__ allowed";
  2047       return classNew->createObject(
  2048           apply.range(), method, 
Var(apply.inputs()[0]).name().name());;
  2050       auto inputs = getNamedValues(apply.inputs(), 
true);
  2051       auto attributes = emitAttributes(apply.attributes());
  2052       return sv->call(loc, method, inputs, attributes, n_binders);
  2056   BoolInfo findRefinements(
const TreeRef& tree) {
  2057     switch (tree->kind()) {
  2060         const auto& inputs = tree->trees();
  2061         if (inputs.at(0)->kind() == TK_VAR && inputs.at(1)->kind() == TK_NONE) {
  2062           const std::string& var_name = 
Var(inputs[0]).name().name();
  2065               environment_stack->getVar(var_name, inputs[0]->range())->type();
  2067             false_info.setRefinement(
  2069                 TypeAndRange(opt_type->getElementType(), &tree->range()));
  2070             true_info.setRefinement(
  2071                 var_name, TypeAndRange(NoneType::get(), &tree->range()));
  2073           if (tree->kind() == TK_IS) {
  2074             return BoolInfo(true_info, false_info);
  2076             return BoolInfo(false_info, true_info);
  2081         const auto& inputs = tree->trees();
  2082         auto bool_info = findRefinements(inputs[0]);
  2084             bool_info.false_refinements_, bool_info.true_refinements_);
  2088         const auto& inputs = tree->trees();
  2089         auto first = findRefinements(inputs[0]);
  2090         auto second = findRefinements(inputs[1]);
  2091         if (tree->kind() == TK_OR) {
  2092           return *first.mergeOr(second);
  2094           return *first.mergeAnd(second);
  2101   Value* emitExpr(
const Expr& tree, 
const TypePtr& type_hint = 
nullptr) {
  2102     return emitSugaredExpr(tree, 1, type_hint)->asValue(tree.range(), method);
  2106     if (kind == aten::lt) {
  2108     } 
else if (kind == aten::le) {
  2110     } 
else if (kind == aten::gt) {
  2112     } 
else if (kind == aten::ge) {
  2115     throw std::runtime_error(
  2116         "reverseComparision: unsupported NodeKind. File a bug");
  2126   std::shared_ptr<SugaredValue> emitSugaredExpr(
  2129       const TypePtr& type_hint = 
nullptr) {
  2130     switch (tree.kind()) {
  2132         return environment_stack->getSugaredVar(
Var(tree).name());
  2134         auto select = 
Select(tree);
  2135         auto sv = emitSugaredExpr(select.value(), 1);
  2136         return sv->attr(select.range(), method, select.selector().name());
  2139         auto apply = 
Apply(tree);
  2140         return emitApplyExpr(apply, n_binders);
  2143         return std::make_shared<SimpleValue>(emitSimpleExpr(tree, type_hint));
  2147   Value* emitNegate(
const TreeRef& tree) {
  2148     const auto& inputs = tree->trees();
  2149     auto named_values = getNamedValues(inputs, 
false);
  2151     auto neg_val = emitBuiltinCall(
  2161     auto maybe_constant_input = toIValue(neg_val->node()->input());
  2162     if (!maybe_constant_input) {
  2165     auto op = getOperation(neg_val->node());
  2167     stack.push_back(*maybe_constant_input);
  2169     AT_ASSERT(stack.size() == 1);
  2170     return graph->insertConstant(stack[0], 
nullptr, tree->range());
  2174   std::shared_ptr<SugaredValue> emitForkExpr(
  2176       const std::shared_ptr<SugaredValue>& forked,
  2182             ->insertNode(method.graph()->create(prim::fork, 1))
  2183             ->setSourceLocation(std::make_shared<SourceRange>(loc));
  2184     auto body_block = fork_node->addBlock();
  2190       auto fn_sugared_output = forked->call(loc, method, inputs, attributes, 1);
  2191       auto fn_simple_output = fn_sugared_output->asValue(loc, method);
  2192       body_block->registerOutput(fn_simple_output);
  2193       node_output = fork_node->output()->setType(
  2194           FutureType::create(fn_simple_output->type()));
  2198     lambdaLiftFork(fork_node);
  2200     return std::make_shared<SimpleValue>(node_output);
  2203   Value* emitSimpleExpr(
  2204       const TreeRef& tree,
  2205       const TypePtr& type_hint = 
nullptr) {
  2206     switch (tree->kind()) {
  2226       case TK_FLOOR_DIV: {
  2227         const auto& inputs = tree->trees();
  2228         auto kind = getNodeKind(tree->kind(), inputs.
size());
  2229         auto named_values = getNamedValues(inputs, 
false);
  2230         return emitBuiltinCall(
  2239       case TK_UNARY_MINUS: {
  2240         return emitNegate(tree);
  2244         const auto& inputs = tree->trees();
  2245         return emitShortCircuitIf(
  2246             tree->range(), inputs[0], inputs[1], tree->kind() == TK_OR);
  2250             << 
"Unexpected starred expansion. File a bug report.";
  2253         return emitConst(
Const(tree));
  2256         return graph->insertConstant(
true, 
nullptr, tree->range());
  2259         return graph->insertConstant(
false, 
nullptr, tree->range());
  2262         return graph->insertConstant(
IValue(), 
nullptr, tree->range());
  2264       case TK_SUBSCRIPT: {
  2270       case TK_STRINGLITERAL: {
  2273       case TK_LIST_LITERAL: {
  2275         auto values = getValues(ll.inputs(), 
true);
  2281         TypePtr elem_type = TensorType::get();
  2282         if (type_hint && type_hint->kind() == TypeKind::ListType) {
  2283           elem_type = type_hint->expect<
ListType>()->getElementType();
  2284         } 
else if (!values.empty()) {
  2285           elem_type = values.at(0)->type();
  2291         if (elem_type->isSubtypeOf(TensorType::get())) {
  2292           for (
const auto& value : values) {
  2293             elem_type = unifyTypes(elem_type, value->type()).value();
  2296         for (
auto v : values) {
  2297           if (!v->type()->isSubtypeOf(elem_type)) {
  2299                 << 
"Lists must contain only a single type, expected: "  2300                 << *elem_type << 
" but found " << *v->type() << 
" instead";
  2304             graph->insertNode(graph->createList(elem_type, values))->output();
  2307       case TK_TUPLE_LITERAL: {
  2309         auto values = getValues(ll.inputs(), 
true);
  2310         return graph->insertNode(graph->createTuple(values))->output();
  2312       case TK_DICT_LITERAL: {
  2314         auto key_trees = dl.key_inputs().tree()->trees();
  2315         auto value_trees = dl.value_inputs().tree()->trees();
  2316         AT_ASSERT(key_trees.size() == value_trees.size());
  2317         std::vector<Value*> keys, values;
  2318         for (
size_t i = 0; i < key_trees.size(); ++i) {
  2319           keys.push_back(emitExpr(
Expr(key_trees[i])));
  2320           values.push_back(emitExpr(
Expr(value_trees[i])));
  2323         TypePtr key_type = 
nullptr;
  2324         TypePtr value_type = 
nullptr;
  2326         if (type_hint && type_hint->kind() == TypeKind::DictType) {
  2327           auto dict_type = type_hint->expect<
DictType>();
  2328           key_type = dict_type->getKeyType();
  2329           value_type = dict_type->getValueType();
  2330         } 
else if (!keys.empty()) {
  2331           key_type = keys.at(0)->type();
  2332           value_type = values.at(0)->type();
  2334           key_type = StringType::get();
  2335           value_type = TensorType::get();
  2337         AT_ASSERT(key_type != 
nullptr && value_type != 
nullptr);
  2340             ->insertNode(graph->createDict(key_type, value_type, keys, values))
  2344         throw ErrorReport(tree) << 
"Cannot emit expr for: " << tree;
  2350     if (c.isFloatingPoint())
  2351       return materializeConstant(
  2352           c.asFloatingPoint(), *graph, c.range(), fp_constants);
  2354       return materializeConstant(
  2355           c.asIntegral(), *graph, c.range(), integral_constants);
  2359     return insertConstant(*graph, c.text(), 
nullptr, c.range());
  2368     return emitBuiltinCall(
  2373         {input, graph->insertConstant(dim, 
nullptr, loc), index},
  2385     std::vector<NamedValue> args;
  2387     args.emplace_back(loc, 
"self", input);
  2392       AT_ASSERT(input->type()->isSubtypeOf(TensorType::get()));
  2394           loc, 
"dim", graph->insertConstant(dim.value(), 
nullptr, loc));
  2396       AT_ASSERT(!input->type()->isSubtypeOf(TensorType::get()));
  2399     args.emplace_back(loc, 
"begin", emitExpr(
Expr(slice.startOr(0))));
  2400     const auto has_end = slice.end().present();
  2402       args.emplace_back(loc, 
"end", emitExpr(
Expr(slice.end().get())));
  2406         return emitTupleSlice(loc, args[0], args[1],  args[2]);
  2408         return emitTupleSlice(loc, args[0], args[1], c10::nullopt);
  2412         NamedValue(loc, 
"step", graph->insertConstant(1, 
nullptr, loc));
  2413     return emitBuiltinCall(
  2414         loc, *graph, aten::slice, c10::nullopt, args, {step}, 
true);
  2425         graph->insertNode(graph->createList(OptionalType::ofTensor(), indices))
  2427     return emitBuiltinCall(
  2428         loc, *graph, aten::index, c10::nullopt, {input, index}, {}, 
true);
  2438   std::pair<Value*, std::vector<Value*>> emitIntAndSliceIndexing(
  2442     std::vector<Value*> tensor_indices;
  2445     auto handle_tensor = [&](
Value* tensor) {
  2447       tensor_indices.resize(dim + 1);
  2448       tensor_indices[dim] = tensor;
  2452     for (
const auto& subscript_expr : subscript_exprs) {
  2453       if (subscript_expr.kind() == TK_SLICE_EXPR) {
  2454         sliceable = emitSlice(loc, sliceable, dim, 
SliceExpr(subscript_expr));
  2458       auto index = emitExpr(subscript_expr, OptionalType::ofTensor());
  2459       if (index->type() == IntType::get()) {
  2460         sliceable = emitSelect(loc, sliceable, dim, index);
  2462       } 
else if (index->type()->isSubtypeOf(OptionalType::ofTensor())) {
  2464         handle_tensor(index);
  2468           << 
"Unsupported operation: indexing tensor with unsupported index type '"  2469           << index->type()->str()
  2470           << 
"'. Only ints, slices, and tensors are supported";
  2474     for (
auto& index : tensor_indices) {
  2475       if (index == 
nullptr) {
  2477             graph->insertNode(graph->createNone(TensorType::get()))->output();
  2480     return std::make_pair(sliceable, tensor_indices);
  2501   Value* emitMultidimSlicing(
  2505     if (!sliceable->type()->isSubtypeOf(TensorType::get())) {
  2507           << 
"Unsupported operation: attempted to use multidimensional "  2508           << 
"indexing on a non-tensor type.";
  2511     std::vector<Value*> tensor_indices;
  2512     std::tie(sliceable, tensor_indices) =
  2513         emitIntAndSliceIndexing(loc, sliceable, subscript_exprs);
  2515     if (tensor_indices.empty()) {
  2520     return emitIndex(loc, sliceable, tensor_indices);
  2525   Value* emitBasicSlice(
  2529     AT_ASSERT(subscript_exprs.size() == 1);
  2530     AT_ASSERT(subscript_exprs[0].kind() == TK_SLICE_EXPR);
  2531     auto slice_exp = 
SliceExpr(subscript_exprs[0]);
  2533     if (sliceable->type()->isSubtypeOf(TensorType::get())) {
  2537     return emitSlice(loc, sliceable, maybe_dim, slice_exp);
  2540   int64_t getTupleIndexVal(
  2542       const TupleTypePtr& tuple_type,
  2544       bool allow_out_of_bounds) {
  2547     if (ivalue && ivalue->isInt()) {
  2548       index = ivalue->to<int64_t>();
  2550       throw ErrorReport(loc) << 
"tuple indices must be integer constants";
  2553     int64_t adj_index = index;
  2554     int64_t tuple_len = tuple_type->elements().size();
  2556       adj_index = tuple_len + index;
  2558     if (!allow_out_of_bounds && (adj_index >= tuple_len || adj_index < 0)) {
  2559       throw ErrorReport(loc) << 
"Tuple index out of range. Tuple is length "  2560                              << tuple_len << 
" and index is " << index;
  2565   Value* emitTupleIndex(
  2569     auto tuple_typ = tuple_val->type()->cast<
TupleType>();
  2570     auto adj_index = getTupleIndexVal(
  2571         loc, tuple_typ, idx_val,  
false);
  2572     return graph->insertNode(graph->createTupleIndex(tuple_val, adj_index))
  2576   Value* emitDictIndex(
  2580     auto dict_type = dict_val->type()->cast<
DictType>();
  2581     AT_ASSERT(key_val->type()->isSubtypeOf(dict_type->getKeyType()));
  2582     return graph->insertNode(graph->createDictIndex(dict_val, key_val))
  2586   Value* emitTupleSlice(
  2591     auto tuple_type = tuple_val.value(*graph)->type()->expect<
TupleType>();
  2592     int64_t beg = getTupleIndexVal(
  2593         loc, tuple_type, beg_val.value(*graph),  
true);
  2595     int64_t tuple_len = tuple_type->elements().size();
  2597       end = getTupleIndexVal(loc, tuple_type, end_val->value(*graph), 
true);
  2602     end = std::min(std::max((int64_t)0, end), tuple_len);
  2603     beg = std::min(std::max((int64_t)0, beg), tuple_len);
  2606         ->insertNode(graph->createTupleSlice(tuple_val.value(*graph), beg, end))
  2611     return emitSubscript(
  2613         emitExpr(subscript.value()),
  2614         subscript.subscript_exprs());
  2617   Value* emitSubscript(
  2621     if (subscript_exprs.size() != 1) {
  2622       return emitMultidimSlicing(loc, sliceable, subscript_exprs);
  2624     if (subscript_exprs[0].kind() == TK_SLICE_EXPR) {
  2625       return emitBasicSlice(loc, sliceable, subscript_exprs);
  2627       return emitBasicGather(loc, sliceable, subscript_exprs);
  2632   Value* emitBasicGather(
  2636     AT_ASSERT(subscript_exprs.size() == 1);
  2638     if (gatherable->type()->kind() == TypeKind::ListType) {
  2640       auto* idx = emitExpr(subscript_exprs[0]);
  2641       return emitBuiltinCall(
  2642           loc, *graph, aten::select, c10::nullopt, {gatherable, idx}, {}, 
true);
  2643     } 
else if (gatherable->type()->isSubtypeOf(TensorType::get())) {
  2644       return emitMultidimSlicing(loc, gatherable, subscript_exprs);
  2645     } 
else if (
auto tuple_type = gatherable->type()->cast<
TupleType>()) {
  2646       auto* idx = emitExpr(subscript_exprs[0]);
  2647       return emitTupleIndex(loc, gatherable, idx);
  2648     } 
else if (
auto dict_type = gatherable->type()->cast<
DictType>()) {
  2649       auto* idx = emitExpr(subscript_exprs[0]);
  2650       return emitDictIndex(loc, gatherable, idx);
  2653           << 
"Indexing only supported on lists, dictionaries, "  2654              "tensors, and tuples, but got type '"  2655           << gatherable->type()->str() << 
"'";
  2660 void defineMethodsInModule(
  2661     const std::shared_ptr<Module>& m,
  2662     const std::vector<Def>& definitions,
  2663     const std::vector<Resolver>& resolvers,
  2665   AT_ASSERT(definitions.size() == resolvers.size());
  2666   auto resolver_it = resolvers.begin();
  2667   std::vector<Method*> methods;
  2668   std::unordered_map<std::string, Method*> function_table;
  2669   for (
const Def& def : definitions) {
  2670     const std::string& name = def.name().name();
  2671     auto resolver = *resolver_it++;
  2672     AT_ASSERT(resolver);
  2677       resolver = [resolver, &function_table](
  2678                      const std::string& name,
  2680                      const SourceRange& loc) -> std::shared_ptr<SugaredValue> {
  2681         auto it = function_table.find(name);
  2682         if (it != function_table.end()) {
  2683           return std::make_shared<MethodValue>(
nullptr, *it->second);
  2685         return resolver(name, m, loc);
  2688     auto creator = [def, resolver, 
self](
Method& method) {
  2689       AT_ASSERT(resolver);
  2690       to_ir(def, resolver, 
self, method);
  2692     Method& method = m->create_method(name, creator);
  2693     function_table[name] = &method;
  2694     methods.push_back(&method);
  2696   for (
Method* method : methods) {
  2697     method->ensure_defined();
  2699   if (!
self || !self->asFirstClass()) {
  2701     didFinishEmitModule(m);
  2705 void defineMethodsInModule(
  2706     const std::shared_ptr<Module>& m,
  2707     const std::string& source,
  2708     const Resolver& resolver,
  2711   std::vector<Def> definitions;
  2712   std::vector<Resolver> resolvers;
  2713   while (p.lexer().cur().kind != TK_EOF) {
  2714     auto def = 
Def(p.parseFunction(
bool(
self)));
  2715     definitions.push_back(def);
  2716     resolvers.push_back(resolver);
  2718   defineMethodsInModule(m, definitions, resolvers, 
self);
  2721 void lambdaLiftFork(
Node* fork_node) {
  2723   auto forked_graph = std::make_shared<Graph>();
  2724   auto body_block = fork_node->blocks()[0];
  2728   std::unordered_map<Value*, Value*> uncaptures_map;
  2730     if (!uncaptures_map.count(v)) {
  2732       uncaptures_map[v] = forked_graph->addInput()->copyMetadata(v);
  2733       fork_node->addInput(v);
  2735     return uncaptures_map[v];
  2737   forked_graph->block()->cloneFrom(body_block, env);
  2740   fork_node->g_(attr::Subgraph, forked_graph);
  2741   fork_node->eraseBlock(0);
 
AT_CPP14_CONSTEXPR ArrayRef< T > slice(size_t N, size_t M) const 
slice(n, m) - Chop off the first N elements of the array, and keep M elements in the array...
 
constexpr size_t size() const 
size - Get the array size. 
 
An utility class for setting temporary insertion points. 
 
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...