Caffe2 - C++ API
A deep learning, cross platform ML framework
lower_tuples.cpp
1 #include <torch/csrc/jit/passes/lower_tuples.h>
2 #include <ATen/core/functional.h>
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/jit/passes/dead_code_elimination.h>
5 
6 namespace torch {
7 namespace jit {
8 
9 namespace {
10 
11 // operators where we expect to find tuples as inputs/outputs
12 // this is to assert we are only doing modifications when we know
13 // we can flatten tuples
14 std::unordered_set<Symbol> white_list = {
15  prim::If,
16  prim::Loop,
17  prim::TupleUnpack,
18  prim::TupleConstruct,
19  prim::TupleIndex,
20  prim::TupleSlice,
21  prim::Param,
22  prim::Return,
23 };
24 
25 void removeTupleNodes(Node* n, bool must_remove_tuples) {
26  if (n->kind() != prim::TupleUnpack && n->kind() != prim::TupleIndex &&
27  n->kind() != prim::TupleSlice) {
28  return;
29  }
30  auto construct = n->input()->node();
31  if (construct->kind() != prim::TupleConstruct) {
32  if (must_remove_tuples) {
33  AT_ERROR(n->kind().toQualString(), " not matched to tuple construct");
34  }
35  return;
36  }
37  if (n->kind() == prim::TupleUnpack) {
38  for (size_t i = 0; i < n->outputs().size(); ++i) {
39  n->outputs()[i]->replaceAllUsesWith(construct->inputs().at(i));
40  }
41  } else if (n->kind() == prim::TupleIndex) {
42  auto idx = n->i(attr::index);
43  n->output()->replaceAllUsesWith(construct->inputs().at(idx));
44  } else if (n->kind() == prim::TupleSlice) {
45  std::vector<Value*> values;
46  int64_t beg = n->i(attr::beg);
47  int64_t end = n->i(attr::end);
48  for (int64_t i = beg; i < end; i += 1) {
49  values.push_back(construct->inputs().at(i));
50  }
51  auto graph = n->owningGraph();
52  auto tuple_out = graph->createTuple(values);
53  WithInsertPoint insert(n);
54  graph->insertNode(tuple_out);
55  n->output()->replaceAllUsesWith(tuple_out->output());
56  }
57 }
58 
59 } // anonymous namespace
60 
61 static void LowerAllTuples(Block* block);
62 
63 static void VisitNode(Node* n, Node* insert_point) {
64  auto& graph = *n->owningGraph();
65 
66  // tuple construction operators will become dead when the unpacks are replaced
67  if (n->kind() == prim::TupleConstruct) {
68  return;
69  }
70 
71  // note: changing the second argument to false changes this pass from a
72  // complete lowering pass to one that removes tuples when possible. When
73  // tuples are first-class in the interpreter, we should still run this pass to
74  // remove extraneous uses
75 
76  if (n->kind() == prim::TupleUnpack || n->kind() == prim::TupleIndex ||
77  n->kind() == prim::TupleSlice) {
78  removeTupleNodes(n, /*must_remove_tuples*/ true);
79  return;
80  }
81 
82  // flatten the input list op(a, tup, b) --> op(a, t0, t1, b)
83  for (size_t i = 0; i < n->inputs().size();) {
84  auto input = n->inputs()[i];
85  if (TupleTypePtr tt = input->type()->cast<TupleType>()) {
86  AT_CHECK(
87  white_list.count(n->kind()) > 0,
88  "tuple appears in op that does not forward tuples");
89  AT_CHECK(
90  input->node()->kind() == prim::TupleConstruct,
91  "tuple use not matched to tuple construct");
92  for (size_t j = 0; j < tt->elements().size(); ++j) {
93  n->insertInput(i + 1 + j, input->node()->inputs().at(j));
94  }
95  n->removeInput(i);
96  // note: no update to i
97  // since tuples might be nested we need to recursively scan
98  // the new flattened inputs
99  } else {
100  ++i;
101  }
102  }
103  for (auto b : n->blocks()) {
104  LowerAllTuples(b);
105  }
106 
107  // flatten the outputs list
108  for (size_t i = 0; i < n->outputs().size();) {
109  Value* output = n->outputs()[i];
110  // (a, b, tup, c) -> (a, b, t0, t1, c)
111  // and:
112  // tup = (t0, t1)
113  // is placed at the current insertion point
114  if (TupleTypePtr tt = output->type()->cast<TupleType>()) {
115  AT_CHECK(
116  white_list.count(n->kind()) > 0,
117  "tuple appears in op that does not forward tuples");
118  for (size_t j = 0; j < tt->elements().size(); j++) {
119  n->insertOutput(i + 1 + j)->setType(tt->elements()[j]);
120  }
121  auto new_tup =
122  graph.createTuple(n->outputs().slice(i + 1, tt->elements().size()));
123  new_tup->insertBefore(insert_point);
124  insert_point = new_tup;
125  output->replaceAllUsesWith(new_tup->output());
126  n->eraseOutput(i);
127  // note: no update to i to handle nested tuples
128  } else {
129  ++i;
130  }
131  }
132 }
133 
134 static void LowerAllTuples(Block* block) {
135  // tuples in parameter lists of a block behave exactly the same as
136  // _outputs_ of normal instructions, since the param_node represents the
137  // parameters as outputs, we can handle it by simply visiting the node
138  VisitNode(block->param_node(), *block->nodes().begin());
139  for (auto it = block->nodes().begin(), end = block->nodes().end();
140  it != end;) {
141  auto n = *it++;
142  VisitNode(n, *it);
143  }
144  // tuples in return lists of blocks behave exactly the same as
145  // _inputs_ of normal instructions, so we can use VisitNode here as well
146  // insert_point is null because it will never be used since return nodes
147  // have no outputs
148  VisitNode(block->return_node(), nullptr);
149 }
150 
151 static void EnsureNoTuples(ArrayRef<Value*> values) {
152  for (Value* v : values) {
153  AT_CHECK(
154  v->type()->kind() != TypeKind::TupleType, "Couldn't lower all tuples.");
155  }
156 }
157 
158 static void EnsureNoTuples(Block* block) {
159  for (Node* n : block->nodes()) {
160  for (Block* b : n->blocks()) {
161  EnsureNoTuples(b);
162  }
163  EnsureNoTuples(n->outputs());
164  }
165 }
166 
167 void LowerAllTuples(std::shared_ptr<Graph>& graph) {
168  LowerAllTuples(graph->block());
169  EliminateDeadCode(graph->block());
170  EnsureNoTuples(graph->block());
171 }
172 
173 void LowerSimpleTuples(Block* block) {
174  for (auto n : block->nodes()) {
175  removeTupleNodes(n, /*must_remove_tuples*/ false);
176  for (auto b : n->blocks()) {
177  LowerSimpleTuples(b);
178  }
179  }
180 }
181 
182 void LowerSimpleTuples(std::shared_ptr<Graph>& graph) {
183  LowerSimpleTuples(graph->block());
184  EliminateDeadCode(graph);
185 }
186 
187 } // namespace jit
188 } // namespace torch
Definition: jit_type.h:17