1 #include <torch/csrc/jit/passes/constant_propagation.h> 2 #include <ATen/core/functional.h> 3 #include <ATen/core/ivalue.h> 4 #include <torch/csrc/autograd/variable.h> 5 #include <torch/csrc/jit/constants.h> 6 #include <torch/csrc/jit/interpreter.h> 7 #include <torch/csrc/jit/ir.h> 8 #include <torch/csrc/jit/operator.h> 9 #include <torch/csrc/jit/passes/alias_analysis.h> 10 #include <torch/csrc/jit/passes/dead_code_elimination.h> 17 std::unordered_set<Symbol> skip_list = {
22 prim::unchecked_unwrap_optional,
27 std::vector<IValue> runNode(Node* n) {
28 auto op = getOperation(n);
30 for (
auto input : n->inputs()) {
31 stack.push_back(*(toIValue(input)));
34 auto var_outputs = fmap(stack, [&](IValue v) -> IValue {
36 auto t = std::move(v).toTensor();
38 return IValue(autograd::as_variable_ref(t).data());
49 void propagateNode(Node* n) {
50 std::vector<IValue> outputs;
58 auto graph = n->owningGraph();
59 WithInsertPoint guard(n);
60 for (
size_t i = 0; i < outputs.size(); ++i) {
62 auto new_output = graph->insertConstant(outputs[i]);
63 if (outputs[i].isNone()) {
64 new_output->setType(n->outputs()[i]->type());
66 n->outputs()[i]->replaceAllUsesWith(new_output);
67 }
catch (constant_not_supported_error& err) {
75 void removeLoopNode(Node* n) {
76 auto loop_input_offset = 2;
77 for (
size_t i = 0; i < n->outputs().size(); ++i) {
78 n->outputs().at(i)->replaceAllUsesWith(
79 n->inputs().at(i + loop_input_offset));
84 bool loopWillNotRun(Node* node) {
85 Value* trip_count = node->inputs().at(0);
86 int64_t iter_len = constant_as<int64_t>(trip_count).value_or(1);
88 Value* start_cond = node->inputs().at(1);
89 bool cond_val = constant_as<bool>(start_cond).value_or(
true);
91 bool loop_might_run = cond_val && iter_len > 0;
92 return !loop_might_run;
95 void ConstantPropagation(Block* block,
const AliasDb& aliasDb);
97 void inlineIfBody(Block* body) {
98 Node* n = body->owningNode();
99 for (
auto it = body->nodes().begin(); it != body->nodes().end();) {
100 Node* body_node = *it;
104 body_node->moveBefore(n);
106 for (
size_t i = 0; i < n->outputs().size(); ++i) {
107 n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i));
114 void inlineIf(Node* n,
const AliasDb& aliasDb) {
115 auto input_bool = constant_as<bool>(n->input());
116 AT_ASSERT(input_bool);
117 size_t block_index = *input_bool ? 0 : 1;
118 ConstantPropagation(n->blocks().at(block_index), aliasDb);
119 inlineIfBody(n->blocks().at(block_index));
123 bool removeExtraIfOutputs(Node* n) {
124 AT_CHECK(n->kind() == prim::If,
"Only supported for If nodes");
125 auto true_block = n->blocks()[0];
126 auto false_block = n->blocks()[1];
127 auto initial_outputs = true_block->outputs().size();
128 for (
size_t i = 0; i < true_block->outputs().size();) {
130 if (true_block->outputs()[i] == false_block->outputs()[i]) {
131 n->outputs().at(i)->replaceAllUsesWith(true_block->outputs()[i]);
133 true_block->eraseOutput(i);
134 false_block->eraseOutput(i);
140 return initial_outputs != true_block->outputs().size();
144 void removeExtraLoopOutputs(Node* node) {
145 auto loop_body = node->blocks().at(0);
146 auto loop_input_offset = 2;
147 auto loop_body_offset =
149 for (
size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
152 if (loop_body->inputs().at(loop_body_offset + i) ==
153 loop_body->outputs().at(loop_body_offset + i)) {
154 auto node_input = node->inputs().at(loop_input_offset + i);
155 node->outputs().at(i)->replaceAllUsesWith(node_input);
157 .at(loop_body_offset + i)
158 ->replaceAllUsesWith(node_input);
159 node->eraseOutput(i);
160 node->removeInput(loop_input_offset + i);
161 loop_body->eraseInput(loop_body_offset + i);
162 loop_body->eraseOutput(loop_body_offset + i);
167 void ConstantPropagation(Node* n,
const AliasDb& aliasDb) {
168 bool constant_inputs =
169 std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
170 return v->node()->kind() == prim::Constant;
172 bool supported_node = !n->kind().is_onnx() &&
173 skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
174 !n->hasSideEffects() && !aliasDb.hasWriters(n);
175 auto run_blocks = [&]() {
176 for (Block* block : n->blocks()) {
177 ConstantPropagation(block, aliasDb);
180 if (n->kind() == prim::If) {
182 if (constant_inputs) {
183 inlineIf(n, aliasDb);
186 removeExtraIfOutputs(n);
188 }
else if (n->kind() == prim::Loop) {
189 if (loopWillNotRun(n)) {
193 removeExtraLoopOutputs(n);
195 }
else if (constant_inputs && supported_node) {
202 void ConstantPropagation(Block* block,
const AliasDb& aliasDb) {
203 for (
auto it = block->nodes().begin(); it != block->nodes().end();) {
206 ConstantPropagation(n, aliasDb);
211 void ConstantPropagation(std::shared_ptr<Graph>& graph) {
212 AliasDb aliasDb(graph);
213 ConstantPropagation(graph->block(), aliasDb);
214 EliminateDeadCode(graph);
The primary ATen error class.