Caffe2 - C++ API
A deep learning, cross platform ML framework
interpreter.cpp
1 #include <torch/csrc/jit/interpreter.h>
2 
3 #include <torch/csrc/autograd/edge.h>
4 #include <torch/csrc/autograd/function.h>
5 #include <torch/csrc/autograd/generated/variable_factories.h>
6 #include <torch/csrc/autograd/grad_mode.h>
7 #include <torch/csrc/autograd/profiler.h>
8 #include <torch/csrc/autograd/variable.h>
9 #include <c10/util/Exception.h>
10 #include <torch/csrc/jit/constants.h>
11 #include <torch/csrc/jit/graph_executor.h>
12 #include <torch/csrc/jit/ir.h>
13 #include <ATen/core/ivalue.h>
14 #include <torch/csrc/jit/operator.h>
15 #include <torch/csrc/jit/script/jit_exception.h>
16 #include <c10/core/thread_pool.h>
17 
18 #include <exception>
19 #include <iostream>
20 #include <memory>
21 #include <mutex>
22 #include <ostream>
23 #include <stdexcept>
24 #include <typeinfo>
25 #include <unordered_map>
26 #include <unordered_set>
27 #include <utility>
28 #include <vector>
29 
30 namespace torch {
31 namespace jit {
32 
33 // Before we translate to intepreter instructions, we do
34 // some preprocessing of the graph to turn it into a form that is closer
35 // to what the instructions will look like.
36 // In particular we:
37 // * (TODO) desugar Loop trip counts into c = 0, c += 1 instructions in the loop
38 // * Turn inputs/outputs into Load/Store instruction
39 // *. computes move_flags (see Outputs), and inserts
40 // * Drop nodes are inserted for any node that is unused to create a dummy use
41 // that will cause the interpreter to free the node.
42 // A drop node is just a node with no outputs that just pops its inputs off
43 // the stack, to ensure the interpreter release references to nodes that are
44 // never used. Drop nodes are also inserted when the last use of a node is in
45 // some conditionally run control flow (e.g. one side of an If) and the
46 // interpreter must free the node only after the control flow has reconverged
47 // Outputs are:
48 // * graph - the post processed copy of g
49 // * move_flags[n] - a list of booleans, one for each input,
50 // indicating whether this is the last use of the value. The interpreter
51 // should generate a move rather than a copy in this case.
52 
53 namespace {
54 
55 // new_cond = (i < max_trip_count) && cond
56 Value* createTripCountConjunctiveCondition(
57  Graph* g,
58  Value* cur_trip_count,
59  Value* max_trip_count,
60  Value* cond) {
61  // Emit initial comparison -- initial_trip_count < max_trip_count
62  Value* initial_comparison_value =
63  g->insertNode(g->create(aten::lt, {cur_trip_count, max_trip_count}, 1))
64  ->output()
65  ->setType(BoolType::get());
66 
67  // Replace initial condition with logical `and` of trip count and
68  // initial condition
69  Value* new_cond =
70  g->insertNode(
71  g->create(aten::__and__, {initial_comparison_value, cond}, 1))
72  ->output()
73  ->setType(BoolType::get());
74  return new_cond;
75 }
76 
77 // this currently just _removes_ the trip count inputs and checks they are
78 // unused. In the future they will be desugared into normal arithmetic to
79 // provide a loop counter
80 void desugarTripCounts(Block* b) {
81  for (auto n : b->nodes()) {
82  if (n->kind() == prim::Loop) {
83  auto g = n->owningGraph();
84  auto body_block = n->blocks()[0];
85 
86  Value* block_trip_count_input = body_block->inputs()[0];
87  // Treat loop iteration number as a loop-carried dependency. We emit an
88  // increment at the end of the body block.
89  n->insertOutput(0);
90 
91  Value* max_trip_count_value = n->input(0);
92  {
93  WithInsertPoint guard(n);
94  // int i = 0
95  Value* initial_trip_count = g->insertConstant(0);
96  // Set up initial iteration number value for loop-carried dependency
97  n->removeInput(0);
98  // Input 0 is now initial termination condition, insert this after that.
99  // LCD's start at index 1.
100  n->insertInput(1, initial_trip_count);
101 
102  Value* new_cond = createTripCountConjunctiveCondition(
103  g, initial_trip_count, max_trip_count_value, n->input(0));
104  n->replaceInput(0, new_cond);
105  }
106 
107  {
108  WithInsertPoint guard(body_block);
109  // Trip count is now a loop carried dependency. We emit an op to
110  // increment the trip count at the end of the body. Then, emit the same
111  // conjunctive stopping condition as above.
112 
113  Value* const_one = g->insertConstant(1);
114 
115  Value* inc_trip_count =
116  g->insertNode(
117  g->create(aten::add, {block_trip_count_input, const_one}, 1))
118  ->output()
119  ->setType(IntType::get());
120  body_block->insertOutput(1, inc_trip_count);
121 
122  Value* body_cond = createTripCountConjunctiveCondition(
123  g, inc_trip_count, max_trip_count_value, body_block->outputs()[0]);
124  body_block->eraseOutput(0);
125  body_block->insertOutput(0, body_cond);
126  }
127  }
128  for (auto sb : n->blocks()) {
129  desugarTripCounts(sb);
130  }
131  }
132 }
133 
134 // removes all inputs and outputs to a graph, replacing them with Load Store
135 // nodes
136 static void flattenIO(Graph& graph) {
137  auto load = graph.prependNode(graph.create(prim::Load, 0));
138  for (auto old_input : graph.inputs()) {
139  auto nv = load->addOutput();
140  nv->setType(old_input->type());
141  old_input->replaceAllUsesWith(nv);
142  }
143  graph.appendNode(graph.create(prim::Store, graph.outputs(), 0));
144 
145  while (graph.inputs().size() > 0)
146  graph.eraseInput(graph.inputs().size() - 1);
147  while (graph.outputs().size() > 0)
148  graph.eraseOutput(graph.outputs().size() - 1);
149 }
150 
151 // insert Drop nodes to kill references for anything unused:
152 // this can happen in a few places, e.g. when a node returns
153 // many values but only one is used
154 // a, b = foo()
155 // return a
156 void dropUnused(Block* b) {
157  auto createDropIfUnused = [&](ArrayRef<Value*> values) -> Node* {
158  std::vector<Value*> to_drop;
159  for (auto v : values) {
160  if (v->uses().size() == 0)
161  to_drop.push_back(v);
162  }
163  if (to_drop.size() == 0)
164  return nullptr;
165  return b->owningGraph()->create(prim::Drop, to_drop, 0);
166  };
167 
168  if (auto d = createDropIfUnused(b->inputs())) {
169  b->prependNode(d);
170  }
171  for (auto n : b->nodes()) {
172  if (auto d = createDropIfUnused(n->outputs())) {
173  d->insertAfter(n);
174  }
175  for (auto b : n->blocks())
176  dropUnused(b);
177  }
178 }
179 
180 // for each input, should we move rather than copy the inputs
181 std::unordered_map<Node*, std::vector<uint8_t>> findLastUses(Graph& g) {
182  // struct to share common data structures
183  struct FindLastUses {
184  Graph& graph;
185  // have we seen this value, yet, if not, it is the last use of the value
186  std::unordered_set<Value*> seen;
187 
188  std::unordered_map<Node*, std::vector<uint8_t>> move_flags;
189  // A map from an If or Loop node to the optional Drop block that
190  // occurs directly after it to release any tensors that go out of scope
191  // when the If/Loop exits. These are created and inserted on demand.
192  std::unordered_map<Node*, Node*> drop_for_node;
193 
194  FindLastUses(Graph& g) : graph(g) {
195  scanBlock(graph.block());
196  }
197  void scanBlock(Block* b) {
198  scanNode(b->return_node());
199  for (auto n : b->nodes().reverse()) {
200  scanNode(n);
201  }
202  }
203  void scanNode(Node* n) {
204  for (auto b : n->blocks()) {
205  scanBlock(b);
206  }
207  move_flags[n].resize(n->inputs().size());
208  // scan backwards so if a value is used twice in the list then it is a
209  // move
210  for (size_t i = n->inputs().size(); i > 0; --i) {
211  scanUse(n, i - 1);
212  }
213  }
214  void scanUse(Node* n, size_t i) {
215  auto& move_flags_n = move_flags[n];
216  auto v = n->inputs()[i];
217  auto inserted = seen.insert(v).second;
218  if (!inserted) {
219  move_flags_n[i] = false;
220  return;
221  }
222 
223  // the last use of v may be in a nested block of an If or Loop statement
224  // find the node 'same_depth_node' at the same depth as the definition of
225  // v, and consider that node to be the last use of v. This ensures we do
226  // not delete nodes in nested scopes that may be executed multiple times
227  // and that nodes used on one side of an if
228  // but not the other get deleted regardless of the branch
229  // e.g.
230  // a = 4
231  // while <...>:
232  // y = a + a
233  // drop(a)
234  // In other words, we find the first program point for v that
235  // _reverse_ dominates the definition of v, and add a drop point there.
236  Node* same_depth_node = findOwnerInBlock(n, v->node()->owningBlock());
237  AT_ASSERT(
238  same_depth_node); // failure means v is not in scope for n, use lint!
239 
240  // In the case where v and n are in the same block, just mark
241  // its move_flags to be true
242  if (same_depth_node == n) {
243  move_flags_n[i] = true;
244  return;
245  }
246 
247  // in the case where the use is nested in a block
248  // add a Drop node after that block which will drop 'v'.
249  move_flags_n[i] = false;
250  addToDropIfNotExists(
251  findOrCreateDropInstructionForNode(same_depth_node), v);
252  }
253 
254  // finds the node in block 'block' that contains in 'n'
255  // or nullptr if no such node exists, e.g.:
256  // n0: a = 4
257  // n1: if <cond>:
258  // n2: b = a + a
259  // findOwnerInBlock(n2, n0.block()) == n1
260  Node* findOwnerInBlock(Node* n, Block* block) {
261  while (n != nullptr && block != n->owningBlock()) {
262  n = n->owningBlock()->owningNode();
263  }
264  return n;
265  }
266 
267  Node* findOrCreateDropInstructionForNode(Node* n) {
268  auto it = drop_for_node.find(n);
269  if (it == drop_for_node.end()) {
270  auto drop_node = graph.create(prim::Drop, 0);
271  drop_node->insertAfter(n);
272  it = drop_for_node.emplace(n, drop_node).first;
273  }
274  return it->second;
275  }
276 
277  void addToDropIfNotExists(Node* drop, Value* v) {
278  for (auto i : drop->inputs()) {
279  // we already accounted for this use
280  if (i == v)
281  return;
282  }
283  drop->addInput(v);
284  move_flags[drop].push_back(true);
285  }
286  };
287 
288  return FindLastUses(g).move_flags;
289 }
290 } // namespace
291 
292 // pre-processing that happens once per graph
294  PreprocessGraph(Graph& g) : graph(g.copy()) {
295  n_outputs = graph->outputs().size();
296  desugarTripCounts(graph->block());
297  flattenIO(*graph);
298  dropUnused(graph->block());
299  // fill in move_flags by scanning blocks;
300  move_flags = findLastUses(*graph);
301  // TODO: desugar Loop trip counts, for now we drop trip counts
302  }
303  // Outputs of the preprocessing:
304  std::shared_ptr<Graph> graph;
305  // for each input, should we move rather than copy the inputs
306  std::unordered_map<Node*, std::vector<uint8_t>> move_flags;
307  // Record number of outputs before flattenIO()
308  size_t n_outputs;
309 };
310 
311 // Sometimes we want to pass things that are not tensors. Instead of
312 // coming up with some "superclass" for tensor, which is annoying since
313 // 99% of values are at::Tensor, we instead we create a fake subclass of
314 // TensorImpl that can be subclassed to hold arbitrary things
315 // Note: this is currently unused but will probably be useful in the future,
316 // so we keep it around
318  public:
320  : TensorImpl(
321  at::UndefinedTensorId(),
323  nullptr,
324  /* is_variable */ false) {}
325 
326  ~ContainerTensor() override = default;
327  at::IntArrayRef sizes() const override {
328  throw std::runtime_error("sizes() on ContainerTensor");
329  }
330  at::IntArrayRef strides() const override {
331  throw std::runtime_error("strides() on ContainerTensor");
332  }
333  int64_t dim() const override {
334  throw std::runtime_error("dim() on ContainerTensor");
335  }
336  const at::Storage& storage() const override {
337  throw std::runtime_error("storage() on ContainerTensor");
338  }
339 };
340 
341 // We need some lists for inputs and outputs. To keep all the memory
342 // contiguous we allocate a single vector and use offsets into the vector
343 // which are stored in the ListHandle struct
344 // start is an offset into int_data of Code for ListHandle<int>
345 // and bool_data of Code for ListHandle<bool>
346 template <typename T>
347 struct ListHandle {
348  int start;
349  int size;
350 };
351 
352 struct UseList {
353  // values to be used
354  ListHandle<int> values;
355  // boolean flags indicating whether to free the Tensor after this use
356  ListHandle<bool> free_flags;
357 };
358 
359 // one instruction plus meta-data
360 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
361 struct Instruction {
362  Operation callback;
363  UseList inputs;
364  ListHandle<int> outputs;
365  Symbol debug_name; // used in dump to understand the generated code
366  std::shared_ptr<SourceLocation> debug_location; // for error reporting
367 };
368 
369 int relativeJump(int from_inst, int to_inst) {
370  return to_inst - (from_inst + 1);
371 }
372 
373 struct CodeImpl {
374  CodeImpl(const std::shared_ptr<Graph>& graph_) : preprocess(*graph_) {
375  graph = preprocess.graph;
376  insertNodesFromBlock(graph->block());
377  }
378 
379  // jump when input is false
380  void createJumpFalse(int from_inst, int to_inst) {
381  auto& inst = instructions[from_inst];
382  AT_ASSERT(inst.debug_name == prim::Placeholder);
383  auto offset = relativeJump(from_inst, to_inst);
384  inst.callback = [offset](Stack& stack) {
385  auto t = pop(stack).toBool();
386  return t ? 0 : offset;
387  };
388  inst.debug_name = prim::JumpZ;
389  }
390 
391  // jump when input is true
392  void createJumpTrue(int from_inst, int to_inst) {
393  auto& inst = instructions[from_inst];
394  AT_ASSERT(inst.debug_name == prim::Placeholder);
395  auto offset = relativeJump(from_inst, to_inst);
396  inst.callback = [offset](Stack& stack) {
397  auto t = pop(stack).toBool();
398  return t ? offset : 0;
399  };
400  inst.debug_name = prim::JumpNZ;
401  }
402 
403  void createJump(int from_inst, int to_inst) {
404  auto& inst = instructions[from_inst];
405  AT_ASSERT(inst.debug_name == prim::Placeholder);
406  auto offset = relativeJump(from_inst, to_inst);
407  inst.callback = [=](Stack& stack) { return offset; };
408  inst.debug_name = prim::Jump;
409  }
410 
411  void insertNodesFromBlock(Block* block) {
412  for (auto node : block->nodes()) {
413  const auto& source_location = node->getSourceLocation();
414  switch (node->kind()) {
415  case prim::If: {
416  // x = if c:
417  // <then_block>
418  // -> (vt)
419  // else:
420  // <else_block>
421  // -> (vf)
422 
423  // turns into:
424  // JumpNZ c, then
425  // <else_block>
426  // x = vf
427  // Jump end
428  // then:
429  // <then_block>
430  // x = vt
431  // end:
432 
433  // prim::Placeholder instructions are replaced with branch
434  // instructions when the branch target locations are known
435  auto cond_branch = insertInstruction(
436  prim::Placeholder,
437  source_location,
438  node->inputs(),
439  moveFlags(node),
440  {});
441  auto then_block = node->blocks()[0];
442  auto else_block = node->blocks()[1];
443  insertNodesFromBlock(else_block);
444  insertAssign(
445  source_location,
446  else_block->outputs(),
447  moveFlags(else_block),
448  node->outputs());
449  auto jump =
450  insertInstruction(prim::Placeholder, source_location, {}, {}, {});
451  auto then_block_start = instructions.size();
452  insertNodesFromBlock(then_block);
453  insertAssign(
454  source_location,
455  then_block->outputs(),
456  moveFlags(then_block),
457  node->outputs());
458  createJump(jump, instructions.size());
459  createJumpTrue(cond_branch, then_block_start);
460  } break;
461  case prim::Loop: {
462  // o0 = while c i0
463  // block 0: l0
464  // <body>
465  // -> (v0, v1)
466 
467  // turns into:
468  // l0 = i0
469  // JumpZ c, end
470  // begin:
471  // <body>
472  // c, l0 = v0, v1
473  // JumpNZ c, begin
474  // end:
475 
476  auto body_block = node->blocks()[0];
477 
478  // before assign op: stack: ... <cond> <loop-carried-depdencies>
479  insertAssign(
480  source_location,
481  node->inputs(),
482  moveFlags(node),
483  body_block->inputs());
484  // after assign op: stack: ... <cond>
485  // cond_branch consumes <cond> from top of the stack
486  auto cond_branch =
487  insertInstruction(prim::Placeholder, source_location, {}, {}, {});
488  // after branch: stack: ...
489 
490  auto entry = instructions.size();
491  insertNodesFromBlock(body_block);
492  // before assign op: stack: ... <cond> <loop-carried-depdencies>
493  insertAssign(
494  source_location,
495  body_block->outputs(),
496  moveFlags(body_block),
497  body_block->inputs());
498  // after assign op: stack: ... <cond>
499  auto cond_branch_end =
500  insertInstruction(prim::Placeholder, source_location, {}, {}, {});
501  // after branch: stack: ...
502 
503  aliasRegistersTo(node->outputs(), body_block->inputs());
504  createJumpFalse(cond_branch, instructions.size());
505  createJumpTrue(cond_branch_end, entry);
506  } break;
507  default: { insertInstruction(node); } break;
508  }
509  }
510  }
511 
512  size_t insertInstruction(Node* n) {
513  auto inst = insertInstruction(
514  n->kind(),
515  n->getSourceLocation(),
516  n->inputs(),
517  moveFlags(n),
518  n->outputs());
519  instructions[inst].callback = getOperation(n);
520  return inst;
521  }
522  size_t insertInstruction(
523  Symbol sym,
524  std::shared_ptr<SourceLocation> debug_location,
525  ArrayRef<Value*> inputs,
526  ArrayRef<uint8_t> move_flags,
527  ArrayRef<Value*> outputs) {
528  instructions.emplace_back();
529  auto& inst = instructions.back();
530  inst.debug_name = sym;
531  inst.debug_location = std::move(debug_location);
532  listBegin(inst.inputs.values);
533  for (auto input : inputs) {
534  listInsert(inst.inputs.values, getOrAllocateRegister(input, true));
535  }
536  listBegin(inst.inputs.free_flags);
537  for (auto flag : move_flags) {
538  listInsert(inst.inputs.free_flags, flag);
539  }
540  listBegin(inst.outputs);
541  for (auto output : outputs) {
542  listInsert(inst.outputs, getOrAllocateRegister(output));
543  }
544  return instructions.size() - 1;
545  }
546  ArrayRef<uint8_t> moveFlags(Node* n) {
547  return preprocess.move_flags.at(n);
548  }
549  ArrayRef<uint8_t> moveFlags(Block* b) {
550  return moveFlags(b->return_node());
551  }
552 
553  size_t insertAssign(
554  std::shared_ptr<SourceLocation> debug_location,
555  ArrayRef<Value*> inputs,
556  ArrayRef<uint8_t> move_flags,
557  ArrayRef<Value*> outputs) {
558  auto inst = insertInstruction(
559  prim::Assign, std::move(debug_location), inputs, move_flags, outputs);
560  // This node effectively forwards its inputs into different places in a
561  // register list. We don't need to manipulate the stack in any way, because
562  // all inputs are also outputs, and the interpreter will take care of
563  // putting them in correct places.
564  instructions[inst].callback = [](Stack& stack) { return 0; };
565  return inst;
566  }
567 
568  // helpers to build/access RegList objects
569  int get(const ListHandle<int>& list, int i) const {
570  return int_data[list.start + i];
571  }
572  bool get(const ListHandle<bool>& list, int i) const {
573  return bool_data[list.start + i];
574  }
575  void listBegin(ListHandle<int>& list) {
576  list.start = int_data.size();
577  list.size = 0;
578  }
579  void listInsert(ListHandle<int>& list, int value) {
580  AT_CHECK(
581  list.start + list.size == (int)int_data.size(),
582  "another list already started");
583  int_data.push_back(value);
584  list.size++;
585  }
586  void listBegin(ListHandle<bool>& list) {
587  list.start = bool_data.size();
588  list.size = 0;
589  }
590  void listInsert(ListHandle<bool>& list, int value) {
591  AT_CHECK(
592  list.start + list.size == (int)bool_data.size(),
593  "another list already started");
594  bool_data.push_back(value);
595  list.size++;
596  }
597  // must be called before any new_allocations are used, otherwise they will
598  // already have registers assigned
599  void aliasRegistersTo(
600  ArrayRef<Value*> new_allocations,
601  ArrayRef<Value*> existing_allocations) {
602  AT_ASSERT(new_allocations.size() == existing_allocations.size());
603  for (size_t i = 0; i < new_allocations.size(); ++i) {
604  auto n = new_allocations[i]->unique();
605  auto e = existing_allocations[i]->unique();
606  AT_ASSERT(unique_to_reg.count(e) > 0 && unique_to_reg.count(n) == 0);
607  unique_to_reg[n] = unique_to_reg[e];
608  }
609  }
610  int getOrAllocateRegister(Value* n, bool required = false) {
611  size_t u = n->unique();
612  if (unique_to_reg.count(u) > 0)
613  return unique_to_reg[u];
614  AT_ASSERT(!required);
615  int r = register_size++;
616  unique_to_reg[u] = r;
617  return r;
618  }
619 
620  const std::vector<GraphExecutor*>& grad_executors() {
621  if (!grad_executors_) {
622  grad_executors_.emplace();
623  for (Instruction& instr : instructions) {
624  if (auto executor = detail::getGradExecutor(instr.callback)) {
625  grad_executors_->push_back(executor);
626  }
627  }
628  }
629  return *grad_executors_;
630  }
631 
632  void dumpInstruction(std::ostream& out, size_t pc) const {
633  auto writeList = [&](const ListHandle<int>& list) {
634  for (int i = 0; i < list.size; i++) {
635  if (i > 0)
636  out << ", ";
637  out << get(list, i);
638  }
639  };
640  auto writeUseList = [&](const UseList& list) {
641  for (int i = 0; i < list.values.size; i++) {
642  if (i > 0)
643  out << ", ";
644  if (get(list.free_flags, i))
645  out << "move(" << get(list.values, i) << ")";
646  else
647  out << get(list.values, i);
648  }
649  };
650  auto& inst = instructions.at(pc);
651  writeList(inst.outputs);
652  // NB: debug names are the kind of operator used to select
653  // dispatch
654  out << " = " << inst.debug_name.toUnqualString() << " ";
655  writeUseList(inst.inputs);
656  }
657  void dump(std::ostream& out) const {
658  for (size_t i = 0; i < instructions.size(); ++i) {
659  dumpInstruction(out, i);
660  out << "\n";
661  }
662  }
663 
664  // We MUST hold onto graph here because some Operators stored in the
665  // instruction lists have dependencies on meta-data stored in the graph
666  // that would be dead otherwise.
667  // It is also very useful for debugging interpreter problems to
668  // keep this around.
669  std::shared_ptr<Graph> graph;
671  PreprocessGraph preprocess;
672 
673  std::unordered_map<size_t, int>
674  unique_to_reg; // map from unique of nodes to register in register table
675 
676  friend struct InterpreterState;
677  std::vector<Instruction> instructions;
678  int register_size = 0;
679 
680  // all memory ArrayRef<int> are slices of this, to make sure
681  // the interpreter is mostly linearly scanning through memory
682  std::vector<int> int_data;
683  std::vector<bool> bool_data;
684 };
685 
686 // InterpreterState state that and used to compute a Code
688  InterpreterStateImpl(const Code& code)
689  : function(code.pImpl),
690  int_data(function->int_data.data()),
691  bool_data(function->bool_data),
692  registers(function->register_size) {}
693 
694  private:
695  c10::intrusive_ptr<InterpreterStateImpl> intrusive_from_this() {
696  c10::raw::intrusive_ptr::incref(this);
698  }
699 
700  bool runImpl(Stack& stack) {
701  auto& instructions = function->instructions;
702  size_t last = instructions.size();
703 
704  while (pc < last) {
705  // std::cout << "executing " << pc << ": ";
706  // function->dumpInstruction(std::cout, pc);
707  // std::cout << "\n";
708  auto& inst = instructions[pc];
709  try {
710  loadTensorsFromRegisters(inst.inputs, stack);
711  size_t new_pc = pc + 1 + inst.callback(stack);
712  for (int i = inst.outputs.size - 1; i >= 0; --i) {
713  int reg = get(inst.outputs, i);
714  registers[reg] = pop(stack);
715  // std::cout << "pop reg[" << reg << "];\n" << registers[reg] << "\n";
716  }
717  pc = new_pc;
718  } catch (Suspend& e) {
719  // wait() expects a single input
720  AT_ASSERT(inst.inputs.values.size == 1);
721 
722  getOrCreateFuture();
723 
724  if (get(inst.inputs.free_flags, 0)) {
725  // make sure the register is not freed once we are waked up
726  registers[get(inst.inputs.values, 0)] = e.future;
727  }
728 
729  // Make sure adding callback is the last step.
730  // Otherwise if e.future has completed,
731  // the current thread will continue running before it suspends.
732  InterpreterState state(intrusive_from_this());
733  e.future->addCallback([state]() {
734  c10::global_work_queue().run(InterpreterContinuation(state, Stack(),
735  autograd::GradMode::is_enabled()));
736  });
737 
738  return true;
739  } catch (Future::FutureError& e) {
740  // Error from the forked thread.
741  auto msg = e.error_msg; // copy the error for each callback
742  handleError(std::move(msg), false);
743  return false;
744  } catch (std::exception& e) {
745  // Error from the current thread
746  bool is_jit_exception = dynamic_cast<JITException*>(&e);
747  if (instructions[pc].debug_location) {
748  handleError(
749  instructions[pc].debug_location->wrapException(
750  e, "operation failed in interpreter"),
751  is_jit_exception);
752  } else {
753  handleError(e.what(), is_jit_exception);
754  }
755  return false;
756  }
757  }
758  if (future) {
759  auto num_outputs = function->preprocess.n_outputs;
760  if (num_outputs == 1) {
761  future->markCompleted(stack.back());
762  } else {
763  future->markCompleted(
764  Tuple::create(jit::last(stack, num_outputs).vec()));
765  }
766  }
767 
768  return false;
769  }
770 
771  void handleError(std::string&& error_msg, bool is_jit_exception) {
772  if (future) {
773  future->markCompleted(Future::FutureError(std::move(error_msg)));
774  } else if (is_jit_exception) {
775  throw JITException(std::move(error_msg));
776  } else {
777  throw std::runtime_error(std::move(error_msg));
778  }
779  }
780 
781  public:
782  c10::intrusive_ptr<Future> getOrCreateFuture() {
783  if (!future) {
784  future = c10::make_intrusive<Future>();
785  }
786  return future;
787  }
788 
789  c10::intrusive_ptr<Future> runAsync(Stack& stack) {
790  getOrCreateFuture();
791  runImpl(stack);
792  return future;
793  }
794 
795  void run(Stack& stack) {
796  if (runImpl(stack)) {
797  future->wait();
798 
799  auto num_outputs = function->preprocess.n_outputs;
800  if (num_outputs == 1) {
801  push(stack, future->value());
802  } else {
803  auto tuple = future->value().toTuple();
804  for (const auto& value : tuple->elements()) {
805  push(stack, value);
806  }
807  }
808  }
809  }
810 
811  int get(const ListHandle<int>& list, int i) {
812  return int_data[list.start + i];
813  };
814  bool get(const ListHandle<bool>& list, int i) {
815  return bool_data[list.start + i];
816  }
817  void loadTensorsFromRegisters(const UseList& uses, Stack& stack) {
818  for (int i = 0; i < uses.values.size; i++) {
819  int reg = get(uses.values, i);
820  // std::cout << "push reg[" << reg << "];\n" << registers[reg] << "\n\n";
821  if (get(uses.free_flags, i)) {
822  stack.push_back(std::move(registers[reg]));
823  } else {
824  stack.push_back(registers[reg]);
825  }
826  }
827  }
828 
829  // pc is critical for the interperter to pick up the progress from suspend
830  size_t pc = 0;
832  std::shared_ptr<CodeImpl> function; // keep function alive
833  // these are just copies of function to prevent indirections in interpreter
834  int* int_data;
835  const std::vector<bool>& bool_data;
836 
837  // this holds all the tensors for this interpreter run
838  // we don't bother minimizing the size of this vector, since the extra
839  // memory used by the pointers in this will be small
840  // instead we are very aggresive about releasing tensors when they become dead
841  // to make sure memory management happens efficiently.
842 
843  // We optimize for the case where derivatives are run with retain_graph=False
844  // in the case where it is true, then the interpreter and this array get
845  // copied if this every becomes a bottleneck then we _should_ consider
846  // minimizing the total number or register
847  std::vector<IValue> registers;
848 
849  // single buffer for input/output calls to ATen functions, so that we do not
850  // reallocate
851  Stack stack;
852 };
853 
854 std::ostream& operator<<(std::ostream& out, const Code& code) {
855  out << *code.pImpl->graph << "\n";
856  code.pImpl->dump(out);
857  return out;
858 }
859 
860 Code::Code(const std::shared_ptr<Graph>& graph) : pImpl(new CodeImpl(graph)) {}
861 Code::~Code() = default;
862 
863 const std::vector<GraphExecutor*>& Code::grad_executors() {
864  return pImpl->grad_executors();
865 }
866 
867 InterpreterState::InterpreterState(const Code& code)
868  : pImpl(c10::make_intrusive<InterpreterStateImpl>(code)) {}
869 InterpreterState::~InterpreterState() = default;
870 
871 void InterpreterState::run(Stack& stack) {
872  static_cast<InterpreterStateImpl*>(pImpl.get())->run(stack);
873 }
874 
875 c10::intrusive_ptr<Future> InterpreterState::runAsync(Stack& stack) {
876  return static_cast<InterpreterStateImpl*>(pImpl.get())->runAsync(stack);
877 }
878 
879 c10::intrusive_ptr<Future> InterpreterState::getFuture() {
880  return static_cast<InterpreterStateImpl*>(pImpl.get())->getOrCreateFuture();
881 }
882 
883 InterpreterState::InterpreterState(
885  : pImpl(std::move(pImpl_)) {}
886 
887 void InterpreterContinuation::operator()() {
888  autograd::AutoGradMode grad_mode(grad_mode_enabled);
889  state.runAsync(stack);
890 }
891 } // namespace jit
892 } // namespace torch
const at::Storage & storage() const override
Return the underlying storage of a Tensor.
at::IntArrayRef strides() const override
Return a reference to the strides of this tensor.
The low-level representation of a tensor, which contains a pointer to a storage (which contains the a...
Definition: TensorImpl.h:211
at::IntArrayRef sizes() const override
Return a reference to the sizes of this tensor.
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
int64_t dim() const override
Return the number of dimensions of this tensor.
Definition: jit_type.h:17
intrusive_ptr<T> is an alternative to shared_ptr<T> that has better performance because it does the r...
Definition: intrusive_ptr.h:35
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:324
AT_CPP14_CONSTEXPR const T & at(size_t Index) const
Vector compatibility.
Definition: ArrayRef.h:186
static intrusive_ptr reclaim(TTarget *owning_ptr)
Takes an owning pointer to TTarget* and creates an intrusive_ptr that takes over ownership.