Caffe2 - C++ API
A deep learning, cross platform ML framework
constant_propagation.cpp
1 #include <torch/csrc/jit/passes/constant_propagation.h>
2 #include <ATen/core/functional.h>
3 #include <ATen/core/ivalue.h>
4 #include <torch/csrc/autograd/variable.h>
5 #include <torch/csrc/jit/constants.h>
6 #include <torch/csrc/jit/interpreter.h>
7 #include <torch/csrc/jit/ir.h>
8 #include <torch/csrc/jit/operator.h>
9 #include <torch/csrc/jit/passes/alias_analysis.h>
10 #include <torch/csrc/jit/passes/dead_code_elimination.h>
11 
12 namespace torch {
13 namespace jit {
14 
15 namespace {
16 
17 std::unordered_set<Symbol> skip_list = {
18  prim::If,
19  prim::Loop,
20  prim::Constant,
21  prim::AutogradZero,
22  prim::unchecked_unwrap_optional, // TODO remove
23  // TODO (zach): we should consider skipping tensor factories in the cases
24  // where the constant tensor would be large but cheap to create.
25 };
26 
27 std::vector<IValue> runNode(Node* n) {
28  auto op = getOperation(n);
29  Stack stack;
30  for (auto input : n->inputs()) {
31  stack.push_back(*(toIValue(input)));
32  }
33  op(stack);
34  auto var_outputs = fmap(stack, [&](IValue v) -> IValue {
35  if (v.isTensor()) {
36  auto t = std::move(v).toTensor();
37  if (t.defined()) {
38  return IValue(autograd::as_variable_ref(t).data());
39  } else {
40  return t;
41  }
42  } else {
43  return v;
44  }
45  });
46  return var_outputs;
47 }
48 
49 void propagateNode(Node* n) {
50  std::vector<IValue> outputs;
51  try {
52  outputs = runNode(n);
53  } catch (const c10::Error& e) {
54  // catch AT_ASSERT errors. This op may not be run reached,
55  // so catch the error here & leave the op in the graph
56  return;
57  }
58  auto graph = n->owningGraph();
59  WithInsertPoint guard(n);
60  for (size_t i = 0; i < outputs.size(); ++i) {
61  try {
62  auto new_output = graph->insertConstant(outputs[i]);
63  if (outputs[i].isNone()) {
64  new_output->setType(n->outputs()[i]->type());
65  }
66  n->outputs()[i]->replaceAllUsesWith(new_output);
67  } catch (constant_not_supported_error& err) {
68  // we cannot actually represent the IValue as a constant node,
69  // so we give up replacing it
70  }
71  // let dce elimination remove n
72  }
73 }
74 
75 void removeLoopNode(Node* n) {
76  auto loop_input_offset = 2; // offset of loop carried deps in input list
77  for (size_t i = 0; i < n->outputs().size(); ++i) {
78  n->outputs().at(i)->replaceAllUsesWith(
79  n->inputs().at(i + loop_input_offset));
80  }
81  n->destroy();
82 }
83 
84 bool loopWillNotRun(Node* node) {
85  Value* trip_count = node->inputs().at(0);
86  int64_t iter_len = constant_as<int64_t>(trip_count).value_or(1);
87 
88  Value* start_cond = node->inputs().at(1);
89  bool cond_val = constant_as<bool>(start_cond).value_or(true);
90 
91  bool loop_might_run = cond_val && iter_len > 0;
92  return !loop_might_run;
93 }
94 
95 void ConstantPropagation(Block* block, const AliasDb& aliasDb);
96 
97 void inlineIfBody(Block* body) {
98  Node* n = body->owningNode();
99  for (auto it = body->nodes().begin(); it != body->nodes().end();) {
100  Node* body_node = *it;
101  // advance iterator because after body_node is moved its next pointer will
102  // be to n
103  it++;
104  body_node->moveBefore(n);
105  }
106  for (size_t i = 0; i < n->outputs().size(); ++i) {
107  n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i));
108  }
109  // NB: destroy the node here, because it might contain side effects, like
110  // print
111  n->destroy();
112 }
113 
114 void inlineIf(Node* n, const AliasDb& aliasDb) {
115  auto input_bool = constant_as<bool>(n->input());
116  AT_ASSERT(input_bool);
117  size_t block_index = *input_bool ? 0 : 1;
118  ConstantPropagation(n->blocks().at(block_index), aliasDb);
119  inlineIfBody(n->blocks().at(block_index));
120 }
121 
122 // remove extra outputs from the node
123 bool removeExtraIfOutputs(Node* n) {
124  AT_CHECK(n->kind() == prim::If, "Only supported for If nodes");
125  auto true_block = n->blocks()[0];
126  auto false_block = n->blocks()[1];
127  auto initial_outputs = true_block->outputs().size();
128  for (size_t i = 0; i < true_block->outputs().size();) {
129  // neither block changes the output value
130  if (true_block->outputs()[i] == false_block->outputs()[i]) {
131  n->outputs().at(i)->replaceAllUsesWith(true_block->outputs()[i]);
132  n->eraseOutput(i);
133  true_block->eraseOutput(i);
134  false_block->eraseOutput(i);
135  } else {
136  i++; // increment bc we didn't remove current index
137  }
138  }
139  // an output was removed
140  return initial_outputs != true_block->outputs().size();
141 }
142 
143 // remove extra outputs from the node
144 void removeExtraLoopOutputs(Node* node) {
145  auto loop_body = node->blocks().at(0);
146  auto loop_input_offset = 2; // offset of loop carried deps in input list
147  auto loop_body_offset =
148  1; // offset to the loop carried dependencies in block inputs/outputs
149  for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
150  size_t i = i_1 - 1;
151  // if the value is no longer changed remove output
152  if (loop_body->inputs().at(loop_body_offset + i) ==
153  loop_body->outputs().at(loop_body_offset + i)) {
154  auto node_input = node->inputs().at(loop_input_offset + i);
155  node->outputs().at(i)->replaceAllUsesWith(node_input);
156  loop_body->inputs()
157  .at(loop_body_offset + i)
158  ->replaceAllUsesWith(node_input);
159  node->eraseOutput(i);
160  node->removeInput(loop_input_offset + i);
161  loop_body->eraseInput(loop_body_offset + i);
162  loop_body->eraseOutput(loop_body_offset + i);
163  }
164  }
165 }
166 
167 void ConstantPropagation(Node* n, const AliasDb& aliasDb) {
168  bool constant_inputs =
169  std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
170  return v->node()->kind() == prim::Constant;
171  });
172  bool supported_node = !n->kind().is_onnx() &&
173  skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
174  !n->hasSideEffects() && !aliasDb.hasWriters(n);
175  auto run_blocks = [&]() {
176  for (Block* block : n->blocks()) {
177  ConstantPropagation(block, aliasDb);
178  }
179  };
180  if (n->kind() == prim::If) {
181  // inline node if we can, otherwise check for simplified outputs
182  if (constant_inputs) {
183  inlineIf(n, aliasDb);
184  } else {
185  run_blocks();
186  removeExtraIfOutputs(n);
187  }
188  } else if (n->kind() == prim::Loop) {
189  if (loopWillNotRun(n)) {
190  removeLoopNode(n);
191  } else {
192  run_blocks();
193  removeExtraLoopOutputs(n);
194  }
195  } else if (constant_inputs && supported_node) {
196  propagateNode(n);
197  } else {
198  run_blocks();
199  }
200 }
201 
202 void ConstantPropagation(Block* block, const AliasDb& aliasDb) {
203  for (auto it = block->nodes().begin(); it != block->nodes().end();) {
204  Node* n = *it;
205  it++; // advance iterator bc the current node may be destroyed
206  ConstantPropagation(n, aliasDb);
207  }
208 }
209 } // anonymous namespace
210 
211 void ConstantPropagation(std::shared_ptr<Graph>& graph) {
212  AliasDb aliasDb(graph);
213  ConstantPropagation(graph->block(), aliasDb);
214  EliminateDeadCode(graph);
215 }
216 
217 } // namespace jit
218 } // namespace torch
The primary ATen error class.
Definition: Exception.h:27
Definition: jit_type.h:17