1 #include <c10/util/Exception.h> 2 #include <torch/csrc/jit/passes/onnx/peephole.h> 4 #include <c10/util/Optional.h> 8 typedef SSIZE_T ssize_t;
15 using namespace ::c10::onnx;
18 bool isRNN(
const Node* node) {
19 auto k = node->kind();
20 return k == onnx::RNN || k == onnx::LSTM || k == onnx::GRU;
23 bool isNopTranspose(
const std::vector<int64_t>& perm) {
24 for (int64_t i = 0, perm_size = perm.size(); i < perm_size; i++)
39 std::vector<int64_t> composeTransposes(
40 const std::vector<int64_t>& t1,
41 const std::vector<int64_t>& t2) {
42 AT_ASSERT(t1.size() == t2.size());
43 std::vector<int64_t> ret;
44 ret.reserve(t1.size());
45 for (
const auto& i : t2) {
46 AT_ASSERT(i < int64_t(t1.size()));
52 const std::vector<size_t>& getBroadcastPositions(Node* node) {
56 static std::unordered_map<NodeKind, std::vector<size_t>> broadcast_positions =
64 {onnx::Equal, {0, 1}},
65 {onnx::Greater, {0, 1}},
68 static std::vector<size_t> no_positions;
70 auto iter = broadcast_positions.find(node->kind());
71 if (iter != broadcast_positions.end()) {
85 for (
size_t i = 0; i < from.
size(); i++) {
86 auto fdim = from[from.
size() - 1 - i];
87 auto tdim = to[to.
size() - 1 - i];
88 if (fdim != 1 && fdim != tdim) {
96 void fuseBroadcast(Block* b) {
97 for (
auto n : b->nodes()) {
98 for (
auto* child_block : n->blocks()) {
99 fuseBroadcast(child_block);
102 auto& broadcast_positions = getBroadcastPositions(n);
103 if (!broadcast_positions.empty()) {
104 AT_ASSERT(!n->hasAttribute(attr::axis));
107 for (
size_t position : broadcast_positions) {
108 auto* expand_node = n->input(position)->node();
111 if (expand_node->kind() != aten::expand ||
112 expand_node->input(1)->node()->kind() != onnx::Constant ||
113 expand_node->input(2)->node()->kind() != onnx::Constant) {
117 auto* unexpanded_input = expand_node->input(0);
123 if (!unexpanded_input->isTensor() || !n->output()->isTensor())
128 unexpanded_input->type()
129 ->expect<CompleteTensorType>()
131 n->output()->type()->expect<CompleteTensorType>()->sizes());
132 if (axis == c10::nullopt)
135 n->replaceInput(position, unexpanded_input);
136 if (!expand_node->hasUses()) {
137 expand_node->destroy();
143 void fuseConsecutiveTransposes(Block* b) {
144 for (
auto n : b->nodes()) {
145 for (
auto* child_block : n->blocks()) {
146 fuseConsecutiveTransposes(child_block);
148 if (n->kind() == onnx::Transpose &&
149 n->input()->node()->kind() == onnx::Transpose) {
150 auto origInput = n->input();
154 origInput->node()->is(attr::perm), n->is(attr::perm)));
155 n->replaceInput(0, origInput->node()->input());
156 if (origInput->uses().size() == 0) {
157 origInput->node()->destroy();
164 void eliminateNopTranspose(Block* b) {
165 for (
auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
167 for (
auto* child_block : n->blocks()) {
168 eliminateNopTranspose(child_block);
170 if (n->kind() == onnx::Transpose) {
171 if (isNopTranspose(n->is(attr::perm))) {
172 n->output()->replaceAllUsesWith(n->input());
180 void fuseTransposeIntoGemm(Block* b) {
181 static const std::vector<int64_t> simpleTransPerm({1, 0});
183 for (
auto n : b->nodes()) {
184 for (
auto* child_block : n->blocks()) {
185 fuseTransposeIntoGemm(child_block);
187 if (n->kind() == onnx::Gemm) {
188 for (
size_t i : {0, 1}) {
189 auto inp = n->inputs()[i];
190 auto trans = i == 0 ? attr::transA : attr::transB;
191 if (inp->node()->kind() == onnx::Transpose &&
192 inp->node()->is(attr::perm) == simpleTransPerm) {
193 n->replaceInput(i, inp->node()->input());
194 n->i_(trans, n->hasAttribute(trans) ? !n->i(trans) : 1);
195 if (inp->uses().size() == 0) {
196 inp->node()->destroy();
221 void pushPackingPastRnn(Block* b) {
222 for (
auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
224 for (
auto* child_block : n->blocks()) {
225 pushPackingPastRnn(child_block);
228 if (n->kind() != prim::PackPadded) {
231 if (n->outputs().at(0)->uses().size() != 1) {
235 Node* rnn = n->outputs()[0]->uses()[0].user;
240 if (rnn->owningBlock() != n->owningBlock())
245 if (rnn->outputs().at(0)->uses().empty() &&
246 n->outputs().at(1)->uses().size() == 1) {
247 n->outputs().at(0)->replaceAllUsesWith(n->inputs().at(0));
248 n->outputs().at(1)->replaceFirstUseWith(n->inputs().at(1));
255 Node* next = rnn->outputs().at(0)->uses().at(0).user;
256 if (next->kind() == onnx::Transpose) {
257 next = next->outputs().at(0)->uses().at(0).user;
258 if (next->kind() != onnx::Reshape) {
261 }
else if (next->kind() != onnx::Squeeze) {
266 n->outputs().at(0)->replaceAllUsesWith(n->inputs().at(0));
270 n->outputs().at(1)->replaceFirstUseWith(n->inputs().at(1));
273 Node* newPackPadded = b->owningGraph()->create(prim::PackPadded, 2);
274 newPackPadded->insertAfter(next);
277 next->outputs().at(0)->replaceAllUsesWith(newPackPadded->outputs().at(0));
278 n->outputs().at(1)->replaceAllUsesWith(newPackPadded->outputs().at(1));
281 newPackPadded->addInput(next->outputs().at(0));
282 newPackPadded->addInput(n->inputs().at(1));
289 CompleteTensorTypePtr oldType =
290 rnn->inputs().at(0)->type()->cast<CompleteTensorType>();
292 std::vector<int64_t> new_sizes;
293 new_sizes.push_back(oldType->sizes().at(0));
294 new_sizes.push_back(oldType->sizes().at(1));
295 new_sizes.push_back(rnn->i(attr::hidden_size));
296 CompleteTensorTypePtr newType = CompleteTensorType::create(
297 oldType->scalarType(), oldType->device(), new_sizes);
298 next->outputs().at(0)->setType(newType);
305 void removeNopPacking(Block* graph) {
306 for (
auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
308 for (
auto* child_block : n->blocks()) {
309 removeNopPacking(child_block);
312 if (n->kind() != prim::PadPacked) {
315 Node* input = n->inputs()[0]->node();
316 if (input->kind() != prim::PackPadded) {
319 if (input->outputs()[0] != n->inputs()[0]) {
322 if (input->outputs()[1] != n->inputs()[1]) {
325 n->outputs()[0]->replaceAllUsesWith(input->inputs()[0]);
326 n->outputs()[1]->replaceAllUsesWith(input->inputs()[1]);
328 n->removeAllInputs();
333 void hackFixupPadPackedShapes(Block* graph) {
337 for (
auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
339 for (
auto* child_block : n->blocks()) {
340 removeNopPacking(child_block);
343 if (n->kind() != prim::PadPacked) {
346 Node* input = n->inputs()[0]->node();
347 input->outputs()[0]->setType(n->outputs()[0]->type());
351 void fixDefaultRNNState(Graph* graph, Node* n,
int input_index) {
352 auto initial_state = n->inputs()[input_index];
361 bool needsFixing = initial_state->node()->kind() == onnx::Constant ||
362 (initial_state->node()->kind() == onnx::Slice &&
363 initial_state->node()->inputs()[0]->node()->kind() == onnx::Constant);
369 Node* shape_of_input = graph->create(onnx::Shape, 1);
370 shape_of_input->insertBefore(n);
371 shape_of_input->addInput(n->inputs()[0]);
373 Node* gather_indices = graph->create(onnx::Constant, 1);
374 gather_indices->insertBefore(n);
375 gather_indices->t_(attr::value, autograd::make_variable(at::scalar_to_tensor(
at::Scalar(1))));
377 Node* batch_size = graph->create(onnx::Gather, 1);
378 batch_size->insertBefore(n);
379 batch_size->addInput(shape_of_input->outputs()[0]);
380 batch_size->addInput(gather_indices->outputs()[0]);
382 Node* unsqueezed_batch_size = graph->create(onnx::Unsqueeze, 1);
383 unsqueezed_batch_size->insertBefore(n);
384 unsqueezed_batch_size->addInput(batch_size->outputs()[0]);
385 unsqueezed_batch_size->is_(attr::axes, {0});
387 Node* hidden_size = graph->create(onnx::Constant, 1);
388 hidden_size->insertBefore(n);
391 autograd::make_variable(at::full(
393 n->i(attr::hidden_size),
396 Node* num_directions = graph->create(onnx::Constant, 1);
397 num_directions->insertBefore(n);
400 autograd::make_variable(scalar_to_tensor(
at::Scalar(
401 n->hasAttribute(attr::direction) &&
402 n->s(attr::direction) ==
"bidirectional" 406 Node* unsqueezed_num_directions = graph->create(onnx::Unsqueeze, 1);
407 unsqueezed_num_directions->insertBefore(n);
408 unsqueezed_num_directions->addInput(num_directions->outputs()[0]);
409 unsqueezed_num_directions->is_(attr::axes, {0});
411 Node* concated_dims = graph->create(onnx::Concat, 1);
412 concated_dims->insertBefore(n);
413 concated_dims->i_(attr::axis, 0);
414 concated_dims->addInput(unsqueezed_num_directions->outputs()[0]);
415 concated_dims->addInput(unsqueezed_batch_size->outputs()[0]);
416 concated_dims->addInput(hidden_size->outputs()[0]);
418 Node* constant_of_shape = graph->create(onnx::ConstantOfShape, 1);
419 constant_of_shape->insertBefore(n);
420 constant_of_shape->addInput(concated_dims->outputs()[0]);
421 n->replaceInput(input_index, constant_of_shape->outputs()[0]);
423 if (initial_state->uses().size() == 0) {
424 initial_state->node()->destroy();
428 void fixDefaultRnnHiddenState(Block* b) {
429 for (
auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
431 for (
auto* child_block : n->blocks()) {
432 fixDefaultRnnHiddenState(child_block);
440 if (n->inputs().size() < 6) {
443 fixDefaultRNNState(b->owningGraph(), n, 5);
447 void fixDefaultLstmCellState(Block* b) {
448 for (
auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
450 for (
auto* child_block : n->blocks()) {
451 fixDefaultLstmCellState(child_block);
454 if (n->kind() != onnx::LSTM) {
459 if (n->inputs().size() < 7) {
462 fixDefaultRNNState(b->owningGraph(), n, 6);
466 static bool isSafeToSpeculate(Node* n) {
467 return n->kind() == onnx::Transpose;
470 static void speculateOps(Block* block) {
471 for (
auto it = block->nodes().begin(), end = block->nodes().end();
476 for (
auto b : n->blocks()) {
479 if (!isSafeToSpeculate(n))
483 auto node_input = n->input()->node();
484 if (node_input->owningBlock() == n->owningBlock())
488 auto control_flow_node = n->owningBlock()->owningNode();
489 while (control_flow_node->owningBlock() != node_input->owningBlock())
490 control_flow_node = control_flow_node->owningBlock()->owningNode();
492 n->moveBefore(control_flow_node);
496 static void replaceInputWithList(Node* node,
size_t i, ArrayRef<Value*> to) {
497 node->removeInput(i);
498 for (
auto* to_val : to) {
499 AT_ASSERT(to_val->owningGraph() == node->owningGraph());
500 node->insertInput(i++, to_val);
504 static void eraseListConstruct(Block* block) {
505 for (
auto it = block->nodes().begin(), end = block->nodes().end();
510 for (
auto b : n->blocks()) {
511 eraseListConstruct(b);
513 std::vector<std::tuple<size_t, std::vector<Value*>>> replacements;
516 for (
auto* input : n->inputs()) {
517 if (input->node()->kind() == prim::ListConstruct) {
518 auto* lc_node = input->node();
520 lc_node->output()->type()->cast<ListType>()->getElementType();
521 if (elem->cast<IntType>()) {
525 std::vector<Value*> unsqueezed;
526 Graph* g = block->owningGraph();
527 for (
auto* input : lc_node->inputs()) {
528 Node* unsqueezed_node = g->create(onnx::Unsqueeze, 1);
529 unsqueezed_node->insertBefore(lc_node);
530 unsqueezed_node->addInput(input);
531 unsqueezed_node->is_(attr::axes, {0});
532 unsqueezed.emplace_back(unsqueezed_node->output());
534 Node* concat_node = g->create(onnx::Concat, 1);
535 concat_node->i_(attr::axis, 0);
536 for (
auto v : unsqueezed) {
537 concat_node->addInput(v);
539 concat_node->insertBefore(lc_node);
543 replacements.emplace_back(
544 i, std::vector<Value*>({concat_node->output()}));
550 replacements.emplace_back(
553 lc_node->inputs().begin(), lc_node->inputs().end()));
559 for (
auto ritr = replacements.rbegin(); ritr != replacements.rend();
561 replaceInputWithList(n, std::get<0>(*ritr), std::get<1>(*ritr));
566 static void fuseSplitListUnpack(Block *b) {
567 for(
auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
568 for (
auto *child_block : it->blocks()) {
569 fuseSplitListUnpack(child_block);
571 if (it->kind() == prim::ListUnpack && it->input()->node()->kind() == onnx::Split) {
572 auto origSplitNode = it->input()->node();
574 Node * splitNode = b->owningGraph()->create(onnx::Split, it->outputs().size());
575 for (
size_t i=0; i<splitNode->outputs().size(); ++i) {
576 splitNode->outputs()[i]->copyMetadata(it->outputs()[i]);
578 splitNode->copyAttributes(*origSplitNode);
579 splitNode->insertBefore(origSplitNode);
580 splitNode->addInput(origSplitNode->input());
581 it->replaceAllUsesWith(splitNode);
582 it->removeAllInputs();
583 origSplitNode->destroy();
590 void removeMaxPoolUnusedOutput(Block* b) {
591 for (
auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
593 for (
auto* child_block : n->blocks()) {
594 removeMaxPoolUnusedOutput(child_block);
596 if (strcmp(n->kind().toQualString(),
"onnx::MaxPool") == 0) {
597 if (n->outputs().size() == 2 && n->outputs().at(1)->uses().empty()) {
621 void PeepholeOptimizeONNX(std::shared_ptr<Graph>& graph) {
625 hackFixupPadPackedShapes(graph->block());
626 pushPackingPastRnn(graph->block());
627 removeNopPacking(graph->block());
628 fixDefaultRnnHiddenState(graph->block());
629 fixDefaultLstmCellState(graph->block());
630 fuseBroadcast(graph->block());
631 fuseConsecutiveTransposes(graph->block());
632 eliminateNopTranspose(graph->block());
633 fuseTransposeIntoGemm(graph->block());
634 speculateOps(graph->block());
635 eraseListConstruct(graph->block());
636 fuseSplitListUnpack(graph->block());
637 removeMaxPoolUnusedOutput(graph->block());
Scalar represents a 0-dimensional tensor which contains a single element.
constexpr size_t size() const
size - Get the array size.