1 #include <torch/csrc/jit/passes/dead_code_elimination.h> 2 #include <torch/csrc/jit/passes/remove_inplace_ops.h> 7 static const std::unordered_map<NodeKind, NodeKind> inPlaceToOutOfPlace = {
8 {aten::add_, aten::add},
9 {aten::sub_, aten::sub},
10 {aten::div_, aten::div},
11 {aten::mul_, aten::mul}};
13 bool isInplaceOp(
const Node* node) {
14 return inPlaceToOutOfPlace.count(node->kind()) != 0;
26 void RemoveInplaceOps(Block* block) {
27 auto graph = block->owningGraph();
28 auto it = block->nodes().begin();
29 while (it != block->nodes().end()) {
32 for (
auto block : node->blocks()) {
33 RemoveInplaceOps(block);
36 if (isInplaceOp(node)) {
38 auto newNode = graph->create(inPlaceToOutOfPlace.at(node->kind()));
39 newNode->insertBefore(node);
42 for (
auto input : node->inputs()) {
43 newNode->addInput(input);
47 newNode->output()->copyMetadata(node->output());
48 node->replaceAllUsesWith(newNode);
55 void RemoveInplaceOps(
const std::shared_ptr<Graph>& graph) {
56 RemoveInplaceOps(graph->block());