1 #include <torch/csrc/jit/interpreter.h> 3 #include <torch/csrc/autograd/edge.h> 4 #include <torch/csrc/autograd/function.h> 5 #include <torch/csrc/autograd/generated/variable_factories.h> 6 #include <torch/csrc/autograd/grad_mode.h> 7 #include <torch/csrc/autograd/profiler.h> 8 #include <torch/csrc/autograd/variable.h> 9 #include <c10/util/Exception.h> 10 #include <torch/csrc/jit/constants.h> 11 #include <torch/csrc/jit/graph_executor.h> 12 #include <torch/csrc/jit/ir.h> 13 #include <ATen/core/ivalue.h> 14 #include <torch/csrc/jit/operator.h> 15 #include <torch/csrc/jit/script/jit_exception.h> 16 #include <c10/core/thread_pool.h> 25 #include <unordered_map> 26 #include <unordered_set> 56 Value* createTripCountConjunctiveCondition(
58 Value* cur_trip_count,
59 Value* max_trip_count,
62 Value* initial_comparison_value =
63 g->insertNode(g->create(aten::lt, {cur_trip_count, max_trip_count}, 1))
65 ->setType(BoolType::get());
71 g->create(aten::__and__, {initial_comparison_value, cond}, 1))
73 ->setType(BoolType::get());
80 void desugarTripCounts(Block* b) {
81 for (
auto n : b->nodes()) {
82 if (n->kind() == prim::Loop) {
83 auto g = n->owningGraph();
84 auto body_block = n->blocks()[0];
86 Value* block_trip_count_input = body_block->inputs()[0];
91 Value* max_trip_count_value = n->input(0);
93 WithInsertPoint guard(n);
95 Value* initial_trip_count = g->insertConstant(0);
100 n->insertInput(1, initial_trip_count);
102 Value* new_cond = createTripCountConjunctiveCondition(
103 g, initial_trip_count, max_trip_count_value, n->input(0));
104 n->replaceInput(0, new_cond);
108 WithInsertPoint guard(body_block);
113 Value* const_one = g->insertConstant(1);
115 Value* inc_trip_count =
117 g->create(aten::add, {block_trip_count_input, const_one}, 1))
119 ->setType(IntType::get());
120 body_block->insertOutput(1, inc_trip_count);
122 Value* body_cond = createTripCountConjunctiveCondition(
123 g, inc_trip_count, max_trip_count_value, body_block->outputs()[0]);
124 body_block->eraseOutput(0);
125 body_block->insertOutput(0, body_cond);
128 for (
auto sb : n->blocks()) {
129 desugarTripCounts(sb);
136 static void flattenIO(Graph& graph) {
137 auto load = graph.prependNode(graph.create(prim::Load, 0));
138 for (
auto old_input : graph.inputs()) {
139 auto nv = load->addOutput();
140 nv->setType(old_input->type());
141 old_input->replaceAllUsesWith(nv);
143 graph.appendNode(graph.create(prim::Store, graph.outputs(), 0));
145 while (graph.inputs().size() > 0)
146 graph.eraseInput(graph.inputs().size() - 1);
147 while (graph.outputs().size() > 0)
148 graph.eraseOutput(graph.outputs().size() - 1);
156 void dropUnused(Block* b) {
157 auto createDropIfUnused = [&](ArrayRef<Value*> values) -> Node* {
158 std::vector<Value*> to_drop;
159 for (
auto v : values) {
160 if (v->uses().size() == 0)
161 to_drop.push_back(v);
163 if (to_drop.size() == 0)
165 return b->owningGraph()->create(prim::Drop, to_drop, 0);
168 if (
auto d = createDropIfUnused(b->inputs())) {
171 for (
auto n : b->nodes()) {
172 if (
auto d = createDropIfUnused(n->outputs())) {
175 for (
auto b : n->blocks())
181 std::unordered_map<Node*, std::vector<uint8_t>> findLastUses(Graph& g) {
183 struct FindLastUses {
186 std::unordered_set<Value*> seen;
188 std::unordered_map<Node*, std::vector<uint8_t>> move_flags;
192 std::unordered_map<Node*, Node*> drop_for_node;
194 FindLastUses(Graph& g) : graph(g) {
195 scanBlock(graph.block());
197 void scanBlock(Block* b) {
198 scanNode(b->return_node());
199 for (
auto n : b->nodes().reverse()) {
203 void scanNode(Node* n) {
204 for (
auto b : n->blocks()) {
207 move_flags[n].resize(n->inputs().size());
210 for (
size_t i = n->inputs().size(); i > 0; --i) {
214 void scanUse(Node* n,
size_t i) {
215 auto& move_flags_n = move_flags[n];
216 auto v = n->inputs()[i];
217 auto inserted = seen.insert(v).second;
219 move_flags_n[i] =
false;
236 Node* same_depth_node = findOwnerInBlock(n, v->node()->owningBlock());
242 if (same_depth_node == n) {
243 move_flags_n[i] =
true;
249 move_flags_n[i] =
false;
250 addToDropIfNotExists(
251 findOrCreateDropInstructionForNode(same_depth_node), v);
260 Node* findOwnerInBlock(Node* n, Block* block) {
261 while (n !=
nullptr && block != n->owningBlock()) {
262 n = n->owningBlock()->owningNode();
267 Node* findOrCreateDropInstructionForNode(Node* n) {
268 auto it = drop_for_node.find(n);
269 if (it == drop_for_node.end()) {
270 auto drop_node = graph.create(prim::Drop, 0);
271 drop_node->insertAfter(n);
272 it = drop_for_node.emplace(n, drop_node).first;
277 void addToDropIfNotExists(Node* drop, Value* v) {
278 for (
auto i : drop->inputs()) {
284 move_flags[drop].push_back(
true);
288 return FindLastUses(g).move_flags;
295 n_outputs = graph->outputs().size();
296 desugarTripCounts(graph->block());
298 dropUnused(graph->block());
300 move_flags = findLastUses(*graph);
304 std::shared_ptr<Graph> graph;
306 std::unordered_map<Node*, std::vector<uint8_t>> move_flags;
321 at::UndefinedTensorId(),
328 throw std::runtime_error(
"sizes() on ContainerTensor");
331 throw std::runtime_error(
"strides() on ContainerTensor");
333 int64_t
dim()
const override {
334 throw std::runtime_error(
"dim() on ContainerTensor");
337 throw std::runtime_error(
"storage() on ContainerTensor");
346 template <
typename T>
366 std::shared_ptr<SourceLocation> debug_location;
369 int relativeJump(
int from_inst,
int to_inst) {
370 return to_inst - (from_inst + 1);
374 CodeImpl(
const std::shared_ptr<Graph>& graph_) : preprocess(*graph_) {
375 graph = preprocess.graph;
376 insertNodesFromBlock(graph->block());
380 void createJumpFalse(
int from_inst,
int to_inst) {
381 auto& inst = instructions[from_inst];
382 AT_ASSERT(inst.debug_name == prim::Placeholder);
383 auto offset = relativeJump(from_inst, to_inst);
384 inst.callback = [offset](Stack& stack) {
385 auto t = pop(stack).toBool();
386 return t ? 0 : offset;
388 inst.debug_name = prim::JumpZ;
392 void createJumpTrue(
int from_inst,
int to_inst) {
393 auto& inst = instructions[from_inst];
394 AT_ASSERT(inst.debug_name == prim::Placeholder);
395 auto offset = relativeJump(from_inst, to_inst);
396 inst.callback = [offset](Stack& stack) {
397 auto t = pop(stack).toBool();
398 return t ? offset : 0;
400 inst.debug_name = prim::JumpNZ;
403 void createJump(
int from_inst,
int to_inst) {
404 auto& inst = instructions[from_inst];
405 AT_ASSERT(inst.debug_name == prim::Placeholder);
406 auto offset = relativeJump(from_inst, to_inst);
407 inst.callback = [=](Stack& stack) {
return offset; };
408 inst.debug_name = prim::Jump;
411 void insertNodesFromBlock(
Block* block) {
412 for (
auto node : block->nodes()) {
413 const auto& source_location = node->getSourceLocation();
414 switch (node->kind()) {
435 auto cond_branch = insertInstruction(
441 auto then_block = node->blocks()[0];
442 auto else_block = node->blocks()[1];
443 insertNodesFromBlock(else_block);
446 else_block->outputs(),
447 moveFlags(else_block),
450 insertInstruction(prim::Placeholder, source_location, {}, {}, {});
451 auto then_block_start = instructions.size();
452 insertNodesFromBlock(then_block);
455 then_block->outputs(),
456 moveFlags(then_block),
458 createJump(jump, instructions.size());
459 createJumpTrue(cond_branch, then_block_start);
476 auto body_block = node->blocks()[0];
483 body_block->inputs());
487 insertInstruction(prim::Placeholder, source_location, {}, {}, {});
490 auto entry = instructions.size();
491 insertNodesFromBlock(body_block);
495 body_block->outputs(),
496 moveFlags(body_block),
497 body_block->inputs());
499 auto cond_branch_end =
500 insertInstruction(prim::Placeholder, source_location, {}, {}, {});
503 aliasRegistersTo(node->outputs(), body_block->inputs());
504 createJumpFalse(cond_branch, instructions.size());
505 createJumpTrue(cond_branch_end, entry);
507 default: { insertInstruction(node); }
break;
512 size_t insertInstruction(
Node* n) {
513 auto inst = insertInstruction(
515 n->getSourceLocation(),
519 instructions[inst].callback = getOperation(n);
522 size_t insertInstruction(
524 std::shared_ptr<SourceLocation> debug_location,
528 instructions.emplace_back();
529 auto& inst = instructions.back();
530 inst.debug_name = sym;
531 inst.debug_location = std::move(debug_location);
532 listBegin(inst.inputs.values);
533 for (
auto input : inputs) {
534 listInsert(inst.inputs.values, getOrAllocateRegister(input,
true));
536 listBegin(inst.inputs.free_flags);
537 for (
auto flag : move_flags) {
538 listInsert(inst.inputs.free_flags, flag);
540 listBegin(inst.outputs);
541 for (
auto output : outputs) {
542 listInsert(inst.outputs, getOrAllocateRegister(output));
544 return instructions.size() - 1;
547 return preprocess.move_flags.
at(n);
550 return moveFlags(b->return_node());
554 std::shared_ptr<SourceLocation> debug_location,
558 auto inst = insertInstruction(
559 prim::Assign, std::move(debug_location), inputs, move_flags, outputs);
564 instructions[inst].callback = [](Stack& stack) {
return 0; };
570 return int_data[list.start + i];
573 return bool_data[list.start + i];
576 list.start = int_data.size();
581 list.start + list.size == (
int)int_data.size(),
582 "another list already started");
583 int_data.push_back(value);
587 list.start = bool_data.size();
592 list.start + list.size == (
int)bool_data.size(),
593 "another list already started");
594 bool_data.push_back(value);
599 void aliasRegistersTo(
602 AT_ASSERT(new_allocations.
size() == existing_allocations.
size());
603 for (
size_t i = 0; i < new_allocations.
size(); ++i) {
604 auto n = new_allocations[i]->unique();
605 auto e = existing_allocations[i]->unique();
606 AT_ASSERT(unique_to_reg.count(e) > 0 && unique_to_reg.count(n) == 0);
607 unique_to_reg[n] = unique_to_reg[e];
610 int getOrAllocateRegister(
Value* n,
bool required =
false) {
611 size_t u = n->unique();
612 if (unique_to_reg.count(u) > 0)
613 return unique_to_reg[u];
614 AT_ASSERT(!required);
615 int r = register_size++;
616 unique_to_reg[u] = r;
620 const std::vector<GraphExecutor*>& grad_executors() {
621 if (!grad_executors_) {
622 grad_executors_.emplace();
624 if (
auto executor = detail::getGradExecutor(instr.callback)) {
625 grad_executors_->push_back(executor);
629 return *grad_executors_;
632 void dumpInstruction(std::ostream& out,
size_t pc)
const {
634 for (
int i = 0; i < list.size; i++) {
640 auto writeUseList = [&](
const UseList& list) {
641 for (
int i = 0; i < list.values.size; i++) {
644 if (
get(list.free_flags, i))
645 out <<
"move(" <<
get(list.values, i) <<
")";
647 out <<
get(list.values, i);
650 auto& inst = instructions.at(pc);
651 writeList(inst.outputs);
654 out <<
" = " << inst.debug_name.toUnqualString() <<
" ";
655 writeUseList(inst.inputs);
657 void dump(std::ostream& out)
const {
658 for (
size_t i = 0; i < instructions.size(); ++i) {
659 dumpInstruction(out, i);
669 std::shared_ptr<Graph> graph;
673 std::unordered_map<size_t, int>
677 std::vector<Instruction> instructions;
678 int register_size = 0;
682 std::vector<int> int_data;
683 std::vector<bool> bool_data;
689 :
function(code.pImpl),
690 int_data(function->int_data.data()),
691 bool_data(function->bool_data),
692 registers(function->register_size) {}
696 c10::raw::intrusive_ptr::incref(
this);
700 bool runImpl(Stack& stack) {
701 auto& instructions =
function->instructions;
702 size_t last = instructions.size();
708 auto& inst = instructions[pc];
710 loadTensorsFromRegisters(inst.inputs, stack);
711 size_t new_pc = pc + 1 + inst.callback(stack);
712 for (
int i = inst.outputs.size - 1; i >= 0; --i) {
713 int reg =
get(inst.outputs, i);
714 registers[reg] = pop(stack);
720 AT_ASSERT(inst.inputs.values.size == 1);
724 if (
get(inst.inputs.free_flags, 0)) {
726 registers[
get(inst.inputs.values, 0)] = e.future;
733 e.future->addCallback([state]() {
735 autograd::GradMode::is_enabled()));
739 }
catch (Future::FutureError& e) {
741 auto msg = e.error_msg;
742 handleError(std::move(msg),
false);
744 }
catch (std::exception& e) {
746 bool is_jit_exception =
dynamic_cast<JITException*
>(&e);
747 if (instructions[pc].debug_location) {
749 instructions[pc].debug_location->wrapException(
750 e,
"operation failed in interpreter"),
753 handleError(e.what(), is_jit_exception);
759 auto num_outputs =
function->preprocess.n_outputs;
760 if (num_outputs == 1) {
761 future->markCompleted(stack.back());
763 future->markCompleted(
764 Tuple::create(jit::last(stack, num_outputs).vec()));
771 void handleError(std::string&& error_msg,
bool is_jit_exception) {
773 future->markCompleted(Future::FutureError(std::move(error_msg)));
774 }
else if (is_jit_exception) {
777 throw std::runtime_error(std::move(error_msg));
784 future = c10::make_intrusive<Future>();
795 void run(Stack& stack) {
796 if (runImpl(stack)) {
799 auto num_outputs =
function->preprocess.n_outputs;
800 if (num_outputs == 1) {
801 push(stack, future->value());
803 auto tuple = future->value().toTuple();
804 for (
const auto& value : tuple->elements()) {
812 return int_data[list.start + i];
815 return bool_data[list.start + i];
817 void loadTensorsFromRegisters(
const UseList& uses, Stack& stack) {
818 for (
int i = 0; i < uses.values.size; i++) {
819 int reg =
get(uses.values, i);
821 if (
get(uses.free_flags, i)) {
822 stack.push_back(std::move(registers[reg]));
824 stack.push_back(registers[reg]);
832 std::shared_ptr<CodeImpl>
function;
835 const std::vector<bool>& bool_data;
847 std::vector<IValue> registers;
854 std::ostream& operator<<(std::ostream& out,
const Code& code) {
855 out << *code.pImpl->graph <<
"\n";
856 code.pImpl->dump(out);
860 Code::Code(
const std::shared_ptr<Graph>& graph) : pImpl(
new CodeImpl(graph)) {}
861 Code::~Code() =
default;
863 const std::vector<GraphExecutor*>& Code::grad_executors() {
864 return pImpl->grad_executors();
867 InterpreterState::InterpreterState(
const Code& code)
868 : pImpl(c10::make_intrusive<InterpreterStateImpl>(code)) {}
869 InterpreterState::~InterpreterState() =
default;
871 void InterpreterState::run(Stack& stack) {
883 InterpreterState::InterpreterState(
885 : pImpl(std::move(pImpl_)) {}
887 void InterpreterContinuation::operator()() {
889 state.runAsync(stack);
const at::Storage & storage() const override
Return the underlying storage of a Tensor.
at::IntArrayRef strides() const override
Return a reference to the strides of this tensor.
The low-level representation of a tensor, which contains a pointer to a storage (which contains the a...
at::IntArrayRef sizes() const override
Return a reference to the sizes of this tensor.
constexpr size_t size() const
size - Get the array size.
int64_t dim() const override
Return the number of dimensions of this tensor.
intrusive_ptr<T> is an alternative to shared_ptr<T> that has better performance because it does the r...
AT_CPP14_CONSTEXPR const T & at(size_t Index) const
Vector compatibility.
static intrusive_ptr reclaim(TTarget *owning_ptr)
Takes an owning pointer to TTarget* and creates an intrusive_ptr that takes over ownership.