1 #include <torch/csrc/jit/passes/dead_code_elimination.h> 2 #include <torch/csrc/jit/passes/inline_fork_wait.h> 9 std::unordered_map<Value*, Value*>& future_remap) {
10 for (
auto n : b->nodes()) {
11 if (n->kind() == prim::fork) {
12 WithInsertPoint insert_guard(n);
13 auto graph = b->owningGraph();
14 auto subgraph = n->g(attr::Subgraph);
16 auto output = inlineCallTo(*graph, *subgraph, n->inputs());
18 future_remap[n->output()] = output.at(0);
19 }
else if (n->kind() == aten::wait) {
20 AT_ASSERT(n->inputs().size() == 1);
21 AT_ASSERT(n->outputs().size() == 1);
22 n->output()->replaceAllUsesWith(future_remap.at(n->input()));
25 for (
auto sub_b : n->blocks()) {
26 InlineForkWait(sub_b, future_remap);
31 void InlineForkWait(
const std::shared_ptr<Graph>& graph) {
32 std::unordered_map<Value*, Value*> future_remap;
33 InlineForkWait(graph->block(), future_remap);