Caffe2 - C++ API
A deep learning, cross platform ML framework
inline_fork_wait.cpp
1 #include <torch/csrc/jit/passes/dead_code_elimination.h>
2 #include <torch/csrc/jit/passes/inline_fork_wait.h>
3 
4 namespace torch {
5 namespace jit {
6 
7 void InlineForkWait(
8  Block* b,
9  std::unordered_map<Value*, Value*>& future_remap) {
10  for (auto n : b->nodes()) {
11  if (n->kind() == prim::fork) {
12  WithInsertPoint insert_guard(n);
13  auto graph = b->owningGraph();
14  auto subgraph = n->g(attr::Subgraph);
15 
16  auto output = inlineCallTo(*graph, *subgraph, n->inputs());
17 
18  future_remap[n->output()] = output.at(0);
19  } else if (n->kind() == aten::wait) {
20  AT_ASSERT(n->inputs().size() == 1);
21  AT_ASSERT(n->outputs().size() == 1);
22  n->output()->replaceAllUsesWith(future_remap.at(n->input()));
23  }
24 
25  for (auto sub_b : n->blocks()) {
26  InlineForkWait(sub_b, future_remap);
27  }
28  }
29 }
30 
31 void InlineForkWait(const std::shared_ptr<Graph>& graph) {
32  std::unordered_map<Value*, Value*> future_remap;
33  InlineForkWait(graph->block(), future_remap);
34 }
35 
36 } // namespace jit
37 } // namespace torch
Definition: jit_type.h:17