Caffe2 - C++ API
A deep learning, cross platform ML framework
compiler.cpp
1 #include <torch/csrc/jit/script/compiler.h>
2 
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/jit/hooks_for_testing.h>
5 #include <torch/csrc/jit/interpreter.h>
6 #include <torch/csrc/jit/ir.h>
7 #include <torch/csrc/jit/operator.h>
8 #include <torch/csrc/jit/passes/constant_pooling.h>
9 #include <torch/csrc/jit/passes/lower_tuples.h>
10 #include <torch/csrc/jit/script/final_returns.h>
11 #include <torch/csrc/jit/script/parser.h>
12 #include <torch/csrc/jit/script/schema_matching.h>
13 #include <torch/csrc/jit/script/script_type_parser.h>
14 #include <torch/csrc/utils/object_ptr.h>
15 
16 #include <torch/csrc/jit/constants.h>
17 
18 #include <c10/util/Optional.h>
19 
20 #include <climits>
21 #include <set>
22 
23 namespace torch {
24 namespace jit {
25 namespace script {
26 
27 using SugaredValuePtr = std::shared_ptr<SugaredValue>;
28 using FunctionTable = std::unordered_map<std::string, Method&>;
29 using ValueTable = std::unordered_map<std::string, SugaredValuePtr>;
30 using AttributeMap = std::unordered_map<std::string, Const>;
31 using ListAttributeMap = std::unordered_map<std::string, std::vector<Const>>;
32 
33 using TypeAndRange = std::pair<TypePtr, const SourceRange*>;
34 
35 // Holds mappings from a variable name to a refined type for that variable
36 // E.g if x is not None is true than we can refine x from type t? to t.
37 struct Refinements {
38  // using ordered map for deterministic graph output
39  std::map<std::string, TypeAndRange> mappings_;
40 
41  void setRefinement(const std::string& name, TypeAndRange mapping) {
42  mappings_[name] = std::move(mapping);
43  }
44 
45  c10::optional<TypeAndRange> getRefinement(const std::string& name) const {
46  const auto& maybe_mapping = mappings_.find(name);
47  if (maybe_mapping == mappings_.end()) {
48  return c10::nullopt;
49  }
50  return maybe_mapping->second;
51  }
52 
53  // return the intersection of the values to type mappings between this
54  // types can be unified
55  void intersectRefinements(const Refinements& other) {
56  Refinements ret;
57  for (const auto& name_mapping : mappings_) {
58  const auto& name = name_mapping.first;
59  const auto& mapping = name_mapping.second;
60  if (auto other_mapping = other.getRefinement(name_mapping.first)) {
61  const auto maybe_unified_type =
62  unifyTypes(mapping.first, other_mapping->first);
63  if (maybe_unified_type) {
64  ret.setRefinement(
65  name, TypeAndRange(*maybe_unified_type, mapping.second));
66  }
67  }
68  }
69  mappings_ = std::move(ret.mappings_);
70  }
71 
72  // return the union of the values to type mappings in a and b whose
73  // types can be unified
74  void unionRefinements(const Refinements& other) {
75  Refinements ret;
76  for (const auto& name_mapping : mappings_) {
77  const auto& name = name_mapping.first;
78  const auto& mapping = name_mapping.second;
79  TypePtr t_1 = mapping.first;
80  if (auto other_mapping = other.getRefinement(name_mapping.first)) {
81  TypePtr t_2 = other_mapping->first;
82  c10::optional<TypePtr> maybe_unified_type = c10::nullopt;
83  if (t_1->isSubtypeOf(t_2)) {
84  maybe_unified_type = t_1;
85  } else if (t_2->isSubtypeOf(t_1)) {
86  maybe_unified_type = t_2;
87  }
88  if (maybe_unified_type) {
89  ret.setRefinement(
90  name, TypeAndRange(*maybe_unified_type, mapping.second));
91  }
92  } else {
93  ret.setRefinement(name, mapping);
94  }
95  }
96 
97  for (auto& name_mapping : other.mappings_) {
98  if (!getRefinement(name_mapping.first)) {
99  ret.setRefinement(name_mapping.first, name_mapping.second);
100  }
101  }
102 
103  mappings_ = std::move(ret.mappings_);
104  }
105 };
106 
107 // When a comparison like x is None is made, we associate type refinements
108 // with its true value and its false value. If a boolean that has refinements
109 // associated with it is used in a conditional of an if statememt, the true and
110 // false refinements are inserted into the corresponding blocks
111 
112 struct BoolInfo {
113  BoolInfo(Refinements true_refinements, Refinements false_refinements)
114  : true_refinements_(std::move(true_refinements)),
115  false_refinements_(std::move(false_refinements)){};
116  BoolInfo() = default;
117 
118  Refinements true_refinements_;
119  Refinements false_refinements_;
120 
121  BoolInfo* mergeOr(const BoolInfo& other) {
122  // if the result of an OR is true, either a & b could have been true,
123  // so we take the intersection of a.true_refinements & b.true_refinements.
124  // if the result is false, both a and b had to be false,
125  // so we take their union.
126  true_refinements_.intersectRefinements(other.true_refinements_);
127  false_refinements_.unionRefinements(other.false_refinements_);
128  return this;
129  }
130 
131  BoolInfo* mergeAnd(const BoolInfo& other) {
132  // if the result of an AND is true, both a & b had to be true,
133  // so we take the union of a.true_refinements and b.true_refinements.
134  // if the result is false, either a or b could have been false,
135  // so we take their intersection.
136  true_refinements_.unionRefinements(other.true_refinements_);
137  false_refinements_.intersectRefinements(other.false_refinements_);
138  return this;
139  }
140 };
141 
142 static Value* asSimple(const SugaredValuePtr& value) {
143  if (SimpleValue* sv = dynamic_cast<SimpleValue*>(value.get())) {
144  return sv->getValue();
145  }
146  return nullptr;
147 }
148 // we consider _N where N is a number, to be a non-meaningful name
149 // and do not record it as a unique name. This allows python printing to
150 // be able to export and import more consistently named graphs
151 static bool meaningfulName(const std::string& name) {
152  if (name.size() == 0)
153  return false;
154  if (name[0] == '$')
155  return false;
156  if (name[0] != '_')
157  return true;
158  for (size_t i = 1; i < name.size(); ++i) {
159  if (!isdigit(name[i]))
160  return true;
161  }
162  return false;
163 }
164 
165 // Auxiliary data structure for desugaring variable binding into our always
166 // explicitly scoped language as we descend down nested control structures in
167 // the frontend (which themselves don't introduce scopes)
168 //
169 // The algorithm is roughly as follows:
170 // 1) While emitting a block within a control operator, add inputs and outputs
171 // from the block for each value referenced (both "reads" and "writes").
172 // This sets the value up as a candidate loop carried dependency.
173 // 2) When we reach the end of the block, examine all the values in the current
174 // scope's value map. If the name also resides in an outer scope with a
175 // different Value*, this is a true loop-carried dependency. If not, this
176 // value was not assigned to. Replace all references to the block input
177 // with the Value* pointed to in the tightest enclosing scope. Then delete
178 // that block input and output.
179 // 3) When we emit the actual control operator, take all of the loop-carried
180 // dependency values as inputs and return them as outputs from the control
181 // op
182 //
183 // Note that an alternative implementation could only add the loop-carried dep
184 // inputs and outputs when we see a value that is mutated. This, however
185 // requires replacing all references to that value *within the current
186 // block* with a new input. That is to say: we need to traverse the pre-
187 // decessor nodes and replace inputs that reference that value with the
188 // newly-created input. This could be made less expensive with a change to
189 // the IR API, but for now we choose to pessimisitically create inputs and
190 // delete unnecessary ones later with replaceAllusesWith().
191 struct Environment {
192  Environment(
193  Method& method,
194  Resolver resolver,
195  Block* b,
196  std::shared_ptr<Environment> next = nullptr)
197  : method(method),
198  resolver(std::move(resolver)),
199  b(b),
200  next(std::move(next)) {}
201 
202  Method& method;
203  Resolver resolver;
204  std::vector<std::string> captured_inputs;
205  std::unordered_map<std::string, std::string> error_messages;
206  Block* b;
207 
208  std::shared_ptr<Environment> next;
209 
210  // set type error in the lowest environment. if the variable is used after an
211  // error has been set, then we will use the more informative error message
212  void setVariableTypeError(const std::string& name, const std::string& msg) {
213  auto runner = this;
214  while (runner->next) {
215  runner = runner->next.get();
216  }
217  runner->error_messages[name] = msg;
218  }
219 
220  // see if type error has been set for a variable
221  c10::optional<std::string> findVariableTypeError(const std::string& name) {
222  auto runner = this;
223  while (runner->next) {
224  runner = runner->next.get();
225  }
226  auto msg = runner->error_messages.find(name);
227  if (msg != runner->error_messages.end()) {
228  return msg->second;
229  } else {
230  return c10::nullopt;
231  }
232  }
233 
234  SugaredValuePtr findInThisFrame(const std::string& name) {
235  auto it = value_table.find(name);
236  if (it != value_table.end()) {
237  return it->second;
238  }
239  return nullptr;
240  }
241 
242  SugaredValuePtr findInParentFrame(const std::string& name) {
243  return next ? next->findInAnyFrame(name) : nullptr;
244  }
245 
246  SugaredValuePtr findInAnyFrame(const std::string& name) {
247  for (auto runner = this; runner; runner = runner->next.get()) {
248  if (auto r = runner->findInThisFrame(name)) {
249  return r;
250  }
251  }
252  return nullptr;
253  }
254 
255  Value* getValueInThisFrame(const SourceRange& loc, const std::string& name) {
256  return value_table.at(name)->asValue(loc, method);
257  }
258 
259  SugaredValuePtr createCapturedInput(Value* orig, const std::string& name) {
260  // insert the captured input alphabetically in the capture list.
261  // this ensures consistency of the order of loop-carried dependencies
262  // even when the use in the loop is in a different order
263  size_t insert_pos = 0;
264  while (insert_pos < captured_inputs.size() &&
265  name > captured_inputs[insert_pos]) {
266  insert_pos++;
267  }
268  captured_inputs.insert(captured_inputs.begin() + insert_pos, name);
269 
270  // Create the input
271  const size_t loop_carried_block_inputs_offset = 1;
272  Value* new_input =
273  b->insertInput(loop_carried_block_inputs_offset + insert_pos)
274  ->setType(orig->type());
275 
276  // Associate this name with this value
277  auto sv = std::make_shared<SimpleValue>(new_input);
278  value_table[name] = sv;
279 
280  return sv;
281  }
282 
283  SugaredValuePtr createCapturedInputIfNeeded(
284  const SourceRange& loc,
285  const std::string& ident) {
286  auto in_frame = findInThisFrame(ident);
287  if (in_frame) {
288  return in_frame;
289  }
290 
291  // recursively handles the case where parent blocks are also loops
292  auto from_parent =
293  next ? next->createCapturedInputIfNeeded(loc, ident) : nullptr;
294 
295  // recursively create the captured input if it is the loop block
296  if (from_parent && getBlockOwningKind() == prim::Loop) {
297  if (Value* simple_val = asSimple(from_parent))
298  from_parent = createCapturedInput(simple_val, ident);
299  }
300  return from_parent;
301  }
302 
303  Block* block() {
304  return b;
305  }
306  Symbol getBlockOwningKind() {
307  Symbol owning_kind = Symbol();
308  if (b->owningNode()) {
309  owning_kind = b->owningNode()->kind();
310  }
311  return owning_kind;
312  }
313 
314  void setVar(const SourceRange& loc, const std::string& name, Value* value) {
315  setSugaredVar(loc, name, std::make_shared<SimpleValue>(value));
316  }
317 
318  void setSugaredVar(
319  const SourceRange& loc,
320  const std::string& name,
321  SugaredValuePtr value) {
322  Value* as_simple_value = asSimple(value);
323  if (as_simple_value && !as_simple_value->hasUniqueName() &&
324  meaningfulName(name) &&
325  // note: if the value wasn't defined in this block, we might be giving a
326  // name only used inside this block to a value outside of this. this is
327  // not normally helpful for debugging and causes import/export jitter.
328  as_simple_value->node()->owningBlock() == block()) {
329  as_simple_value->setUniqueName(name);
330  }
331  // prevent re-assignment involving any sugared values
332  // any reassignment like:
333  // a = ...
334  // while ...
335  // a = ..
336  // requires 'a' to be first-class in the graph since its value depends on
337  // control flow
338  if (auto parent = findInParentFrame(name)) {
339  if (!as_simple_value) {
340  throw ErrorReport(loc)
341  << "Cannot re-assign '" << name << "' to a value of type "
342  << value->kind() << " because " << name
343  << " is not a first-class value. Only reassignments to first-class values are allowed";
344  }
345  Value* simple_parent = asSimple(parent);
346  if (!simple_parent) {
347  throw ErrorReport(loc)
348  << "Cannot re-assign '" << name << "' because it has type "
349  << value->kind() << " and " << name
350  << " is not a first-class value. Only reassignments to first-class values are allowed";
351  }
352  if (!as_simple_value->type()->isSubtypeOf(
353  unshapedType(simple_parent->type()))) {
354  std::stringstream errMsg;
355  errMsg << "variable '" << name << "' previously has type "
356  << simple_parent->type()->str()
357  << " but is now being assigned to a value of type "
358  << as_simple_value->type()->str();
359  // Special-cased error msg if we're trying to assign to a tensor list.
360  if (simple_parent->type()->kind() == TypeKind::ListType &&
361  as_simple_value->type()->kind() == TypeKind::ListType) {
362  errMsg << "\n. (Note: empty lists are constructed as Tensor[]; "
363  << "if you want an empty list of a different type, "
364  << "use `torch.jit.annotate(List[T], [])`, "
365  << "where `T` is the type of elements in the list)";
366  }
367  throw ErrorReport(loc) << errMsg.str();
368  }
369  }
370  if (as_simple_value)
371  createCapturedInputIfNeeded(loc, name);
372  value_table[name] = std::move(value);
373  }
374 
375  SugaredValuePtr getSugaredVar(const Ident& ident, bool required = true) {
376  return getSugaredVar(ident.name(), ident.range());
377  }
378  Value* getVar(const Ident& ident) {
379  return getSugaredVar(ident)->asValue(ident.range(), method);
380  }
381 
382  SugaredValuePtr getSugaredVar(
383  const std::string& ident,
384  const SourceRange& range,
385  bool required = true) {
386  auto retval = createCapturedInputIfNeeded(range, ident);
387 
388  if (!retval) {
389  static std::unordered_map<std::string, SugaredValuePtr> globals = {
390  {"print", std::make_shared<PrintValue>()},
391  {"float", std::make_shared<CastValue>(FloatType::get(), prim::Float)},
392  {"int", std::make_shared<CastValue>(IntType::get(), prim::Int)},
393  {"bool", std::make_shared<CastValue>(BoolType::get(), prim::Bool)},
394  {"getattr", std::make_shared<GetAttrValue>()},
395  {"isinstance", std::make_shared<IsInstanceValue>()},
396  // todo(zach): remove when we can correctly export torch.full via ONNX
397  // or we have implicit conversion that can convert numbers to tensors
398  {"_to_tensor",
399  std::make_shared<CastValue>(TensorType::get(), prim::NumToTensor)},
400  {"len", std::make_shared<BuiltinFunction>(aten::len, at::nullopt)},
401  {"min", std::make_shared<BuiltinFunction>(prim::min, at::nullopt)},
402  {"max", std::make_shared<BuiltinFunction>(prim::max, at::nullopt)},
403  {"list", std::make_shared<BuiltinFunction>(aten::list, at::nullopt)},
404  {"rangelist", std::make_shared<BuiltinFunction>(prim::rangelist, at::nullopt)},
405  };
406  auto it = globals.find(ident);
407  if (it != globals.end()) {
408  retval = it->second;
409  }
410  }
411 
412  if (!retval) {
413  if (auto class_type = ClassType::get(ident)) {
414  retval = std::make_shared<script::ClassValue>(class_type);
415  }
416  }
417 
418  if (!retval) {
419  retval = resolver(ident, method, range);
420  }
421 
422  if (!retval && required) {
423  // check if this value was not emitted in an if statement because of a
424  // type mismatch. if it was, then we print a more informative error msg
425  if (auto msg = findVariableTypeError(ident)) {
426  throw ErrorReport(range) << *msg << "and was used here";
427  }
428  throw ErrorReport(range) << "undefined value " << ident;
429  }
430  return retval;
431  }
432 
433  Value* getVar(const std::string& ident, const SourceRange& range) {
434  return getSugaredVar(ident, range)->asValue(range, method);
435  }
436 
437  // Given that after emitting statements in a block, we've added block inputs
438  // for all value references and assignments, delete inputs for which there was
439  // no assignment, only references.
440  void deleteExtraInputs(const SourceRange& loc) {
441  // note: skip i == 0, it is the loop trip count for inputs
442  // and the loop condition for outputs.
443  // captured_inputs is indexed by i - 1 since it only contains loop
444  // carried dependencies
445  // inputs: loop_counter, lcd0, lcd1, ...
446  // outputs: loop_condition, lcd0, lcd1, ...
447  // captured_inputs: lcd0, lcd1, ...
448  AT_ASSERT(b->inputs().size() == b->outputs().size());
449  AT_ASSERT(b->inputs().size() == captured_inputs.size() + 1);
450  for (size_t i = b->inputs().size() - 1; i > 0; i--) {
451  // nothing changed along this loop
452  if (b->inputs()[i] == b->outputs()[i]) {
453  auto name = captured_inputs[i - 1];
454  Value* orig = findInParentFrame(name)->asValue(loc, method);
455  b->inputs()[i]->replaceAllUsesWith(orig);
456  b->eraseInput(i);
457  b->eraseOutput(i);
458  captured_inputs.erase(captured_inputs.begin() + i - 1);
459  }
460  }
461  }
462  std::vector<std::string> definedVariables() {
463  std::vector<std::string> result;
464  for (auto& kv : value_table) {
465  result.push_back(kv.first);
466  }
467  return result;
468  }
469 
470  private:
471  ValueTable value_table;
472 };
473 
474 template <class T>
475 static Value* materializeConstant(
476  T val,
477  Graph& graph,
478  const SourceRange& r,
479  std::unordered_map<T, Value*>& map) {
480  auto existing_constant = map.find(val);
481  if (existing_constant != map.end()) {
482  return existing_constant->second;
483  }
484 
485  WithInsertPoint guard(graph.block()->nodes().front());
486  auto new_constant = graph.insertConstant(val, nullptr, r);
487  map[val] = new_constant;
488 
489  return new_constant;
490 }
491 
492 static Value* ensureInt(const SourceRange& range, Value* v) {
493  if (!v->type()->isSubtypeOf(IntType::get())) {
494  throw ErrorReport(range)
495  << "expected a int but found a " << v->type()->str();
496  }
497  return v;
498 }
499 
500 inline bool isSupportedListElementType(const TypePtr& type) {
501  return type->isSubtypeOf(TensorType::get()) ||
502  type->isSubtypeOf(NumberType::get());
503 }
504 
505 // Information for each def being emitted.
506 // Defs can be nested to support closures so we need a stack of this information
507 // Currently records information about the functions return type.
508 struct DefContext {
509  TypePtr declared_return_type_; // nullptr if not annotated
510  TypePtr merged_return_type_; // nullptr if a Return has not been seen yet
511 };
512 
513 struct to_ir {
514  to_ir(
515  const Def& def,
516  Resolver resolver_,
517  const c10::optional<Self>& self,
518  Method& method) // method being constructed
519  : method(method),
520  graph(method.graph()),
521  resolver(std::move(resolver_)),
522  environment_stack(nullptr) {
523  AT_ASSERT(resolver);
524  pushFrame(graph->block(), /*starts_def=*/true);
525 
526  // Type annotations exclude explicitly typing the "self" parameter, so in
527  // the case that this is a method with self we expect one fewer parameter
528  // annotation than the number of parameters this Def takes.
529  if (self && def.decl().params().size() == 0) {
530  throw ErrorReport(def.decl().params().range())
531  << "methods must have a self argument";
532  }
533 
534  method.setSchema(emitDef(def, self, graph->block()));
535 
536  runCleanupPasses(graph);
537  }
538 
539  private:
540  Method& method;
541  std::shared_ptr<Graph> graph;
542  Resolver resolver;
543  std::unordered_map<int64_t, Value*> integral_constants;
544  std::unordered_map<double, Value*> fp_constants;
545 
546  // Singly-linked list of environments. This top element contains a member
547  // `next` that points to the most immediate enclosing scope's value.
548  std::shared_ptr<Environment> environment_stack;
549  std::vector<DefContext> def_stack_;
550 
551  void pushFrame(Block* b, bool starts_def = false) {
552  if (starts_def) {
553  def_stack_.emplace_back();
554  }
555  environment_stack =
556  std::make_shared<Environment>(method, resolver, b, environment_stack);
557  }
558  std::shared_ptr<Environment> popFrame(bool ends_def = false) {
559  auto old_frame = environment_stack;
560  environment_stack = environment_stack->next;
561  if (ends_def) {
562  def_stack_.pop_back();
563  }
564  return old_frame;
565  }
566 
567  void runCleanupPasses(std::shared_ptr<Graph>& to_clean) {
568  // remove any uses of tuples that we inserted that are not needed
569  LowerSimpleTuples(to_clean);
570  ConstantPooling(to_clean);
571  }
572 
573  FunctionSchema emitDef(
574  const Def& def,
575  const c10::optional<Self>& self,
576  Block* block) {
577  auto schema = extractSchemaFromDef(def, self);
578  // TODO need guards on init returning none
579  if (schema.returns().size() == 1) {
580  def_stack_.back().declared_return_type_ = schema.returns().at(0).type();
581  }
582  std::vector<Argument> arguments =
583  emitFormalArguments(def, self, schema, block);
584 
585  // body
586  auto stmts_list = moveAllReturnsToEnd(def.statements());
587  emitStatements(stmts_list.begin(), stmts_list.end());
588  std::vector<Argument> returns = {emitOutput(def.range(), schema, block)};
589  return {def.name().name(), "", std::move(arguments), std::move(returns)};
590  }
591 
592  std::vector<IValue> evaluateDefaults(
593  const SourceRange& r,
594  const std::vector<Expr>& default_types,
595  const std::vector<Expr>& default_exprs) {
596  std::vector<IValue> default_values;
597  if (default_exprs.empty())
598  return default_values;
599  // To evaluate the default expressions, we create a graph with no inputs,
600  // and whose returns are the default values we need.
601  // We then run constant prop on this graph and check the results are
602  // constant. This approach avoids having to have separate handling of
603  // default arguments from standard expressions by piecing together existing
604  // machinery for graph generation, constant propgation, and constant
605  // extraction.
606  auto tuple_type = Subscript::create(
607  r,
608  Var::create(r, Ident::create(r, "Tuple")),
609  List<Expr>::create(r, default_types));
610  auto blank_decl = Decl::create(
611  r, List<Param>::create(r, {}), Maybe<Expr>::create(r, tuple_type));
612 
613  auto tuple_expr =
614  TupleLiteral::create(r, List<Expr>::create(r, default_exprs));
615  auto ret = Return::create(r, tuple_expr);
616  auto def = Def::create(
617  r,
618  Ident::create(r, "defaults"),
619  blank_decl,
620  List<Stmt>::create(r, {ret}));
621  auto m = std::make_shared<Module>();
622  defineMethodsInModule(m, {def}, {resolver}, c10::nullopt);
623  Stack stack;
624  m->get_method("defaults").run(stack);
625  return stack.at(0).toTuple()->elements();
626  }
627 
628  std::vector<Argument> parseArgsFromDecl(
629  const Decl& decl,
630  const c10::optional<Self>& self) {
631  auto params_begin = decl.params().begin();
632  auto params_end = decl.params().end();
633  if (self) {
634  ++params_begin;
635  }
636  std::vector<Argument> retval;
637 
638  std::vector<Expr> default_types;
639  std::vector<Expr> default_exprs;
640  // gather any non-empty default arguments
641  for (auto it = params_begin; it != params_end; ++it) {
642  auto param = *it;
643  auto def = param.defaultValue();
644  if (def.present()) {
645  default_types.emplace_back(param.type());
646  default_exprs.emplace_back(def.get());
647  }
648  }
649  auto default_values =
650  evaluateDefaults(decl.range(), default_types, default_exprs);
651 
652  auto defaults_it = default_values.begin();
653  for (auto it = params_begin; it != params_end; ++it) {
654  auto decl_arg = *it;
655 
656  TypePtr type;
658 
659  // BroadcastList list can only appear at the argument level
660  if (auto maybe_broad_list = parseBroadcastList(decl_arg.type())) {
661  type = maybe_broad_list->first;
662  N = maybe_broad_list->second;
663  } else {
664  type = parseTypeFromExpr(decl_arg.type());
665  N = c10::nullopt;
666  }
667  c10::optional<IValue> default_value = c10::nullopt;
668  if (decl_arg.defaultValue().present()) {
669  default_value = *defaults_it++;
670  }
671  auto arg = Argument(
672  decl_arg.ident().name(),
673  type,
674  N,
675  default_value,
676  decl_arg.kwarg_only());
677  retval.push_back(arg);
678  }
679  return retval;
680  }
681 
682  std::vector<Argument> parseReturnFromDecl(const Decl& decl) {
683  // we represent no annoation on a return type as having no values in the
684  // schema's return() list
685  // in emitReturn we take the actual return value to be the value of the
686  // return statement if no one was provided here
687  if (!decl.return_type().present())
688  return {};
689 
690  if (parseBroadcastList(decl.return_type().get()))
691  throw ErrorReport(decl.return_type().range())
692  << "Broadcastable lists cannot appear as a return type";
693  auto parsed_type = parseTypeFromExpr(decl.return_type().get());
694  return {Argument(
695  "",
696  parsed_type,
697  /*N =*/c10::nullopt,
698  /*default_value =*/c10::nullopt,
699  /*kwarg_only =*/false)};
700  }
701  FunctionSchema extractSchemaFromDef(
702  const Def& def,
703  const c10::optional<Self>& self) {
704  const auto name = def.name().name();
705  std::vector<Argument> args = parseArgsFromDecl(def.decl(), self);
706  std::vector<Argument> returns = parseReturnFromDecl(def.decl());
707  return FunctionSchema(
708  name, "", std::move(args), std::move(returns), false, false);
709  }
710 
711  std::vector<Argument> emitFormalArguments(
712  const Def& def,
713  const c10::optional<Self>& self,
714  const FunctionSchema& schema,
715  Block* block) {
716  std::vector<Argument> arguments; // for schema
717  // inputs
718  auto it = def.decl().params().begin();
719  auto end = def.decl().params().end();
720  auto expected_annotation_size = def.decl().params().size();
721  if (self) {
722  expected_annotation_size--;
723  }
724  if (schema.arguments().size() != expected_annotation_size) {
725  throw ErrorReport(def.decl().params().range())
726  << "Number of type annotations for"
727  << " function parameters (" << schema.arguments().size() << ")"
728  << " does not match the number of parameters on the function ("
729  << expected_annotation_size << ")!";
730  }
731 
732  if (self) {
733  AT_ASSERT(it != end);
734  const auto& name = (*it).ident().name();
735  if (auto type = self->asFirstClass()) {
736  Value* new_input =
737  block->addInput()->setUniqueName(name)->setType(type);
738  environment_stack->setVar((*it).ident().range(), name, new_input);
739  arguments.emplace_back(name, type);
740  } else {
741  environment_stack->setSugaredVar(def.range(), name, self->asSugared());
742  }
743  ++it;
744  }
745  size_t arg_annotation_idx = 0;
746  for (; it != end; ++it) {
747  auto& name = (*it).ident().name();
748  // Add the input to the graph
749  Value* new_input = block->addInput();
750  if (meaningfulName(name)) {
751  new_input->setUniqueName(name);
752  }
753  environment_stack->setVar((*it).ident().range(), name, new_input);
754 
755  // Record the type for the schema and set the Type on the Value*
756  arguments.push_back(schema.arguments().at(arg_annotation_idx++));
757  new_input->setType(arguments.back().type());
758  }
759  return arguments;
760  }
761 
762  Argument emitOutput(
763  const SourceRange& range,
764  const FunctionSchema& schema,
765  Block* block) {
766  // rewrites ensure there is always a return statement in program
767  AT_ASSERT(def_stack_.back().merged_return_type_);
768  // outputs
769  Value* result = environment_stack->getVar("$return", range);
770  block->registerOutput(result);
771  return Argument("", def_stack_.back().merged_return_type_);
772  }
773 
774  void emitStatements(const List<Stmt>& statements) {
775  return emitStatements(statements.begin(), statements.end());
776  }
777  std::pair<std::shared_ptr<Graph>, Value*> lambdaLift(Block* block) {
778  auto subgraph = std::make_shared<Graph>();
779  // note: type is set later on pack_context and context when we know it
780  Node* pack_context =
781  graph->insertNode(graph->create(prim::TupleConstruct, {}, 1));
782  Value* context = subgraph->addInput("context");
783  // cannot use createTupleUnpack because the type is not known yet
784  Node* unpack_context =
785  subgraph->insertNode(subgraph->create(prim::TupleUnpack, {context}, 0));
786 
787  std::unordered_map<Value*, Value*> captures;
788  auto env = [&](Value* v) -> Value* {
789  auto it = captures.find(v);
790  if (it != captures.end()) {
791  return it->second;
792  }
793  pack_context->addInput(v);
794  Value* r = unpack_context->addOutput()->copyMetadata(v);
795  captures[v] = r;
796  return r;
797  };
798  subgraph->block()->cloneFrom(block, env);
799  auto context_type = TupleType::create(
800  fmap(pack_context->inputs(), [](Value* v) { return v->type(); }));
801  pack_context->output()->setType(context_type);
802  context->setType(context_type);
803  return std::make_pair(std::move(subgraph), pack_context->output());
804  }
805  // XXX - right now closures are used _only_ for defining gradients internally
806  // There are several unfinished aspects that make them unusable generally
807  // 1. We do not have a type, ivalue, operator to represent prim::Function, so
808  // closure_node has type None
809  // and any graphs that contain it cannot be run
810  // 2. There is no export logic for it yet, so it cannot be
811  // exported/python_printed
812  // 3. There is nothing preventing the assignment of already existing variables
813  // inside the closures
814  // the changes to those variables will just get forgotten.
815  // 4. There is no parsing support in frontend.py, this is intentional since it
816  // prevents people from accidentally using this feature.
817  void emitClosure(const Def& def) {
818  Node* closure_node = graph->insertNode(graph->create(prim::Function, 1));
819  closure_node->output()->setType(
820  NoneType::get()); // it is not a real thing yet, so just say the type is
821  // none.
822  Block* block = closure_node->addBlock();
823  {
824  WithInsertPoint guard(block);
825  pushFrame(block, /*starts_def=*/true);
826  emitDef(
827  def,
828  c10::nullopt,
829  block); // ignore schema return, we just wont use it for now since we
830  // never create a Method for the closure
831  popFrame(/*ends_def=*/true);
832  }
833  std::shared_ptr<Graph> subgraph;
834  Value* context;
835  std::tie(subgraph, context) = lambdaLift(block);
836  runCleanupPasses(subgraph);
837  closure_node->eraseBlock(0);
838  closure_node->g_(attr::Subgraph, std::move(subgraph));
839  auto tup =
840  graph->insertNode(graph->createTuple({closure_node->output(), context}))
841  ->output();
842  environment_stack->setVar(def.name().range(), def.name().name(), tup);
843  }
844 
845  void emitReturn(const Return& stmt) {
846  Value* result = emitExpr(stmt.expr());
847  TypePtr result_type = def_stack_.back().declared_return_type_;
848  // result type is annotated, every return must convert to that type
849  if (result_type) {
850  // this guard skips implicit conversion from None -> Tensor for the return
851  // type. otherwise forgetting a return a function returning a tensor will
852  // cause a None to be converted to a tensor.
853  if (!(result_type->isSubtypeOf(TensorType::get()) &&
854  result->type()->isSubtypeOf(NoneType::get()))) {
855  result = tryConvertToType(
856  stmt.range(),
857  *graph,
858  result_type,
859  result,
860  /*allow_conversions=*/true);
861  }
862 
863  if (!result->type()->isSubtypeOf(result_type)) {
864  throw ErrorReport(stmt.range())
865  << "Return value was annotated as having type "
866  << result_type->python_str() << " but is actually of type "
867  << result->type()->python_str();
868  }
869  } else {
870  result_type = def_stack_.back().merged_return_type_;
871  if (!result_type) {
872  result_type = result->type();
873  }
874  if (!unifyTypes(result_type, result->type())) {
875  throw ErrorReport(stmt.range())
876  << "Previous return statement returned a value of type "
877  << result_type->python_str()
878  << " but this return statement returns a value of type "
879  << result->type()->python_str();
880  }
881  }
882  AT_ASSERT(result_type);
883  def_stack_.back().merged_return_type_ = result_type;
884  environment_stack->setVar(stmt.range(), "$return", result);
885  }
886 
887  void emitStatements(
890  for (; begin != end; ++begin) {
891  auto stmt = *begin;
892  switch (stmt.kind()) {
893  case TK_IF:
894  emitIf(If(stmt));
895  break;
896  case TK_WHILE:
897  emitWhile(While(stmt));
898  break;
899  case TK_FOR:
900  emitFor(For(stmt));
901  break;
902  case TK_ASSIGN:
903  emitAssignment(Assign(stmt));
904  break;
905  case TK_AUG_ASSIGN:
906  emitAugAssignment(AugAssign(stmt));
907  break;
908  case TK_GLOBAL:
909  for (auto ident : Global(stmt).names()) {
910  const auto& name = Ident(ident).name();
911  environment_stack->setVar(
912  ident.range(), name, graph->addInput(name));
913  }
914  break;
915  case TK_EXPR_STMT: {
916  auto expr = ExprStmt(stmt).expr();
917  emitSugaredExpr(expr, 0);
918  } break;
919  case TK_RAISE:
920  emitRaise(Raise(stmt).range());
921  break;
922  case TK_ASSERT:
923  emitAssert(Assert(stmt));
924  break;
925  case TK_RETURN: {
926  emitReturn(Return(stmt));
927  } break;
928  case TK_PASS:
929  // Emit nothing for pass
930  break;
931  case TK_DEF:
932  emitClosure(Def(stmt));
933  break;
934  default:
935  throw ErrorReport(stmt)
936  << "Unrecognized statement kind " << kindToString(stmt.kind());
937  }
938  }
939  }
940 
941  std::shared_ptr<Environment> emitSingleIfBranch(
942  Block* b,
943  const List<Stmt>& branch,
944  const Refinements& refinements) {
945  pushFrame(b);
946  WithInsertPoint guard(b);
947  insertRefinements(refinements);
948  emitStatements(branch);
949  return popFrame();
950  }
951 
952  Node* create(Symbol kind, const SourceRange& loc, size_t n_outputs) {
953  return graph->create(kind, n_outputs)
954  ->setSourceLocation(std::make_shared<SourceRange>(loc));
955  }
956 
957  Value* emitTernaryIf(const TernaryIf& expr) {
958  const auto& bool_info = findRefinements(expr.cond());
959  Value* cond_value = emitCond(expr.cond());
960  auto true_expr = [&] {
961  insertRefinements(bool_info.true_refinements_);
962  return emitExpr(expr.true_expr());
963  };
964  auto false_expr = [&] {
965  insertRefinements(bool_info.false_refinements_);
966  return emitExpr(expr.false_expr());
967  };
968  return emitIfExpr(expr.range(), cond_value, true_expr, false_expr);
969  }
970 
971  // Insert subtyping refinements
972  void insertRefinements(const Refinements& ref) {
973  for (const auto& name_mappings : ref.mappings_) {
974  const std::string& name = name_mappings.first;
975  auto type = name_mappings.second.first;
976  const auto& range = *name_mappings.second.second;
977  Value* v = environment_stack->getVar(name, range);
978  if (type != NoneType::get()) {
979  Value* output = graph->insert(prim::unchecked_unwrap_optional, {v});
980  environment_stack->setVar(range, name, output);
981  }
982  // todo @eellison - revisit inserting Nones when None subtypes Optional
983  }
984  }
985 
986  Value* emitShortCircuitIf(
987  const SourceRange& loc,
988  const TreeRef& first_expr,
989  const TreeRef& second_expr,
990  bool is_or) {
991  const auto first_bool_info = findRefinements(first_expr);
992  Value* first_value = emitCond(Expr(first_expr));
993 
994  const Refinements* first_expr_refinements;
995  const Refinements* second_expr_refinements;
996  // if it's an OR the first expr is emitted in the true branch
997  // and the second expr in the false branch, if it's an AND the opposite
998  if (is_or) {
999  first_expr_refinements = &first_bool_info.true_refinements_;
1000  second_expr_refinements = &first_bool_info.false_refinements_;
1001  } else {
1002  first_expr_refinements = &first_bool_info.false_refinements_;
1003  second_expr_refinements = &first_bool_info.true_refinements_;
1004  }
1005 
1006  auto get_first_expr = [&] {
1007  insertRefinements(*first_expr_refinements);
1008  return first_value;
1009  };
1010 
1011  auto get_second_expr = [&] {
1012  insertRefinements(*second_expr_refinements);
1013  return emitCond(Expr(second_expr));
1014  };
1015 
1016  // if this is an OR, eval second expression if first expr is False
1017  // If this is an AND, eval second expression if first expr is True
1018  if (is_or) {
1019  return emitIfExpr(loc, first_value, get_first_expr, get_second_expr);
1020  } else {
1021  return emitIfExpr(loc, first_value, get_second_expr, get_first_expr);
1022  }
1023  }
1024 
1025  Value* emitIfExpr(
1026  const SourceRange& range,
1027  Value* cond_value,
1028  std::function<Value*()> true_expr,
1029  std::function<Value*()> false_expr) {
1030  Node* n = graph->insertNode(create(prim::If, range, 0));
1031 
1032  n->addInput(cond_value);
1033  auto* true_block = n->addBlock();
1034  auto* false_block = n->addBlock();
1035 
1036  auto emit_if_expr = [this](Block* b, std::function<Value*()> expr_value) {
1037  pushFrame(b);
1038  WithInsertPoint guard(b);
1039  Value* out_val = expr_value();
1040  b->registerOutput(out_val);
1041  popFrame();
1042  };
1043 
1044  emit_if_expr(true_block, std::move(true_expr));
1045  emit_if_expr(false_block, std::move(false_expr));
1046 
1047  auto true_type = true_block->outputs().at(0)->type();
1048  auto false_type = false_block->outputs().at(0)->type();
1049  auto unified = unifyTypes(true_type, false_type);
1050  if (!unified) {
1051  throw ErrorReport(range)
1052  << "if-expression's true branch has type " << true_type->str()
1053  << " but false branch has type " << false_type->str();
1054  }
1055 
1056  // Add op outputs
1057  auto expr_value = n->addOutput()->setType(*unified); // Resulting value
1058 
1059  return expr_value;
1060  }
1061 
1062  Value* emitCond(const Expr& cond) {
1063  Value* v = emitExpr(cond);
1064  if (!v->type()->isSubtypeOf(BoolType::get())) {
1065  ErrorReport error(cond);
1066  error << "expected a boolean expression for condition but found "
1067  << v->type()->str();
1068  if (v->type()->isSubtypeOf(TensorType::get())) {
1069  error << ", to use a tensor in a boolean"
1070  << " expression, explicitly cast it with `bool()`";
1071  }
1072  throw error;
1073  }
1074  return v;
1075  }
1076 
1077  void emitIfElseBlocks(Value* cond_value, const If& stmt) {
1078  Node* n = graph->insertNode(create(prim::If, stmt.range(), 0));
1079  n->addInput(cond_value);
1080  const auto bool_info = findRefinements(stmt.cond());
1081  auto* true_block = n->addBlock();
1082  auto* false_block = n->addBlock();
1083 
1084  // Emit both blocks once to get the union of all mutated values
1085  auto save_true = emitSingleIfBranch(
1086  true_block, stmt.trueBranch(), bool_info.true_refinements_);
1087  auto save_false = emitSingleIfBranch(
1088  false_block, stmt.falseBranch(), bool_info.false_refinements_);
1089 
1090  // In python, every variable assigned in an if statement escapes
1091  // the scope of the if statement (all variables are scoped to the function).
1092  // Script is a subset of python: we consider variables to be in scope
1093  // as long as there is a definition of the variable along all paths
1094  // through the if statemnent
1095  // ----
1096  // if ...:
1097  // a =
1098  // else:
1099  // ...
1100  // ... = a # error, a is not defined along all paths
1101  // ----
1102  // if ...:
1103  // a =
1104  // else:
1105  // a =
1106  // ... = a # OK, a is defined along all paths
1107  // ----
1108  // a = ...
1109  // if ...:
1110  // a =
1111  // ... = a # OK, a is defined along all paths
1112 
1113  // ordered set, because we want deterministic graph output
1114  std::set<std::string> mutated_variables;
1115 
1116  for (auto& v : save_true->definedVariables()) {
1117  if (save_false->findInAnyFrame(v)) {
1118  mutated_variables.insert(v);
1119  }
1120  }
1121  for (auto& v : save_false->definedVariables()) {
1122  if (save_true->findInAnyFrame(v)) {
1123  mutated_variables.insert(v);
1124  }
1125  }
1126 
1127  // Register outputs in each block
1128  for (const auto& x : mutated_variables) {
1129  auto tv = save_true->getVar(x, stmt.range());
1130  auto fv = save_false->getVar(x, stmt.range());
1131  auto unified = unifyTypes(tv->type(), fv->type());
1132 
1133  // attempt to unify the types. we allow variables to be set to different
1134  // types in each branch as long as that variable is not already in scope,
1135  // or if that variable does not get used later. here, we save the error
1136  // so that the error message will be more informative in the case that is
1137  // used later. When a is accessed in (a + 1), the error will get printed
1138  // if cond:
1139  // a = 1
1140  // else:
1141  // a = tensor
1142  // b = a + 1
1143  //
1144  if (!unified) {
1145  ErrorReport error(stmt);
1146  error << "Type mismatch: " << x << " is set to type "
1147  << tv->type()->str() << " in the true branch"
1148  << " and type " << fv->type()->str() << " in the false branch";
1149  if (save_true->findInParentFrame(x) ||
1150  save_false->findInParentFrame(x)) {
1151  throw error;
1152  } else {
1153  // error gets saved in the lowest environment because all
1154  // variables are scoped to the function. doesn't matter if this
1155  // accessed through save_true or save_false
1156  save_true->setVariableTypeError(x, error.what());
1157  continue;
1158  }
1159  }
1160  true_block->registerOutput(tv);
1161  false_block->registerOutput(fv);
1162  environment_stack->setVar(
1163  stmt.range(), x, n->addOutput()->setType(*unified));
1164  }
1165  }
1166 
1167  void emitIf(const If& stmt) {
1168  // NOTE: emitIf checks on If stmt condition to see if the cond AST kind ==
1169  // is/is not, for such cases we do meta programming and disable emitting the
1170  // corresponding branches
1171  Expr cond = stmt.cond();
1172 
1173  if (cond.kind() != TK_IS && cond.kind() != TK_ISNOT) {
1174  // emit normal IF stmt for cases except TK_IS and TK_ISNOT
1175  Value* cond_value = emitCond(cond);
1176  emitIfElseBlocks(cond_value, stmt);
1177  return;
1178  }
1179  // meta programming on AST for is/is not cases and emit branches base on the
1180  // possible output of cond
1181  auto cond_op = BinOp(cond);
1182  SugaredValuePtr lhs_val = emitSugaredExpr(cond_op.lhs(), 1);
1183  SugaredValuePtr rhs_val = emitSugaredExpr(cond_op.rhs(), 1);
1184 
1185  List<Stmt> always_none_branch =
1186  cond.kind() == TK_IS ? stmt.trueBranch() : stmt.falseBranch();
1187  List<Stmt> never_none_branch =
1188  cond.kind() == TK_IS ? stmt.falseBranch() : stmt.trueBranch();
1189 
1190  auto lhs_none = lhs_val->isNone();
1191  auto rhs_none = rhs_val->isNone();
1192 
1193  // Dispatch logic (A: ALWAYS, N: NEVER, M: MAYBE):
1194  //
1195  // AA, -> emit always_none_branch
1196  // AN , NA-> emit never_none_branch
1197  // MA, MM, MN, NM, NN, AM -> emit both conditional branches
1198 
1199  if (lhs_none == ALWAYS && rhs_none == ALWAYS) {
1200  // None is/is not None: only emit the always_none_branch
1201  emitStatements(always_none_branch);
1202  } else if (
1203  (lhs_none == ALWAYS && rhs_none == NEVER) ||
1204  (lhs_none == NEVER && rhs_none == ALWAYS)) {
1205  // lhs_val/rhs_val with A/M: only emit never_none_branch
1206  emitStatements(never_none_branch);
1207  } else {
1208  // all other cases for lhs_val and rhs_val
1209  // emit the whole If stmt as usual, finish emitCond first
1210  auto lhs_range = cond_op.lhs().get()->range();
1211  auto rhs_range = cond_op.rhs().get()->range();
1212 
1213  auto kind = getNodeKind(cond.kind(), cond.get()->trees().size());
1214  Value* cond_value = emitBuiltinCall(
1215  cond.get()->range(),
1216  *method.graph(),
1217  kind,
1218  c10::nullopt,
1219  {lhs_val->asValue(lhs_range, method),
1220  rhs_val->asValue(rhs_range, method)},
1221  {},
1222  /*required=*/true);
1223  emitIfElseBlocks(cond_value, stmt);
1224  }
1225  }
1226 
1227  // *********************** Loop Operators ************************************
1228  // Emits a loop operators conforming to the semantics specified at
1229  // https://github.com/onnx/onnx/blob/master/docs/Operators.md#experimental-loop
1230  // TODO: implement scan_outputs
1231 
1232  // the format of the Loop instruction is:
1233  // loop_carried_outputs* = Loop(max_trip_count, start_condition,
1234  // loop_carried_inputs*)
1235  // block0(loop_counter, loop_carried_block*) {
1236  // <body>
1237  // -> (continue_condition, loop_carried_block_outputs*)
1238  // }
1239  // all loop_carried_... lists are the same length and represent the value of
1240  // loop-carried variables whose definitions are updated as the loop executes
1241  // in a way that ensure single static assignment.
1242 
1243  void emitLoopCommon(
1244  SourceRange range,
1245  c10::optional<Expr> max_trip_count,
1246  c10::optional<Expr> cond,
1247  const List<Stmt>& body,
1248  c10::optional<Ident> itr_ident,
1249  bool in_list = false) {
1250  Node* n = graph->insertNode(create(prim::Loop, range, 0));
1251  Value *max_trip_count_val, *cond_val;
1252  {
1253  WithInsertPoint guard(n);
1254  if (max_trip_count) {
1255  if (in_list) {
1256  auto listArg = emitExpr(max_trip_count.value());
1257 
1258  max_trip_count_val = emitBuiltinCall(
1259  max_trip_count->range(),
1260  *graph,
1261  aten::len,
1262  c10::nullopt,
1263  {listArg},
1264  {},
1265  /*required=*/true);
1266  } else {
1267  max_trip_count_val = ensureInt(
1268  max_trip_count->range(), emitExpr(max_trip_count.value()));
1269  }
1270  } else {
1271  max_trip_count_val = materializeConstant(
1272  std::numeric_limits<int64_t>::max(),
1273  *graph,
1274  range,
1275  integral_constants);
1276  }
1277  if (cond) {
1278  cond_val = emitCond(cond.value());
1279  } else {
1280  cond_val = graph->insertConstant(true, nullptr, range);
1281  }
1282  }
1283  n->addInput(max_trip_count_val);
1284  n->addInput(cond_val);
1285  auto* body_block = n->addBlock();
1286  Value* trip_count =
1287  body_block->addInput()->setType(IntType::get()); // Iteration num
1288 
1289  {
1290  pushFrame(body_block);
1291  WithInsertPoint guard(body_block);
1292  if (itr_ident) {
1293  if (in_list) {
1294  // set user's iterator variable to the current element
1295  auto listArg = emitExpr(max_trip_count.value());
1296  trip_count = emitBuiltinCall(
1297  max_trip_count->range(),
1298  *graph,
1299  aten::select,
1300  c10::nullopt,
1301  {listArg, trip_count},
1302  {},
1303  /*required=*/true);
1304  }
1305  environment_stack->setVar(
1306  itr_ident->range(), itr_ident->name(), trip_count);
1307  }
1308  emitStatements(body);
1309 
1310  // Also emit the conditional
1311  if (cond) {
1312  Value* body_cond_value = emitCond(cond.value());
1313  body_block->registerOutput(body_cond_value);
1314  } else {
1315  Value* cond_value_dummy = graph->insertConstant(true, nullptr, range);
1316  body_block->registerOutput(cond_value_dummy);
1317  }
1318 
1319  auto body_frame = popFrame();
1320  auto outer_frame = environment_stack;
1321 
1322  // Add block outputs to correspond to each captured input
1323  // some of these will be removed.
1324  for (const auto& x : body_frame->captured_inputs) {
1325  auto fv = body_frame->getValueInThisFrame(range, x);
1326  body_block->registerOutput(fv);
1327  }
1328 
1329  // Remove inputs for values that did not mutate within the
1330  // block
1331  body_frame->deleteExtraInputs(range);
1332 
1333  // register node inputs/outputs for the true loop carried deps,
1334  for (size_t i = 0; i < body_frame->captured_inputs.size(); ++i) {
1335  auto x = body_frame->captured_inputs[i];
1336  n->addInput(outer_frame->getVar(x, range));
1337  // body_block->inputs(): loop_counter, lcd0, lcd1, ...
1338  // captured_inputs: lcd0, lcd1, ...
1339  auto typ = body_block->inputs()[i + 1]->type();
1340  outer_frame->setVar(range, x, n->addOutput()->setType(typ));
1341  }
1342  }
1343  }
1344 
1345  void emitForRange(
1346  const SourceRange& range,
1347  const Ident& target,
1348  const List<Expr>& args,
1349  const List<Stmt>& body) {
1350  // TODO: start, stop, step loop
1351  if (args.size() != 1) {
1352  throw ErrorReport(range)
1353  << "range() expects 1 argument but got " << args.size();
1354  }
1355  emitLoopCommon(range, {args[0]}, {}, body, target);
1356  }
1357 
1358  void emitFor(const For& stmt) {
1359  // For now, we only support range loops. e.g. for i in range(3): ...
1360  auto targets = stmt.targets();
1361  auto itrs = stmt.itrs();
1362  auto body = stmt.body();
1363 
1364  if (stmt.itrs().size() != 1) {
1365  throw ErrorReport(stmt)
1366  << "List of iterables is not supported currently.";
1367  }
1368  if (targets.size() != 1) {
1369  throw ErrorReport(stmt)
1370  << "Iteration variable unpacking is not supported";
1371  }
1372 
1373  if (targets[0].kind() != TK_VAR) {
1374  throw ErrorReport(targets[0])
1375  << "unexpected expression in variable initialization of for loop";
1376  }
1377  auto target = Var(targets[0]).name();
1378 
1379  // match range(<expr>) style loops
1380  // itrs must consist of a single Apply node
1381  if (itrs[0].kind() == TK_APPLY) {
1382  Apply range_iterator = Apply(itrs[0]);
1383  if (range_iterator.callee().kind() == TK_VAR) {
1384  Var var = Var(range_iterator.callee());
1385  if (var.name().name() == "range") {
1386  return emitForRange(
1387  stmt.range(), target, range_iterator.inputs(), body);
1388  }
1389  }
1390  }
1391 
1392  // it isn't a range(<expr>) loop, treat it as a sugared value that maybe can
1393  // be unrolled
1394  auto sv = emitSugaredExpr(itrs[0], 1);
1395  // check if a value is simple and list-like
1396  if (auto siv = std::dynamic_pointer_cast<SimpleValue>(sv)) {
1397  if (siv->getValue()->type()->kind() == TypeKind::ListType) {
1398  return emitLoopCommon(
1399  stmt.range(), {itrs[0]}, {}, body, {target}, true);
1400  }
1401  }
1402  auto instances = sv->asTuple(stmt.range(), method);
1403  const std::string& target_name = target.name();
1404  pushFrame(environment_stack->block());
1405  for (const auto& inst : instances) {
1406  environment_stack->setSugaredVar(itrs[0].range(), target_name, inst);
1407  emitStatements(body);
1408  }
1409 
1410  for (const auto& n : environment_stack->definedVariables()) {
1411  if (environment_stack->findInParentFrame(n)) {
1412  environment_stack->next->setVar(
1413  stmt.range(), n, environment_stack->getVar(n, stmt.range()));
1414  }
1415  }
1416  popFrame();
1417  }
1418 
1419  void emitWhile(const While& stmt) {
1420  auto cond = stmt.cond();
1421  emitLoopCommon(stmt.range(), {}, {cond}, stmt.body(), {});
1422  }
1423 
1424  // Currently we do not support assigning exceptions to variables,
1425  // a = Exception("hi")
1426  // raise a
1427  //
1428  // We ignore the expression following raise
1429  //
1430  // NYI: add exception logic to control-flow nodes
1431  // if True:
1432  // a = 1
1433  // else
1434  // raise Exception("Hi")
1435  // print(a)
1436  void emitRaise(const SourceRange& loc) {
1437  const std::string exception = "Exception";
1438  auto string_input = insertConstant(*graph, exception, nullptr, loc);
1439  graph->insert(prim::RaiseException, {string_input}, {}, loc);
1440  }
1441 
1442  void emitAssert(const Assert& stmt) {
1443  Value* cond_value = emitCond(stmt.test());
1444  Node* n = graph->insertNode(create(prim::If, stmt.range(), 0));
1445 
1446  n->addInput(cond_value);
1447  /* true_block =*/n->addBlock();
1448  auto* false_block = n->addBlock();
1449 
1450  // if assert test is false throw exception
1451  pushFrame(false_block);
1452  WithInsertPoint guard(false_block);
1453  emitRaise(stmt.range());
1454  popFrame();
1455  }
1456 
1457  // Validate that the `lhs` Expr's in an assignment statement are valid. That
1458  // is:
1459  //
1460  // 1) All lhs Expr's are either Var or Starred nodes
1461  // 2) There is at most one Starred node in the lhs Expr
1462  // 3) A Starred node can only appear when there is another non-Starred lhs
1463  // Expr. Concretely this means that `*abc = func()` is illegal. Unpacking
1464  // all outputs into a tuple is covered by `abc = func()`.
1465  bool calcNumStarredUnpack(const List<Expr>& lhs, const SourceRange& r) {
1466  size_t num_normal_assign = 0;
1467  size_t num_starred = 0;
1468  for (const auto& assignee : lhs) {
1469  if (assignee.kind() == TK_VAR || assignee.kind() == TK_SUBSCRIPT) {
1470  num_normal_assign++;
1471  } else if (assignee.kind() == TK_STARRED) {
1472  num_starred++;
1473  } else {
1474  throw ErrorReport(assignee) << "lhs of assignment must be a variable, "
1475  << "subscript, or starred expression.";
1476  }
1477  }
1478 
1479  if (num_starred > 1) {
1480  throw ErrorReport(r)
1481  << "Only one starred expression is allowed on the lhs.";
1482  }
1483 
1484  if (num_starred > 0 && num_normal_assign == 0) {
1485  throw ErrorReport(r) << "A Starred expression may only appear on the "
1486  << "lhs within the presence of another non-starred"
1487  << " expression.";
1488  }
1489 
1490  return num_starred;
1491  }
1492 
1493  // Get the appropriate builtin op for this augmented assignment
1494  // If the RHS is a tensor, return the corresponding ATen in-place op
1495  // If it's a list of scalars, then return the corresponding list augment op
1496  Symbol getAugOp(const AugAssign& stmt, bool isTensor) {
1497  switch (stmt.aug_op()) {
1498  case '+':
1499  return isTensor ? aten::add_ : aten::add;
1500  case '-':
1501  return isTensor ? aten::sub_ : aten::sub;
1502  case '/':
1503  return isTensor ? aten::div_ : aten::div;
1504  case '*':
1505  return isTensor ? aten::mul_ : aten::mul;
1506  default:
1507  throw ErrorReport(stmt)
1508  << "Unknown augmented assignment: " << kindToString(stmt.aug_op());
1509  }
1510  }
1511 
1512  // Emit nodes for augmented assignments like `+=`
1513  void emitAugAssignment(const AugAssign& stmt) {
1514  switch (stmt.lhs().kind()) {
1515  case TK_VAR: {
1516  emitAugAssignmentToVar(stmt);
1517  } break;
1518  case '.': {
1519  emitAugAssignmentToSelectVar(stmt);
1520  } break;
1521  case TK_SUBSCRIPT: {
1522  emitAugAssignmentToSubscript(stmt);
1523  } break;
1524  default:
1525  throw ErrorReport(stmt.lhs())
1526  << "unexpected expression on "
1527  << "left-hand side of augmented assignment.";
1528  }
1529  }
1530 
1531  // This will be called when there is a class param or module buffer
1532  // mutation which make the LHS of the expr be a select expression
1533  //
1534  // Example like:
1535  // class A(Module):
1536  // def __init__():
1537  // self.register_buffer("running_var", torch.zeros(1))
1538  //
1539  // def forward():
1540  // self.num_batches += 1
1541  //
1542  // In this case we will only consider the scenario that the module
1543  // buffer type is a tensor, and we emit the corresponding tensor
1544  // in place op, and throw error for other unsupported types
1545  void emitAugAssignmentToSelectVar(const AugAssign& stmt) {
1546  const auto lhs = Select(stmt.lhs());
1547  const auto lhsSugaredVar =
1548  environment_stack->getSugaredVar(Var(lhs.value()).name());
1549  const auto lhsValue =
1550  lhsSugaredVar->attr(lhs.range(), method, lhs.selector().name())
1551  ->asValue(lhs.range(), method);
1552  if (lhsValue->type()->isSubtypeOf(TensorType::get())) {
1553  // for module parameter/buffer assignment, only consider tensor types,
1554  // emit the corresponding in-place op
1555  const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
1556  const auto self = NamedValue(stmt.lhs().range(), "self", lhsValue);
1557  emitBuiltinCall(
1558  stmt.range(),
1559  *method.graph(),
1560  getAugOp(stmt, /*isTensor=*/true),
1561  self,
1562  {rhs},
1563  {},
1564  /*required=*/true);
1565 
1566  } else {
1567  throw ErrorReport(stmt.lhs())
1568  << "left-hand side of augmented assignment to module "
1569  << "parameters/buffers can only be tensor types";
1570  }
1571  }
1572 
1573  void emitAugAssignmentToVar(const AugAssign& stmt) {
1574  const auto lhs = Var(stmt.lhs());
1575  const auto lhsValue = environment_stack->getSugaredVar(lhs.name())
1576  ->asValue(lhs.range(), method);
1577  if (lhsValue->type()->isSubtypeOf(TensorType::get())) {
1578  // for tensors, emit the corresponding in-place op
1579  const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
1580  const auto self = NamedValue(stmt.lhs().range(), "self", lhsValue);
1581  const auto output = emitBuiltinCall(
1582  stmt.range(),
1583  *method.graph(),
1584  getAugOp(stmt, /*isTensor=*/true),
1585  self,
1586  {rhs},
1587  {},
1588  /*required=*/true);
1589 
1590  environment_stack->setVar(lhs.range(), lhs.name().name(), output);
1591  } else {
1592  // for primitive types, desugar into a simple assignment
1593  // e.g. foo += 1 becomes foo.2 = foo + 1
1594  Ident lhs = Var(stmt.lhs()).name();
1595  Expr expr = BinOp::create(
1596  stmt.range(),
1597  stmt.aug_op(),
1598  Var::create(lhs.range(), lhs),
1599  stmt.rhs());
1600  environment_stack->setVar(lhs.range(), lhs.name(), emitExpr(expr));
1601  }
1602  }
1603 
1604  void emitAugAssignmentToSubscript(const AugAssign& stmt) {
1605  // Process the base list value
1606  const auto lhs = Subscript(stmt.lhs());
1607  const auto sliceable = emitExpr(lhs.value());
1608 
1609  if (sliceable->type()->isSubtypeOf(TensorType::get())) {
1610  // If it's a tensor, just fully evaluate the subscript operation and emit
1611  // an in-place assignment
1612  std::vector<Value*> tensorIndices;
1613  Value* sliced;
1614  std::tie(sliced, tensorIndices) = emitIntAndSliceIndexing(
1615  lhs.range(), sliceable, lhs.subscript_exprs());
1616 
1617  const auto slicedArg = NamedValue(stmt.lhs().range(), "self", sliced);
1618  const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
1619  if (tensorIndices.size() == 0) {
1620  // Common case: we only tried to index with int and slices. Emit the
1621  // correct augmented assignment op to the sliced value
1622  emitBuiltinCall(
1623  stmt.range(),
1624  *method.graph(),
1625  getAugOp(stmt, /*isTensor=*/true),
1626  slicedArg,
1627  {rhs},
1628  {},
1629  /*required=*/true);
1630  } else {
1631  // Special case: we tried to do "advanced indexing". Lower this expr
1632  // into `index` and `index_put_` ops with tensordices of Tensor?[]
1633  const auto indices = graph
1634  ->insertNode(graph->createList(
1635  OptionalType::ofTensor(), tensorIndices))
1636  ->output();
1637  const auto indexed =
1638  graph->insert(aten::index, {slicedArg, indices}, {}, stmt.range());
1639  const auto augmented = emitBuiltinCall(
1640  stmt.range(),
1641  *method.graph(),
1642  getAugOp(stmt, /*isTensor=*/true),
1643  indexed,
1644  {rhs},
1645  {},
1646  /*required=*/true);
1647  graph->insert(
1648  aten::index_put_,
1649  {slicedArg, indices, augmented},
1650  {},
1651  stmt.range());
1652  }
1653  } else {
1654  // Otherwise, it should be a list. Lower this expression into:
1655  // list.set_item(get_item(idx).add_(value))
1656  // similar to how Python handles things.
1657  const auto listType = sliceable->type()->cast<ListType>();
1658  AT_ASSERT(listType != nullptr);
1659 
1660  bool isTensorList =
1661  listType->getElementType()->isSubtypeOf(TensorType::get());
1662 
1663  // Get the idx to augment
1664  const auto subscriptExprs = lhs.subscript_exprs();
1665  if (subscriptExprs.size() != 1) {
1666  throw ErrorReport(subscriptExprs)
1667  << "Sliced expression not yet supported for"
1668  << " subscripted list augmented assignment. "
1669  << "File a bug if you want this.";
1670  }
1671  const auto idxValue = emitExpr(subscriptExprs[0]);
1672 
1673  const auto listArg = NamedValue(lhs.value().range(), "list", sliceable);
1674  const auto idxArg = NamedValue(subscriptExprs.range(), "idx", idxValue);
1675  const auto valueArg =
1676  NamedValue(stmt.rhs().range(), "value", emitExpr(stmt.rhs()));
1677 
1678  const auto getItem =
1679  graph->insert(aten::select, {listArg, idxArg}, {}, stmt.range());
1680  const auto augmentedItem = graph->insert(
1681  getAugOp(stmt, isTensorList), {getItem, valueArg}, {}, stmt.range());
1682  graph->insert(
1683  aten::_set_item, {listArg, idxArg, augmentedItem}, {}, stmt.range());
1684  }
1685  }
1686 
1687  // Emit mutating assignments like `foo[0] = bar`
1688  void emitSubscriptAssign(
1689  const SourceRange& stmtRange,
1690  const Subscript& lhs,
1691  const Expr& rhs) {
1692  emitSubscriptAssign(stmtRange, lhs, NamedValue(rhs.range(), emitExpr(rhs)));
1693  }
1694 
1695  void emitSubscriptAssign(
1696  const SourceRange& stmtRange,
1697  const Subscript& lhs,
1698  const NamedValue& rhs) {
1699  // First check the base value.
1700  auto sliceable = emitExpr(lhs.value());
1701 
1702  // If it's a tensor, copy the RHS data into it
1703  if (sliceable->type()->isSubtypeOf(TensorType::get())) {
1704  std::vector<Value*> tensorIndices;
1705  Value* sliced;
1706  // Handle multi-dimensional slicing: first emit int/slice indexing
1707  // TODO: the Python equivalent code has special-cased copy_to
1708  // broadcasting to match NumPy semantics (see PR#4853). We can't
1709  // replicate that without knowing the size of the Tensor; so really that
1710  // code should be moved into the aten function
1711  std::tie(sliced, tensorIndices) = emitIntAndSliceIndexing(
1712  lhs.range(), sliceable, lhs.subscript_exprs());
1713 
1714  const auto slicedArg = NamedValue(lhs.range(), sliced);
1715  if (tensorIndices.size() == 0) {
1716  // Common case: we only tried to index with int and slices. Copy the
1717  // RHS into the resulting tensor.
1718  graph->insert(aten::copy_, {slicedArg, rhs}, {}, stmtRange);
1719  } else {
1720  // Special case: we tried to do "advanced indexing" with a tensor.
1721  // Dispatch to `aten::index_put_` with tensorindices of Tensor?[]
1722  const auto indices = graph
1723  ->insertNode(graph->createList(
1724  OptionalType::ofTensor(), tensorIndices))
1725  ->output();
1726 
1727  graph->insert(
1728  aten::index_put_, {slicedArg, indices, rhs}, {}, stmtRange);
1729  }
1730 
1731  // Otherwise, this is a list. Dispatch to aten::_set_item to both select
1732  // and assign
1733  } else {
1734  const auto subscript = lhs.subscript_exprs();
1735  if (subscript.size() != 1 || subscript[0].kind() == TK_SLICE_EXPR) {
1736  throw ErrorReport(subscript)
1737  << "Sliced expression not yet supported for"
1738  << " subscripted list assignment. "
1739  << "File a bug if you want this.";
1740  }
1741 
1742  std::vector<NamedValue> args;
1743  args.emplace_back(lhs.value().range(), "list", sliceable);
1744  args.emplace_back(
1745  lhs.subscript_exprs().range(), "idx", emitExpr(subscript[0]));
1746  args.push_back(rhs);
1747 
1748  graph->insert(aten::_set_item, args, {}, stmtRange);
1749  }
1750  }
1751 
1752  void emitTupleAssign(const TupleLiteral& tl, const Expr& rhs) {
1753  size_t n_binders = tl.inputs().size();
1754  bool starred_unpack = calcNumStarredUnpack(tl.inputs(), tl.range());
1755  if (starred_unpack)
1756  n_binders--;
1757  auto output = emitSugaredExpr(rhs, n_binders);
1758  auto outputs = output->asTuple(
1759  rhs.range(),
1760  method,
1761  starred_unpack ? c10::nullopt : c10::optional<size_t>{n_binders});
1762  if (outputs.size() < n_binders) {
1763  throw ErrorReport(tl)
1764  << "need " << (starred_unpack ? "at least " : "") << n_binders
1765  << " values to unpack but found only " << outputs.size();
1766  }
1767  if (outputs.size() > n_binders && !starred_unpack) {
1768  throw ErrorReport(tl) << "too many values to unpack: need " << n_binders
1769  << " but found " << outputs.size();
1770  }
1771  int i = 0;
1772  for (auto assignee : tl.inputs()) {
1773  switch (assignee.kind()) {
1774  case TK_SUBSCRIPT:
1775  emitSubscriptAssign(
1776  rhs.range(),
1777  Subscript(assignee),
1778  NamedValue(
1779  rhs.range(), outputs.at(i)->asValue(rhs.range(), method)));
1780  i++;
1781  break;
1782  case TK_VAR:
1783  environment_stack->setSugaredVar(
1784  assignee.range(), Var(assignee).name().name(), outputs.at(i));
1785  i++;
1786  break;
1787  case TK_STARRED: {
1788  auto var = Starred(assignee).expr();
1789  if (var.kind() != TK_VAR) {
1790  throw ErrorReport(var)
1791  << "Cannot pack a tuple into a non-variable.";
1792  }
1793  size_t n_matched = outputs.size() - n_binders;
1794  ArrayRef<std::shared_ptr<SugaredValue>> outputs_ref = outputs;
1795  auto values = fmap(
1796  outputs_ref.slice(i, n_matched),
1797  [&](const std::shared_ptr<SugaredValue>& v) {
1798  return v->asValue(assignee.range(), method);
1799  });
1800  auto tup = graph->insertNode(graph->createTuple(values))->output();
1801  environment_stack->setVar(var.range(), Var(var).name().name(), tup);
1802  i += n_matched;
1803  } break;
1804  default:
1805  throw ErrorReport(assignee)
1806  << "unexpected expression on the left-hand side";
1807  }
1808  }
1809  }
1810 
1811  void emitAssignment(const Assign& stmt) {
1812  switch (stmt.lhs().kind()) {
1813  case TK_VAR: {
1814  auto v = Var(stmt.lhs());
1815  environment_stack->setSugaredVar(
1816  v.range(), v.name().name(), emitSugaredExpr(stmt.rhs(), 1));
1817  } break;
1818  case TK_TUPLE_LITERAL:
1819  emitTupleAssign(TupleLiteral(stmt.lhs()), stmt.rhs());
1820  break;
1821  case '.':
1822  emitSelectAssign(stmt);
1823  break;
1824  case TK_SUBSCRIPT:
1825  emitSubscriptAssign(stmt.range(), Subscript(stmt.lhs()), stmt.rhs());
1826  break;
1827  default:
1828  throw ErrorReport(stmt.lhs())
1829  << "unexpected expression on left-hand side of assignment.";
1830  }
1831  }
1832 
1833  void emitSelectAssign(const Assign& stmt) {
1834  const auto lhs = Select(stmt.lhs());
1835  const auto basename = Var(lhs.value()).name();
1836  const auto rhsValue =
1837  emitSugaredExpr(stmt.rhs(), 1)->asValue(stmt.rhs().range(), method);
1838  auto userObject = environment_stack->getSugaredVar(basename);
1839  userObject->setAttr(stmt.range(), method, lhs.selector().name(), rhsValue);
1840  }
1841 
1842  NodeKind getNodeKind(int kind, int ninputs) {
1843  switch (kind) {
1844  case '+':
1845  return aten::add;
1846  case '-':
1847  return aten::sub;
1848  case TK_UNARY_MINUS:
1849  return aten::neg;
1850  case '*':
1851  return aten::mul;
1852  case TK_POW:
1853  return aten::pow;
1854  case '@':
1855  return aten::matmul;
1856  case TK_STARRED:
1857  return prim::Starred;
1858  case '/':
1859  return aten::div;
1860  case '%':
1861  return aten::remainder;
1862  case TK_NE:
1863  return aten::ne;
1864  case TK_EQ:
1865  return aten::eq;
1866  case '<':
1867  return aten::lt;
1868  case '>':
1869  return aten::gt;
1870  case TK_LE:
1871  return aten::le;
1872  case TK_GE:
1873  return aten::ge;
1874  case TK_AND:
1875  return aten::__and__;
1876  case TK_OR:
1877  return aten::__or__;
1878  case TK_IS:
1879  return aten::__is__;
1880  case TK_ISNOT:
1881  return aten::__isnot__;
1882  case TK_NOT:
1883  return aten::__not__;
1884  case TK_FLOOR_DIV:
1885  return aten::floordiv;
1886  case '&':
1887  return aten::__and__;
1888  case '|':
1889  return aten::__or__;
1890  case '^':
1891  return aten::__xor__;
1892  default:
1893  throw std::runtime_error("unknown kind " + std::to_string(kind));
1894  }
1895  }
1896 
1897  std::vector<NamedValue> getNamedValues(
1898  const TreeList& trees,
1899  bool maybe_unpack) {
1900  std::vector<NamedValue> values;
1901  for (const auto& tree : trees) {
1902  if (maybe_unpack && tree->kind() == TK_STARRED) {
1903  auto starred = Starred(tree);
1904  auto entries = emitSugaredExpr(starred.expr(), 1)
1905  ->asTuple(starred.range(), method);
1906  for (const auto& entry : entries) {
1907  values.emplace_back(
1908  tree->range(), entry->asValue(starred.range(), method));
1909  }
1910  } else {
1911  values.emplace_back(tree->range(), emitExpr(Expr(tree)));
1912  }
1913  }
1914  return values;
1915  }
1916  std::vector<NamedValue> getNamedValues(
1917  const List<Expr>& trees,
1918  bool maybe_unpack) {
1919  return getNamedValues(trees.tree()->trees(), maybe_unpack);
1920  }
1921 
1922  std::vector<Value*> getValues(const TreeList& trees, bool maybe_unpack) {
1923  return toValues(*graph, getNamedValues(trees, maybe_unpack));
1924  }
1925  std::vector<Value*> getValues(const List<Expr>& trees, bool maybe_unpack) {
1926  return getValues(trees.tree()->trees(), maybe_unpack);
1927  }
1928 
1929  std::vector<NamedValue> emitAttributes(const List<Attribute>& attributes) {
1930  return fmap(attributes, [&](const Attribute& attr) {
1931  return NamedValue(
1932  attr.range(), attr.name().name(), emitExpr(attr.value()));
1933  });
1934  }
1935 
1936  void checkApplyExpr(Apply& apply, SourceRange& loc) {
1937  if (apply.inputs().size() != 2) {
1938  throw ErrorReport(loc) << Var(apply.callee()).name().name()
1939  << " expected exactly two arguments but found "
1940  << apply.inputs().size();
1941  }
1942  if (apply.attributes().size() > 0) {
1943  throw ErrorReport(loc)
1944  << Var(apply.callee()).name().name() << " takes no keyword arguments";
1945  }
1946  }
1947 
1948  std::shared_ptr<SugaredValue> emitApplyExpr(Apply& apply, size_t n_binders) {
1949  auto sv = emitSugaredExpr(apply.callee(), 1);
1950  auto loc = apply.callee().range();
1951  if (auto fork_value = dynamic_cast<ForkValue*>(sv.get())) {
1952  auto& trees = apply.inputs().tree()->trees();
1953  if (trees.size() < 1) {
1954  throw ErrorReport(loc) << "Expected at least one argument to fork()";
1955  }
1956 
1957  auto forked = emitSugaredExpr(Expr(trees[0]), 1);
1958  TreeList sliced_trees(trees.begin() + 1, trees.end());
1959  auto inputs = getNamedValues(sliced_trees, true);
1960  auto attributes = emitAttributes(apply.attributes());
1961  return emitForkExpr(loc, forked, inputs, attributes);
1962  } else if (auto annotate_value = dynamic_cast<AnnotateValue*>(sv.get())) {
1963  checkApplyExpr(apply, loc);
1964  TypePtr type = parseTypeFromExpr(apply.inputs()[0]);
1965  Value* expr = tryConvertToType(
1966  apply.range(),
1967  *graph,
1968  type,
1969  emitExpr(apply.inputs()[1], type),
1970  /*allow_conversions=*/true);
1971 
1972  // This is to ensure even if user forgets to call annotate None with the
1973  // Optional wrapper type, we still generate the correct value with the
1974  // Optional type. e.g. it makes annoate(Tensor, None) to behave the same
1975  // with annotate(Optional[Tensor], None). It also maintains the backward
1976  // compatibility of exported model on Optional undefined tensor/None
1977  auto opt_type = expr->type()->cast<OptionalType>();
1978  bool forget_opt_annotate =
1979  opt_type && *opt_type->getElementType() == *type;
1980 
1981  if (!forget_opt_annotate && !expr->type()->isSubtypeOf(type)) {
1982  throw ErrorReport(apply.inputs())
1983  << "expected an expression of type " << type->python_str()
1984  << " but found " << expr->type()->python_str();
1985  }
1986  return std::make_shared<SimpleValue>(expr);
1987  } else if (auto getattr = dynamic_cast<GetAttrValue*>(sv.get())) {
1988  checkApplyExpr(apply, loc);
1989  auto obj = emitSugaredExpr(apply.inputs()[0], 1);
1990  auto selector = apply.inputs()[1];
1991  if (selector.kind() != TK_STRINGLITERAL) {
1992  throw ErrorReport(loc)
1993  << "getattr's second argument must be a string literal";
1994  }
1995  const std::string& name = StringLiteral(selector).text();
1996  return obj->attr(apply.range(), method, name);
1997  } else if (auto isinstance = dynamic_cast<IsInstanceValue*>(sv.get())) {
1998  // NOTE: for `isinstance` builtin call in JIT, we only check the static
1999  // types on the inputs to evaluate, and insert the corresponding constant
2000  // node
2001  std::function<bool(Expr, Expr)> isInstanceCheck = [&](Expr obj,
2002  Expr classinfo) {
2003  if (classinfo.kind() == TK_TUPLE_LITERAL) {
2004  // handle the case for recursive tuple classinfo
2005  // return true if obj is an instance of any of the types
2006  for (Expr e : TupleLiteral(classinfo).inputs()) {
2007  if (isInstanceCheck(obj, e)) {
2008  return true;
2009  }
2010  }
2011  return false;
2012  }
2013  auto type_name = parseBaseTypeName(classinfo);
2014  if (!type_name) {
2015  throw ErrorReport(classinfo.range())
2016  << "type must be a type identifier";
2017  }
2018  auto val = emitExpr(obj);
2019  // Special casing for list and tuple since isintance(x, list) and
2020  // isinstance(x, tuple) does not accept List[int] / Tuple[int] like
2021  // subscript type annotation in python
2022  if (*type_name == "list" && val->type()->cast<ListType>()) {
2023  return true;
2024  } else if (*type_name == "tuple" && val->type()->cast<TupleType>()) {
2025  return true;
2026  } else if (val->type()->cast<OptionalType>()) {
2027  throw ErrorReport(loc)
2028  << "Optional isinstance check is not supported, "
2029  << "consider use is/isnot None instead";
2030  } else {
2031  TypePtr type = parseTypeFromExpr(classinfo);
2032  if (val->type()->isSubtypeOf(type)) {
2033  return true;
2034  }
2035  }
2036  return false;
2037  };
2038  checkApplyExpr(apply, loc);
2039  bool is_instance_val =
2040  isInstanceCheck(apply.inputs()[0], apply.inputs()[1]);
2041  return std::make_shared<SimpleValue>(
2042  graph->insertConstant(is_instance_val, nullptr, loc));
2043  } else if (auto classNew = dynamic_cast<ClassNewMethod*>(sv.get())) {
2044  if (apply.inputs().size() != 1) {
2045  throw ErrorReport(loc) << "Only one argument to __new__ allowed";
2046  }
2047  return classNew->createObject(
2048  apply.range(), method, Var(apply.inputs()[0]).name().name());;
2049  } else {
2050  auto inputs = getNamedValues(apply.inputs(), true);
2051  auto attributes = emitAttributes(apply.attributes());
2052  return sv->call(loc, method, inputs, attributes, n_binders);
2053  }
2054  }
2055 
2056  BoolInfo findRefinements(const TreeRef& tree) {
2057  switch (tree->kind()) {
2058  case TK_IS:
2059  case TK_ISNOT: {
2060  const auto& inputs = tree->trees();
2061  if (inputs.at(0)->kind() == TK_VAR && inputs.at(1)->kind() == TK_NONE) {
2062  const std::string& var_name = Var(inputs[0]).name().name();
2063  Refinements true_info, false_info;
2064  auto type =
2065  environment_stack->getVar(var_name, inputs[0]->range())->type();
2066  if (auto opt_type = type->cast<OptionalType>()) {
2067  false_info.setRefinement(
2068  var_name,
2069  TypeAndRange(opt_type->getElementType(), &tree->range()));
2070  true_info.setRefinement(
2071  var_name, TypeAndRange(NoneType::get(), &tree->range()));
2072  }
2073  if (tree->kind() == TK_IS) {
2074  return BoolInfo(true_info, false_info);
2075  } else {
2076  return BoolInfo(false_info, true_info);
2077  }
2078  }
2079  } break;
2080  case TK_NOT: {
2081  const auto& inputs = tree->trees();
2082  auto bool_info = findRefinements(inputs[0]);
2083  return BoolInfo(
2084  bool_info.false_refinements_, bool_info.true_refinements_);
2085  }
2086  case TK_OR:
2087  case TK_AND: {
2088  const auto& inputs = tree->trees();
2089  auto first = findRefinements(inputs[0]);
2090  auto second = findRefinements(inputs[1]);
2091  if (tree->kind() == TK_OR) {
2092  return *first.mergeOr(second);
2093  } else {
2094  return *first.mergeAnd(second);
2095  }
2096  }
2097  }
2098  return BoolInfo();
2099  }
2100 
2101  Value* emitExpr(const Expr& tree, const TypePtr& type_hint = nullptr) {
2102  return emitSugaredExpr(tree, 1, type_hint)->asValue(tree.range(), method);
2103  }
2104 
2105  NodeKind reverseComparision(NodeKind kind) {
2106  if (kind == aten::lt) {
2107  return aten::gt;
2108  } else if (kind == aten::le) {
2109  return aten::ge;
2110  } else if (kind == aten::gt) {
2111  return aten::lt;
2112  } else if (kind == aten::ge) {
2113  return aten::le;
2114  }
2115  throw std::runtime_error(
2116  "reverseComparision: unsupported NodeKind. File a bug");
2117  }
2118 
2119  // any expression that can produce a SugaredValue is handled here
2120  // expressions that only return a single Value* are handled in emitSimpleExpr
2121  // type_hint is set if there is a type that this value is expected to be
2122  // e.g. a : List[int] = []
2123  // or a = torch.jit.annotate(List[int], [])
2124  // the caller is responsible for checking that the result matches type_hint
2125  // emitSugaredExpr is free to ignore it.
2126  std::shared_ptr<SugaredValue> emitSugaredExpr(
2127  const Expr& tree,
2128  size_t n_binders,
2129  const TypePtr& type_hint = nullptr) {
2130  switch (tree.kind()) {
2131  case TK_VAR:
2132  return environment_stack->getSugaredVar(Var(tree).name());
2133  case '.': {
2134  auto select = Select(tree);
2135  auto sv = emitSugaredExpr(select.value(), 1);
2136  return sv->attr(select.range(), method, select.selector().name());
2137  }
2138  case TK_APPLY: {
2139  auto apply = Apply(tree);
2140  return emitApplyExpr(apply, n_binders);
2141  } break;
2142  default:
2143  return std::make_shared<SimpleValue>(emitSimpleExpr(tree, type_hint));
2144  }
2145  }
2146 
2147  Value* emitNegate(const TreeRef& tree) {
2148  const auto& inputs = tree->trees();
2149  auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);
2150 
2151  auto neg_val = emitBuiltinCall(
2152  tree->range(),
2153  *method.graph(),
2154  aten::neg,
2155  c10::nullopt,
2156  named_values,
2157  {},
2158  /*required=*/true);
2159 
2160  // constant fold the input if possible
2161  auto maybe_constant_input = toIValue(neg_val->node()->input());
2162  if (!maybe_constant_input) {
2163  return neg_val;
2164  }
2165  auto op = getOperation(neg_val->node());
2166  Stack stack;
2167  stack.push_back(*maybe_constant_input);
2168  op(stack);
2169  AT_ASSERT(stack.size() == 1);
2170  return graph->insertConstant(stack[0], nullptr, tree->range());
2171  }
2172 
2173  // This function extract a new graph from its original subgraph
2174  std::shared_ptr<SugaredValue> emitForkExpr(
2175  SourceRange loc,
2176  const std::shared_ptr<SugaredValue>& forked,
2177  at::ArrayRef<NamedValue> inputs,
2178  at::ArrayRef<NamedValue> attributes) {
2179  // Build the fork node without inputs
2180  auto fork_node =
2181  method.graph()
2182  ->insertNode(method.graph()->create(prim::fork, 1))
2183  ->setSourceLocation(std::make_shared<SourceRange>(loc));
2184  auto body_block = fork_node->addBlock();
2185 
2186  // Build a template of the graph to be executed
2187  Value* node_output;
2188  {
2189  WithInsertPoint guard(body_block);
2190  auto fn_sugared_output = forked->call(loc, method, inputs, attributes, 1);
2191  auto fn_simple_output = fn_sugared_output->asValue(loc, method);
2192  body_block->registerOutput(fn_simple_output);
2193  node_output = fork_node->output()->setType(
2194  FutureType::create(fn_simple_output->type()));
2195  }
2196 
2197  // Lambda lift block(0) into attr::Subgraph
2198  lambdaLiftFork(fork_node);
2199 
2200  return std::make_shared<SimpleValue>(node_output);
2201  }
2202 
2203  Value* emitSimpleExpr(
2204  const TreeRef& tree,
2205  const TypePtr& type_hint = nullptr) {
2206  switch (tree->kind()) {
2207  case '@':
2208  case TK_POW:
2209  case TK_IS:
2210  case TK_ISNOT:
2211  case TK_NOT:
2212  case TK_NE:
2213  case TK_EQ:
2214  case '<':
2215  case '>':
2216  case TK_LE:
2217  case TK_GE:
2218  case '*':
2219  case '/':
2220  case '+':
2221  case '-':
2222  case '%':
2223  case '&':
2224  case '|':
2225  case '^':
2226  case TK_FLOOR_DIV: {
2227  const auto& inputs = tree->trees();
2228  auto kind = getNodeKind(tree->kind(), inputs.size());
2229  auto named_values = getNamedValues(inputs, /*maybe_unpack=*/false);
2230  return emitBuiltinCall(
2231  tree->range(),
2232  *method.graph(),
2233  kind,
2234  c10::nullopt,
2235  named_values,
2236  {},
2237  /*required=*/true);
2238  }
2239  case TK_UNARY_MINUS: {
2240  return emitNegate(tree);
2241  }
2242  case TK_AND:
2243  case TK_OR: {
2244  const auto& inputs = tree->trees();
2245  return emitShortCircuitIf(
2246  tree->range(), inputs[0], inputs[1], tree->kind() == TK_OR);
2247  }
2248  case TK_STARRED: {
2249  throw ErrorReport(tree)
2250  << "Unexpected starred expansion. File a bug report.";
2251  }
2252  case TK_CONST: {
2253  return emitConst(Const(tree));
2254  } break;
2255  case TK_TRUE: {
2256  return graph->insertConstant(true, nullptr, tree->range());
2257  } break;
2258  case TK_FALSE: {
2259  return graph->insertConstant(false, nullptr, tree->range());
2260  } break;
2261  case TK_NONE: {
2262  return graph->insertConstant(IValue(), nullptr, tree->range());
2263  } break;
2264  case TK_SUBSCRIPT: {
2265  return emitSubscript(Subscript(tree));
2266  } break;
2267  case TK_IF_EXPR: {
2268  return emitTernaryIf(TernaryIf(tree));
2269  } break;
2270  case TK_STRINGLITERAL: {
2271  return emitStringLiteral(StringLiteral(tree));
2272  } break;
2273  case TK_LIST_LITERAL: {
2274  auto ll = ListLiteral(tree);
2275  auto values = getValues(ll.inputs(), /*maybe_unpack=*/true);
2276 
2277  // determine the element type of the list
2278  // if we have a type hint of List[T], use T
2279  // if the list is non-empty use type_of(list[0])
2280  // otherwise assume it is List[Tensor]
2281  TypePtr elem_type = TensorType::get();
2282  if (type_hint && type_hint->kind() == TypeKind::ListType) {
2283  elem_type = type_hint->expect<ListType>()->getElementType();
2284  } else if (!values.empty()) {
2285  elem_type = values.at(0)->type();
2286  }
2287 
2288  // Tensors are special because they have dymnamic properties. So any
2289  // list containing tensors should be typed with the unified typeof all
2290  // the elements.
2291  if (elem_type->isSubtypeOf(TensorType::get())) {
2292  for (const auto& value : values) {
2293  elem_type = unifyTypes(elem_type, value->type()).value();
2294  }
2295  }
2296  for (auto v : values) {
2297  if (!v->type()->isSubtypeOf(elem_type)) {
2298  throw ErrorReport(tree)
2299  << "Lists must contain only a single type, expected: "
2300  << *elem_type << " but found " << *v->type() << " instead";
2301  }
2302  }
2303  Value* result =
2304  graph->insertNode(graph->createList(elem_type, values))->output();
2305  return result;
2306  } break;
2307  case TK_TUPLE_LITERAL: {
2308  auto ll = TupleLiteral(tree);
2309  auto values = getValues(ll.inputs(), /*maybe_unpack=*/true);
2310  return graph->insertNode(graph->createTuple(values))->output();
2311  } break;
2312  case TK_DICT_LITERAL: {
2313  auto dl = DictLiteral(tree);
2314  auto key_trees = dl.key_inputs().tree()->trees();
2315  auto value_trees = dl.value_inputs().tree()->trees();
2316  AT_ASSERT(key_trees.size() == value_trees.size());
2317  std::vector<Value*> keys, values;
2318  for (size_t i = 0; i < key_trees.size(); ++i) {
2319  keys.push_back(emitExpr(Expr(key_trees[i])));
2320  values.push_back(emitExpr(Expr(value_trees[i])));
2321  }
2322 
2323  TypePtr key_type = nullptr;
2324  TypePtr value_type = nullptr;
2325 
2326  if (type_hint && type_hint->kind() == TypeKind::DictType) {
2327  auto dict_type = type_hint->expect<DictType>();
2328  key_type = dict_type->getKeyType();
2329  value_type = dict_type->getValueType();
2330  } else if (!keys.empty()) {
2331  key_type = keys.at(0)->type();
2332  value_type = values.at(0)->type();
2333  } else {
2334  key_type = StringType::get();
2335  value_type = TensorType::get();
2336  }
2337  AT_ASSERT(key_type != nullptr && value_type != nullptr);
2338 
2339  return graph
2340  ->insertNode(graph->createDict(key_type, value_type, keys, values))
2341  ->output();
2342  } break;
2343  default:
2344  throw ErrorReport(tree) << "Cannot emit expr for: " << tree;
2345  break;
2346  }
2347  }
2348 
2349  Value* emitConst(const Const& c) {
2350  if (c.isFloatingPoint())
2351  return materializeConstant(
2352  c.asFloatingPoint(), *graph, c.range(), fp_constants);
2353  else
2354  return materializeConstant(
2355  c.asIntegral(), *graph, c.range(), integral_constants);
2356  }
2357 
2358  Value* emitStringLiteral(const StringLiteral& c) {
2359  return insertConstant(*graph, c.text(), nullptr, c.range());
2360  }
2361 
2362  // Desugars select indexing: tensor[i] -> tensor.select(dim, i)
2363  Value* emitSelect(
2364  const SourceRange& loc,
2365  Value* input,
2366  int64_t dim,
2367  Value* index) {
2368  return emitBuiltinCall(
2369  loc,
2370  *graph,
2371  aten::select,
2372  c10::nullopt,
2373  {input, graph->insertConstant(dim, nullptr, loc), index},
2374  {},
2375  true);
2376  }
2377 
2378  // Desugars slice indexing: tensor[begin:end] -> tensor.slice(dim, begin, end,
2379  // 1)
2380  Value* emitSlice(
2381  const SourceRange& loc,
2382  Value* input,
2383  c10::optional<int64_t> dim, // Only used for tensor slicing
2384  const SliceExpr& slice) {
2385  std::vector<NamedValue> args;
2386  args.reserve(4);
2387  args.emplace_back(loc, "self", input);
2388 
2389  // XXX: If list slicing becomes more complicated or stops using
2390  // aten::slice, we should separate it from this function.
2391  if (dim) {
2392  AT_ASSERT(input->type()->isSubtypeOf(TensorType::get()));
2393  args.emplace_back(
2394  loc, "dim", graph->insertConstant(dim.value(), nullptr, loc));
2395  } else {
2396  AT_ASSERT(!input->type()->isSubtypeOf(TensorType::get()));
2397  }
2398 
2399  args.emplace_back(loc, "begin", emitExpr(Expr(slice.startOr(0))));
2400  const auto has_end = slice.end().present();
2401  if (has_end) {
2402  args.emplace_back(loc, "end", emitExpr(Expr(slice.end().get())));
2403  }
2404  if (input->type()->cast<TupleType>()) {
2405  if (has_end) {
2406  return emitTupleSlice(loc, args[0], args[1], /*end*/ args[2]);
2407  } else {
2408  return emitTupleSlice(loc, args[0], args[1], c10::nullopt);
2409  }
2410  }
2411  NamedValue step =
2412  NamedValue(loc, "step", graph->insertConstant(1, nullptr, loc));
2413  return emitBuiltinCall(
2414  loc, *graph, aten::slice, c10::nullopt, args, {step}, true);
2415  }
2416 
2417  Value* emitIndex(
2418  const SourceRange& loc,
2419  Value* input,
2420  at::ArrayRef<Value*> indices) {
2421  // NB: the index of aten::index should be a type of List[Optional[Tensor]],
2422  // this is to support the case like t[:, :, 1] where : here indicates a
2423  // None/undefined tensor(optional tensor)
2424  auto* index =
2425  graph->insertNode(graph->createList(OptionalType::ofTensor(), indices))
2426  ->output();
2427  return emitBuiltinCall(
2428  loc, *graph, aten::index, c10::nullopt, {input, index}, {}, true);
2429  }
2430 
2431  // Emits multidimensional slicing with int and slice indices.
2432  // Returns:
2433  // - Value*: the input after it has been indexed by int and slice indices.
2434  // - vector<Value*>: A list of tensor Value* indices that have not been
2435  // applied yet.
2436  // Should be NULL at indices where sliceable (post-slicing) isn't indexed by
2437  // a tensor.
2438  std::pair<Value*, std::vector<Value*>> emitIntAndSliceIndexing(
2439  const SourceRange& loc,
2440  Value* sliceable,
2441  const List<Expr>& subscript_exprs) {
2442  std::vector<Value*> tensor_indices;
2443  size_t dim = 0;
2444 
2445  auto handle_tensor = [&](Value* tensor) {
2446  // NB: tensor_indices can have None holes because of how at::index works.
2447  tensor_indices.resize(dim + 1);
2448  tensor_indices[dim] = tensor;
2449  dim++;
2450  };
2451 
2452  for (const auto& subscript_expr : subscript_exprs) {
2453  if (subscript_expr.kind() == TK_SLICE_EXPR) {
2454  sliceable = emitSlice(loc, sliceable, dim, SliceExpr(subscript_expr));
2455  ++dim;
2456  continue;
2457  }
2458  auto index = emitExpr(subscript_expr, OptionalType::ofTensor());
2459  if (index->type() == IntType::get()) {
2460  sliceable = emitSelect(loc, sliceable, dim, index);
2461  continue;
2462  } else if (index->type()->isSubtypeOf(OptionalType::ofTensor())) {
2463  // NB:index type can either be a Tensor or : (None of Optional Tensor)
2464  handle_tensor(index);
2465  continue;
2466  }
2467  throw ErrorReport(loc)
2468  << "Unsupported operation: indexing tensor with unsupported index type '"
2469  << index->type()->str()
2470  << "'. Only ints, slices, and tensors are supported";
2471  }
2472  // at::index takes in a List[Optional[Tensor]] where some dims can be None.
2473  // create None node with optional tensor output type and pass to at::index.
2474  for (auto& index : tensor_indices) {
2475  if (index == nullptr) {
2476  index =
2477  graph->insertNode(graph->createNone(TensorType::get()))->output();
2478  }
2479  }
2480  return std::make_pair(sliceable, tensor_indices);
2481  }
2482 
2483  // Desugars multidim slicing into slice/select/index calls.
2484  //
2485  // XXX: Errors in user code are not elegantly reported.
2486  // Let's say someone were to do the following:
2487  // @torch.jit.script
2488  // def fn(x):
2489  // return x[0, 1]
2490  // fn(torch.randn(5))
2491  // Because we desugar this into two aten::select ops, the error message
2492  // complains about aten::select failing rather than there "not being
2493  // enough dimensions to index".
2494  //
2495  // The strategy is to slice and select the tensor for int and slices first
2496  // in one pass and then apply at::index on the result of the
2497  // slicing/selecting. Call the tensor after we've applied slice / select the
2498  // `sliced`. tensor_indices should have the same size as sliced.dim():
2499  // - tensor_indices[i] = NULL if we should not index `sliced` at dim i
2500  // - tensor_indices[i] = t if we should index `sliced` at dim i with tensor t.
2501  Value* emitMultidimSlicing(
2502  const SourceRange& loc,
2503  Value* sliceable,
2504  const List<Expr>& subscript_exprs) {
2505  if (!sliceable->type()->isSubtypeOf(TensorType::get())) {
2506  throw ErrorReport(loc)
2507  << "Unsupported operation: attempted to use multidimensional "
2508  << "indexing on a non-tensor type.";
2509  }
2510 
2511  std::vector<Value*> tensor_indices;
2512  std::tie(sliceable, tensor_indices) =
2513  emitIntAndSliceIndexing(loc, sliceable, subscript_exprs);
2514 
2515  if (tensor_indices.empty()) {
2516  // XXX: Might need to at::alias this when we support mutability
2517  return sliceable;
2518  }
2519 
2520  return emitIndex(loc, sliceable, tensor_indices);
2521  }
2522 
2523  // Desugars slice syntactic sugar tensor[begin:end] -> tensor.slice(begin,
2524  // end).
2525  Value* emitBasicSlice(
2526  const SourceRange& loc,
2527  Value* sliceable,
2528  const List<Expr>& subscript_exprs) {
2529  AT_ASSERT(subscript_exprs.size() == 1);
2530  AT_ASSERT(subscript_exprs[0].kind() == TK_SLICE_EXPR);
2531  auto slice_exp = SliceExpr(subscript_exprs[0]);
2532  c10::optional<int64_t> maybe_dim;
2533  if (sliceable->type()->isSubtypeOf(TensorType::get())) {
2534  // If the sliceable object is a tensor, specify a default dimension
2535  maybe_dim = 0;
2536  }
2537  return emitSlice(loc, sliceable, maybe_dim, slice_exp);
2538  }
2539 
2540  int64_t getTupleIndexVal(
2541  const SourceRange& loc,
2542  const TupleTypePtr& tuple_type,
2543  Value* idx_val,
2544  bool allow_out_of_bounds) {
2545  int64_t index;
2546  at::optional<IValue> ivalue = toIValue(idx_val);
2547  if (ivalue && ivalue->isInt()) {
2548  index = ivalue->to<int64_t>();
2549  } else {
2550  throw ErrorReport(loc) << "tuple indices must be integer constants";
2551  }
2552  // set index to be positive to simplify logic in runtime
2553  int64_t adj_index = index;
2554  int64_t tuple_len = tuple_type->elements().size();
2555  if (index < 0) {
2556  adj_index = tuple_len + index;
2557  }
2558  if (!allow_out_of_bounds && (adj_index >= tuple_len || adj_index < 0)) {
2559  throw ErrorReport(loc) << "Tuple index out of range. Tuple is length "
2560  << tuple_len << " and index is " << index;
2561  }
2562  return adj_index;
2563  }
2564 
2565  Value* emitTupleIndex(
2566  const SourceRange& loc,
2567  Value* tuple_val,
2568  Value* idx_val) {
2569  auto tuple_typ = tuple_val->type()->cast<TupleType>();
2570  auto adj_index = getTupleIndexVal(
2571  loc, tuple_typ, idx_val, /*allow_out_of_bounds*/ false);
2572  return graph->insertNode(graph->createTupleIndex(tuple_val, adj_index))
2573  ->output();
2574  }
2575 
2576  Value* emitDictIndex(
2577  const SourceRange& loc,
2578  Value* dict_val,
2579  Value* key_val) {
2580  auto dict_type = dict_val->type()->cast<DictType>();
2581  AT_ASSERT(key_val->type()->isSubtypeOf(dict_type->getKeyType()));
2582  return graph->insertNode(graph->createDictIndex(dict_val, key_val))
2583  ->output();
2584  }
2585 
2586  Value* emitTupleSlice(
2587  const SourceRange& loc,
2588  const NamedValue& tuple_val,
2589  const NamedValue& beg_val,
2590  const at::optional<NamedValue>& end_val) {
2591  auto tuple_type = tuple_val.value(*graph)->type()->expect<TupleType>();
2592  int64_t beg = getTupleIndexVal(
2593  loc, tuple_type, beg_val.value(*graph), /*allow_out_of_bounds*/ true);
2594  int64_t end;
2595  int64_t tuple_len = tuple_type->elements().size();
2596  if (end_val) {
2597  end = getTupleIndexVal(loc, tuple_type, end_val->value(*graph), true);
2598  } else {
2599  end = tuple_len;
2600  }
2601  // slicing does not throw out of bounds errors
2602  end = std::min(std::max((int64_t)0, end), tuple_len);
2603  beg = std::min(std::max((int64_t)0, beg), tuple_len);
2604 
2605  return graph
2606  ->insertNode(graph->createTupleSlice(tuple_val.value(*graph), beg, end))
2607  ->output();
2608  }
2609 
2610  Value* emitSubscript(const Subscript& subscript) {
2611  return emitSubscript(
2612  subscript.range(),
2613  emitExpr(subscript.value()),
2614  subscript.subscript_exprs());
2615  }
2616 
2617  Value* emitSubscript(
2618  const SourceRange& loc,
2619  Value* sliceable,
2620  const List<Expr>& subscript_exprs) {
2621  if (subscript_exprs.size() != 1) {
2622  return emitMultidimSlicing(loc, sliceable, subscript_exprs);
2623  }
2624  if (subscript_exprs[0].kind() == TK_SLICE_EXPR) {
2625  return emitBasicSlice(loc, sliceable, subscript_exprs);
2626  } else {
2627  return emitBasicGather(loc, sliceable, subscript_exprs);
2628  }
2629  }
2630 
2631  // Desugars gather syntactic sugar foo[i]
2632  Value* emitBasicGather(
2633  const SourceRange& loc,
2634  Value* gatherable,
2635  const List<Expr>& subscript_exprs) {
2636  AT_ASSERT(subscript_exprs.size() == 1);
2637 
2638  if (gatherable->type()->kind() == TypeKind::ListType) {
2639  // if it's a list, emit a regular index selection op
2640  auto* idx = emitExpr(subscript_exprs[0]);
2641  return emitBuiltinCall(
2642  loc, *graph, aten::select, c10::nullopt, {gatherable, idx}, {}, true);
2643  } else if (gatherable->type()->isSubtypeOf(TensorType::get())) {
2644  return emitMultidimSlicing(loc, gatherable, subscript_exprs);
2645  } else if (auto tuple_type = gatherable->type()->cast<TupleType>()) {
2646  auto* idx = emitExpr(subscript_exprs[0]);
2647  return emitTupleIndex(loc, gatherable, idx);
2648  } else if (auto dict_type = gatherable->type()->cast<DictType>()) {
2649  auto* idx = emitExpr(subscript_exprs[0]);
2650  return emitDictIndex(loc, gatherable, idx);
2651  } else {
2652  throw ErrorReport(loc)
2653  << "Indexing only supported on lists, dictionaries, "
2654  "tensors, and tuples, but got type '"
2655  << gatherable->type()->str() << "'";
2656  }
2657  }
2658 };
2659 
2660 void defineMethodsInModule(
2661  const std::shared_ptr<Module>& m,
2662  const std::vector<Def>& definitions,
2663  const std::vector<Resolver>& resolvers,
2664  const c10::optional<Self>& self) {
2665  AT_ASSERT(definitions.size() == resolvers.size());
2666  auto resolver_it = resolvers.begin();
2667  std::vector<Method*> methods;
2668  std::unordered_map<std::string, Method*> function_table;
2669  for (const Def& def : definitions) {
2670  const std::string& name = def.name().name();
2671  auto resolver = *resolver_it++;
2672  AT_ASSERT(resolver);
2673  if (!self) {
2674  // if self is defined, then these are methods and do not go into the
2675  // global namespace otherwise, they get defined together so we add them to
2676  // the function table so the methods can see each other
2677  resolver = [resolver, &function_table](
2678  const std::string& name,
2679  Method& m,
2680  const SourceRange& loc) -> std::shared_ptr<SugaredValue> {
2681  auto it = function_table.find(name);
2682  if (it != function_table.end()) {
2683  return std::make_shared<MethodValue>(nullptr, *it->second);
2684  }
2685  return resolver(name, m, loc);
2686  };
2687  }
2688  auto creator = [def, resolver, self](Method& method) {
2689  AT_ASSERT(resolver);
2690  to_ir(def, resolver, self, method);
2691  };
2692  Method& method = m->create_method(name, creator);
2693  function_table[name] = &method;
2694  methods.push_back(&method);
2695  }
2696  for (Method* method : methods) {
2697  method->ensure_defined();
2698  }
2699  if (!self || !self->asFirstClass()) {
2700  // Disable module hooks if the module is only used to store a class's code.
2701  didFinishEmitModule(m);
2702  }
2703 }
2704 
2705 void defineMethodsInModule(
2706  const std::shared_ptr<Module>& m,
2707  const std::string& source,
2708  const Resolver& resolver,
2709  const c10::optional<Self>& self) {
2710  Parser p(source);
2711  std::vector<Def> definitions;
2712  std::vector<Resolver> resolvers;
2713  while (p.lexer().cur().kind != TK_EOF) {
2714  auto def = Def(p.parseFunction(/*is_method=*/bool(self)));
2715  definitions.push_back(def);
2716  resolvers.push_back(resolver);
2717  }
2718  defineMethodsInModule(m, definitions, resolvers, self);
2719 }
2720 
2721 void lambdaLiftFork(Node* fork_node) {
2722  // Fork a new graph from its orignal owning graph
2723  auto forked_graph = std::make_shared<Graph>();
2724  auto body_block = fork_node->blocks()[0];
2725 
2726  // Make sure we capture everything in the new graph.
2727  // The uncaptured values will be added to the fork signature.
2728  std::unordered_map<Value*, Value*> uncaptures_map;
2729  auto env = [&](Value* v) -> Value* {
2730  if (!uncaptures_map.count(v)) {
2731  // Capture values for both graphs
2732  uncaptures_map[v] = forked_graph->addInput()->copyMetadata(v);
2733  fork_node->addInput(v);
2734  }
2735  return uncaptures_map[v];
2736  };
2737  forked_graph->block()->cloneFrom(body_block, env);
2738 
2739  // Separate the subgraph and clean up the orignal one
2740  fork_node->g_(attr::Subgraph, forked_graph);
2741  fork_node->eraseBlock(0);
2742 }
2743 } // namespace script
2744 } // namespace jit
2745 } // namespace torch
AT_CPP14_CONSTEXPR ArrayRef< T > slice(size_t N, size_t M) const
slice(n, m) - Chop off the first N elements of the array, and keep M elements in the array...
Definition: ArrayRef.h:161
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
Definition: jit_type.h:17
An utility class for setting temporary insertion points.
Definition: ir.h:1174
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41