1 #include <torch/csrc/jit/interpreter.h>     3 #include <torch/csrc/autograd/edge.h>     4 #include <torch/csrc/autograd/function.h>     5 #include <torch/csrc/autograd/generated/variable_factories.h>     6 #include <torch/csrc/autograd/grad_mode.h>     7 #include <torch/csrc/autograd/profiler.h>     8 #include <torch/csrc/autograd/variable.h>     9 #include <c10/util/Exception.h>    10 #include <torch/csrc/jit/constants.h>    11 #include <torch/csrc/jit/graph_executor.h>    12 #include <torch/csrc/jit/ir.h>    13 #include <ATen/core/ivalue.h>    14 #include <torch/csrc/jit/operator.h>    15 #include <torch/csrc/jit/script/jit_exception.h>    16 #include <c10/core/thread_pool.h>    25 #include <unordered_map>    26 #include <unordered_set>    56 Value* createTripCountConjunctiveCondition(
    58     Value* cur_trip_count,
    59     Value* max_trip_count,
    62   Value* initial_comparison_value =
    63       g->insertNode(g->create(aten::lt, {cur_trip_count, max_trip_count}, 1))
    65           ->setType(BoolType::get());
    71            g->create(aten::__and__, {initial_comparison_value, cond}, 1))
    73           ->setType(BoolType::get());
    80 void desugarTripCounts(Block* b) {
    81   for (
auto n : b->nodes()) {
    82     if (n->kind() == prim::Loop) {
    83       auto g = n->owningGraph();
    84       auto body_block = n->blocks()[0];
    86       Value* block_trip_count_input = body_block->inputs()[0];
    91       Value* max_trip_count_value = n->input(0);
    93         WithInsertPoint guard(n);
    95         Value* initial_trip_count = g->insertConstant(0);
   100         n->insertInput(1, initial_trip_count);
   102         Value* new_cond = createTripCountConjunctiveCondition(
   103             g, initial_trip_count, max_trip_count_value, n->input(0));
   104         n->replaceInput(0, new_cond);
   108         WithInsertPoint guard(body_block);
   113         Value* const_one = g->insertConstant(1);
   115         Value* inc_trip_count =
   117                  g->create(aten::add, {block_trip_count_input, const_one}, 1))
   119                 ->setType(IntType::get());
   120         body_block->insertOutput(1, inc_trip_count);
   122         Value* body_cond = createTripCountConjunctiveCondition(
   123             g, inc_trip_count, max_trip_count_value, body_block->outputs()[0]);
   124         body_block->eraseOutput(0);
   125         body_block->insertOutput(0, body_cond);
   128     for (
auto sb : n->blocks()) {
   129       desugarTripCounts(sb);
   136 static void flattenIO(Graph& graph) {
   137   auto load = graph.prependNode(graph.create(prim::Load, 0));
   138   for (
auto old_input : graph.inputs()) {
   139     auto nv = load->addOutput();
   140     nv->setType(old_input->type());
   141     old_input->replaceAllUsesWith(nv);
   143   graph.appendNode(graph.create(prim::Store, graph.outputs(), 0));
   145   while (graph.inputs().size() > 0)
   146     graph.eraseInput(graph.inputs().size() - 1);
   147   while (graph.outputs().size() > 0)
   148     graph.eraseOutput(graph.outputs().size() - 1);
   156 void dropUnused(Block* b) {
   157   auto createDropIfUnused = [&](ArrayRef<Value*> values) -> Node* {
   158     std::vector<Value*> to_drop;
   159     for (
auto v : values) {
   160       if (v->uses().size() == 0)
   161         to_drop.push_back(v);
   163     if (to_drop.size() == 0)
   165     return b->owningGraph()->create(prim::Drop, to_drop, 0);
   168   if (
auto d = createDropIfUnused(b->inputs())) {
   171   for (
auto n : b->nodes()) {
   172     if (
auto d = createDropIfUnused(n->outputs())) {
   175     for (
auto b : n->blocks())
   181 std::unordered_map<Node*, std::vector<uint8_t>> findLastUses(Graph& g) {
   183   struct FindLastUses {
   186     std::unordered_set<Value*> seen;
   188     std::unordered_map<Node*, std::vector<uint8_t>> move_flags;
   192     std::unordered_map<Node*, Node*> drop_for_node;
   194     FindLastUses(Graph& g) : graph(g) {
   195       scanBlock(graph.block());
   197     void scanBlock(Block* b) {
   198       scanNode(b->return_node());
   199       for (
auto n : b->nodes().reverse()) {
   203     void scanNode(Node* n) {
   204       for (
auto b : n->blocks()) {
   207       move_flags[n].resize(n->inputs().size());
   210       for (
size_t i = n->inputs().size(); i > 0; --i) {
   214     void scanUse(Node* n, 
size_t i) {
   215       auto& move_flags_n = move_flags[n];
   216       auto v = n->inputs()[i];
   217       auto inserted = seen.insert(v).second;
   219         move_flags_n[i] = 
false;
   236       Node* same_depth_node = findOwnerInBlock(n, v->node()->owningBlock());
   242       if (same_depth_node == n) {
   243         move_flags_n[i] = 
true;
   249       move_flags_n[i] = 
false;
   250       addToDropIfNotExists(
   251           findOrCreateDropInstructionForNode(same_depth_node), v);
   260     Node* findOwnerInBlock(Node* n, Block* block) {
   261       while (n != 
nullptr && block != n->owningBlock()) {
   262         n = n->owningBlock()->owningNode();
   267     Node* findOrCreateDropInstructionForNode(Node* n) {
   268       auto it = drop_for_node.find(n);
   269       if (it == drop_for_node.end()) {
   270         auto drop_node = graph.create(prim::Drop, 0);
   271         drop_node->insertAfter(n);
   272         it = drop_for_node.emplace(n, drop_node).first;
   277     void addToDropIfNotExists(Node* drop, Value* v) {
   278       for (
auto i : drop->inputs()) {
   284       move_flags[drop].push_back(
true);
   288   return FindLastUses(g).move_flags;
   295     n_outputs = graph->outputs().size();
   296     desugarTripCounts(graph->block());
   298     dropUnused(graph->block());
   300     move_flags = findLastUses(*graph);
   304   std::shared_ptr<Graph> graph;
   306   std::unordered_map<Node*, std::vector<uint8_t>> move_flags;
   321             at::UndefinedTensorId(),
   328     throw std::runtime_error(
"sizes() on ContainerTensor");
   331     throw std::runtime_error(
"strides() on ContainerTensor");
   333   int64_t 
dim()
 const override {
   334     throw std::runtime_error(
"dim() on ContainerTensor");
   337     throw std::runtime_error(
"storage() on ContainerTensor");
   346 template <
typename T>
   366   std::shared_ptr<SourceLocation> debug_location; 
   369 int relativeJump(
int from_inst, 
int to_inst) {
   370   return to_inst - (from_inst + 1);
   374   CodeImpl(
const std::shared_ptr<Graph>& graph_) : preprocess(*graph_) {
   375     graph = preprocess.graph;
   376     insertNodesFromBlock(graph->block());
   380   void createJumpFalse(
int from_inst, 
int to_inst) {
   381     auto& inst = instructions[from_inst];
   382     AT_ASSERT(inst.debug_name == prim::Placeholder);
   383     auto offset = relativeJump(from_inst, to_inst);
   384     inst.callback = [offset](Stack& stack) {
   385       auto t = pop(stack).toBool();
   386       return t ? 0 : offset;
   388     inst.debug_name = prim::JumpZ;
   392   void createJumpTrue(
int from_inst, 
int to_inst) {
   393     auto& inst = instructions[from_inst];
   394     AT_ASSERT(inst.debug_name == prim::Placeholder);
   395     auto offset = relativeJump(from_inst, to_inst);
   396     inst.callback = [offset](Stack& stack) {
   397       auto t = pop(stack).toBool();
   398       return t ? offset : 0;
   400     inst.debug_name = prim::JumpNZ;
   403   void createJump(
int from_inst, 
int to_inst) {
   404     auto& inst = instructions[from_inst];
   405     AT_ASSERT(inst.debug_name == prim::Placeholder);
   406     auto offset = relativeJump(from_inst, to_inst);
   407     inst.callback = [=](Stack& stack) { 
return offset; };
   408     inst.debug_name = prim::Jump;
   411   void insertNodesFromBlock(
Block* block) {
   412     for (
auto node : block->nodes()) {
   413       const auto& source_location = node->getSourceLocation();
   414       switch (node->kind()) {
   435           auto cond_branch = insertInstruction(
   441           auto then_block = node->blocks()[0];
   442           auto else_block = node->blocks()[1];
   443           insertNodesFromBlock(else_block);
   446               else_block->outputs(),
   447               moveFlags(else_block),
   450               insertInstruction(prim::Placeholder, source_location, {}, {}, {});
   451           auto then_block_start = instructions.size();
   452           insertNodesFromBlock(then_block);
   455               then_block->outputs(),
   456               moveFlags(then_block),
   458           createJump(jump, instructions.size());
   459           createJumpTrue(cond_branch, then_block_start);
   476           auto body_block = node->blocks()[0];
   483               body_block->inputs());
   487               insertInstruction(prim::Placeholder, source_location, {}, {}, {});
   490           auto entry = instructions.size();
   491           insertNodesFromBlock(body_block);
   495               body_block->outputs(),
   496               moveFlags(body_block),
   497               body_block->inputs());
   499           auto cond_branch_end =
   500               insertInstruction(prim::Placeholder, source_location, {}, {}, {});
   503           aliasRegistersTo(node->outputs(), body_block->inputs());
   504           createJumpFalse(cond_branch, instructions.size());
   505           createJumpTrue(cond_branch_end, entry);
   507         default: { insertInstruction(node); } 
break;
   512   size_t insertInstruction(
Node* n) {
   513     auto inst = insertInstruction(
   515         n->getSourceLocation(),
   519     instructions[inst].callback = getOperation(n);
   522   size_t insertInstruction(
   524       std::shared_ptr<SourceLocation> debug_location,
   528     instructions.emplace_back();
   529     auto& inst = instructions.back();
   530     inst.debug_name = sym;
   531     inst.debug_location = std::move(debug_location);
   532     listBegin(inst.inputs.values);
   533     for (
auto input : inputs) {
   534       listInsert(inst.inputs.values, getOrAllocateRegister(input, 
true));
   536     listBegin(inst.inputs.free_flags);
   537     for (
auto flag : move_flags) {
   538       listInsert(inst.inputs.free_flags, flag);
   540     listBegin(inst.outputs);
   541     for (
auto output : outputs) {
   542       listInsert(inst.outputs, getOrAllocateRegister(output));
   544     return instructions.size() - 1;
   547     return preprocess.move_flags.
at(n);
   550     return moveFlags(b->return_node());
   554       std::shared_ptr<SourceLocation> debug_location,
   558     auto inst = insertInstruction(
   559         prim::Assign, std::move(debug_location), inputs, move_flags, outputs);
   564     instructions[inst].callback = [](Stack& stack) { 
return 0; };
   570     return int_data[list.start + i];
   573     return bool_data[list.start + i];
   576     list.start = int_data.size();
   581         list.start + list.size == (
int)int_data.size(),
   582         "another list already started");
   583     int_data.push_back(value);
   587     list.start = bool_data.size();
   592         list.start + list.size == (
int)bool_data.size(),
   593         "another list already started");
   594     bool_data.push_back(value);
   599   void aliasRegistersTo(
   602     AT_ASSERT(new_allocations.
size() == existing_allocations.
size());
   603     for (
size_t i = 0; i < new_allocations.
size(); ++i) {
   604       auto n = new_allocations[i]->unique();
   605       auto e = existing_allocations[i]->unique();
   606       AT_ASSERT(unique_to_reg.count(e) > 0 && unique_to_reg.count(n) == 0);
   607       unique_to_reg[n] = unique_to_reg[e];
   610   int getOrAllocateRegister(
Value* n, 
bool required = 
false) {
   611     size_t u = n->unique();
   612     if (unique_to_reg.count(u) > 0)
   613       return unique_to_reg[u];
   614     AT_ASSERT(!required);
   615     int r = register_size++;
   616     unique_to_reg[u] = r;
   620   const std::vector<GraphExecutor*>& grad_executors() {
   621     if (!grad_executors_) {
   622       grad_executors_.emplace();
   624         if (
auto executor = detail::getGradExecutor(instr.callback)) {
   625           grad_executors_->push_back(executor);
   629     return *grad_executors_;
   632   void dumpInstruction(std::ostream& out, 
size_t pc)
 const {
   634       for (
int i = 0; i < list.size; i++) {
   640     auto writeUseList = [&](
const UseList& list) {
   641       for (
int i = 0; i < list.values.size; i++) {
   644         if (
get(list.free_flags, i))
   645           out << 
"move(" << 
get(list.values, i) << 
")";
   647           out << 
get(list.values, i);
   650     auto& inst = instructions.at(pc);
   651     writeList(inst.outputs);
   654     out << 
" = " << inst.debug_name.toUnqualString() << 
" ";
   655     writeUseList(inst.inputs);
   657   void dump(std::ostream& out)
 const {
   658     for (
size_t i = 0; i < instructions.size(); ++i) {
   659       dumpInstruction(out, i);
   669   std::shared_ptr<Graph> graph;
   673   std::unordered_map<size_t, int>
   677   std::vector<Instruction> instructions;
   678   int register_size = 0;
   682   std::vector<int> int_data;
   683   std::vector<bool> bool_data;
   689       : 
function(code.pImpl),
   690         int_data(function->int_data.data()),
   691         bool_data(function->bool_data),
   692         registers(function->register_size) {}
   696     c10::raw::intrusive_ptr::incref(
this);
   700   bool runImpl(Stack& stack) {
   701     auto& instructions = 
function->instructions;
   702     size_t last = instructions.size();
   708       auto& inst = instructions[pc];
   710         loadTensorsFromRegisters(inst.inputs, stack);
   711         size_t new_pc = pc + 1 + inst.callback(stack);
   712         for (
int i = inst.outputs.size - 1; i >= 0; --i) {
   713           int reg = 
get(inst.outputs, i);
   714           registers[reg] = pop(stack);
   720         AT_ASSERT(inst.inputs.values.size == 1);
   724         if (
get(inst.inputs.free_flags, 0)) {
   726           registers[
get(inst.inputs.values, 0)] = e.future;
   733         e.future->addCallback([state]() {
   735               autograd::GradMode::is_enabled()));
   739       } 
catch (Future::FutureError& e) {
   741         auto msg = e.error_msg; 
   742         handleError(std::move(msg), 
false);
   744       } 
catch (std::exception& e) {
   746         bool is_jit_exception = 
dynamic_cast<JITException*
>(&e);
   747         if (instructions[pc].debug_location) {
   749               instructions[pc].debug_location->wrapException(
   750                   e, 
"operation failed in interpreter"),
   753           handleError(e.what(), is_jit_exception);
   759       auto num_outputs = 
function->preprocess.n_outputs;
   760       if (num_outputs == 1) {
   761         future->markCompleted(stack.back());
   763         future->markCompleted(
   764             Tuple::create(jit::last(stack, num_outputs).vec()));
   771   void handleError(std::string&& error_msg, 
bool is_jit_exception) {
   773       future->markCompleted(Future::FutureError(std::move(error_msg)));
   774     } 
else if (is_jit_exception) {
   777       throw std::runtime_error(std::move(error_msg));
   784       future = c10::make_intrusive<Future>();
   795   void run(Stack& stack) {
   796     if (runImpl(stack)) {
   799       auto num_outputs = 
function->preprocess.n_outputs;
   800       if (num_outputs == 1) {
   801         push(stack, future->value());
   803         auto tuple = future->value().toTuple();
   804         for (
const auto& value : tuple->elements()) {
   812     return int_data[list.start + i];
   815     return bool_data[list.start + i];
   817   void loadTensorsFromRegisters(
const UseList& uses, Stack& stack) {
   818     for (
int i = 0; i < uses.values.size; i++) {
   819       int reg = 
get(uses.values, i);
   821       if (
get(uses.free_flags, i)) {
   822         stack.push_back(std::move(registers[reg]));
   824         stack.push_back(registers[reg]);
   832   std::shared_ptr<CodeImpl> 
function; 
   835   const std::vector<bool>& bool_data;
   847   std::vector<IValue> registers;
   854 std::ostream& operator<<(std::ostream& out, 
const Code& code) {
   855   out << *code.pImpl->graph << 
"\n";
   856   code.pImpl->dump(out);
   860 Code::Code(
const std::shared_ptr<Graph>& graph) : pImpl(
new CodeImpl(graph)) {}
   861 Code::~Code() = 
default;
   863 const std::vector<GraphExecutor*>& Code::grad_executors() {
   864   return pImpl->grad_executors();
   867 InterpreterState::InterpreterState(
const Code& code)
   868     : pImpl(c10::make_intrusive<InterpreterStateImpl>(code)) {}
   869 InterpreterState::~InterpreterState() = 
default;
   871 void InterpreterState::run(Stack& stack) {
   883 InterpreterState::InterpreterState(
   885     : pImpl(std::move(pImpl_)) {}
   887 void InterpreterContinuation::operator()() {
   889   state.runAsync(stack);
 
const at::Storage & storage() const override
Return the underlying storage of a Tensor. 
at::IntArrayRef strides() const override
Return a reference to the strides of this tensor. 
The low-level representation of a tensor, which contains a pointer to a storage (which contains the a...
at::IntArrayRef sizes() const override
Return a reference to the sizes of this tensor. 
constexpr size_t size() const 
size - Get the array size. 
int64_t dim() const override
Return the number of dimensions of this tensor. 
intrusive_ptr<T> is an alternative to shared_ptr<T> that has better performance because it does the r...
AT_CPP14_CONSTEXPR const T & at(size_t Index) const 
Vector compatibility. 
static intrusive_ptr reclaim(TTarget *owning_ptr)
Takes an owning pointer to TTarget* and creates an intrusive_ptr that takes over ownership.