1 #include <torch/csrc/jit/passes/lower_tuples.h> 2 #include <ATen/core/functional.h> 3 #include <c10/util/Exception.h> 4 #include <torch/csrc/jit/passes/dead_code_elimination.h> 14 std::unordered_set<Symbol> white_list = {
25 void removeTupleNodes(Node* n,
bool must_remove_tuples) {
26 if (n->kind() != prim::TupleUnpack && n->kind() != prim::TupleIndex &&
27 n->kind() != prim::TupleSlice) {
30 auto construct = n->input()->node();
31 if (construct->kind() != prim::TupleConstruct) {
32 if (must_remove_tuples) {
33 AT_ERROR(n->kind().toQualString(),
" not matched to tuple construct");
37 if (n->kind() == prim::TupleUnpack) {
38 for (
size_t i = 0; i < n->outputs().size(); ++i) {
39 n->outputs()[i]->replaceAllUsesWith(construct->inputs().at(i));
41 }
else if (n->kind() == prim::TupleIndex) {
42 auto idx = n->i(attr::index);
43 n->output()->replaceAllUsesWith(construct->inputs().at(idx));
44 }
else if (n->kind() == prim::TupleSlice) {
45 std::vector<Value*> values;
46 int64_t beg = n->i(attr::beg);
47 int64_t end = n->i(attr::end);
48 for (int64_t i = beg; i < end; i += 1) {
49 values.push_back(construct->inputs().at(i));
51 auto graph = n->owningGraph();
52 auto tuple_out = graph->createTuple(values);
53 WithInsertPoint insert(n);
54 graph->insertNode(tuple_out);
55 n->output()->replaceAllUsesWith(tuple_out->output());
61 static void LowerAllTuples(Block* block);
63 static void VisitNode(Node* n, Node* insert_point) {
64 auto& graph = *n->owningGraph();
67 if (n->kind() == prim::TupleConstruct) {
76 if (n->kind() == prim::TupleUnpack || n->kind() == prim::TupleIndex ||
77 n->kind() == prim::TupleSlice) {
78 removeTupleNodes(n,
true);
83 for (
size_t i = 0; i < n->inputs().size();) {
84 auto input = n->inputs()[i];
85 if (TupleTypePtr tt = input->type()->cast<TupleType>()) {
87 white_list.count(n->kind()) > 0,
88 "tuple appears in op that does not forward tuples");
90 input->node()->kind() == prim::TupleConstruct,
91 "tuple use not matched to tuple construct");
92 for (
size_t j = 0; j < tt->elements().size(); ++j) {
93 n->insertInput(i + 1 + j, input->node()->inputs().at(j));
103 for (
auto b : n->blocks()) {
108 for (
size_t i = 0; i < n->outputs().size();) {
109 Value* output = n->outputs()[i];
114 if (TupleTypePtr tt = output->type()->cast<TupleType>()) {
116 white_list.count(n->kind()) > 0,
117 "tuple appears in op that does not forward tuples");
118 for (
size_t j = 0; j < tt->elements().size(); j++) {
119 n->insertOutput(i + 1 + j)->setType(tt->elements()[j]);
122 graph.createTuple(n->outputs().slice(i + 1, tt->elements().size()));
123 new_tup->insertBefore(insert_point);
124 insert_point = new_tup;
125 output->replaceAllUsesWith(new_tup->output());
134 static void LowerAllTuples(Block* block) {
138 VisitNode(block->param_node(), *block->nodes().begin());
139 for (
auto it = block->nodes().begin(), end = block->nodes().end();
148 VisitNode(block->return_node(),
nullptr);
151 static void EnsureNoTuples(ArrayRef<Value*> values) {
152 for (Value* v : values) {
154 v->type()->kind() != TypeKind::TupleType,
"Couldn't lower all tuples.");
158 static void EnsureNoTuples(Block* block) {
159 for (Node* n : block->nodes()) {
160 for (Block* b : n->blocks()) {
163 EnsureNoTuples(n->outputs());
167 void LowerAllTuples(std::shared_ptr<Graph>& graph) {
168 LowerAllTuples(graph->block());
169 EliminateDeadCode(graph->block());
170 EnsureNoTuples(graph->block());
173 void LowerSimpleTuples(Block* block) {
174 for (
auto n : block->nodes()) {
175 removeTupleNodes(n,
false);
176 for (
auto b : n->blocks()) {
177 LowerSimpleTuples(b);
182 void LowerSimpleTuples(std::shared_ptr<Graph>& graph) {
183 LowerSimpleTuples(graph->block());
184 EliminateDeadCode(graph);