Caffe2 - C++ API
A deep learning, cross platform ML framework
remove_inplace_ops.cpp
1 #include <torch/csrc/jit/passes/dead_code_elimination.h>
2 #include <torch/csrc/jit/passes/remove_inplace_ops.h>
3 
4 namespace torch {
5 namespace jit {
6 namespace {
7 static const std::unordered_map<NodeKind, NodeKind> inPlaceToOutOfPlace = {
8  {aten::add_, aten::add},
9  {aten::sub_, aten::sub},
10  {aten::div_, aten::div},
11  {aten::mul_, aten::mul}};
12 
13 bool isInplaceOp(const Node* node) {
14  return inPlaceToOutOfPlace.count(node->kind()) != 0;
15 }
16 
17 // Remove all in-place ops and replace them with out-of-place equivalents.
18 // e.g.
19 // %foo = aten::add_(%foo, %n)
20 // becomes
21 // %foo.2 = aten::add(%foo, %n)
22 //
23 // NOTE: this is NOT SAFE, since it assumes that the LHS is not aliased by
24 // another value. This is only to avoid breaking ONNX export; when alias
25 // analysis is done we can emit a warning if someone tries to export.
26 void RemoveInplaceOps(Block* block) {
27  auto graph = block->owningGraph();
28  auto it = block->nodes().begin();
29  while (it != block->nodes().end()) {
30  auto node = *it;
31  ++it;
32  for (auto block : node->blocks()) {
33  RemoveInplaceOps(block);
34  }
35 
36  if (isInplaceOp(node)) {
37  // create a replacement out of place op
38  auto newNode = graph->create(inPlaceToOutOfPlace.at(node->kind()));
39  newNode->insertBefore(node);
40 
41  // copy inputs
42  for (auto input : node->inputs()) {
43  newNode->addInput(input);
44  }
45 
46  // Create a new output node and replace all uses of self with it
47  newNode->output()->copyMetadata(node->output());
48  node->replaceAllUsesWith(newNode);
49  node->destroy();
50  }
51  }
52 }
53 } // namespace
54 
55 void RemoveInplaceOps(const std::shared_ptr<Graph>& graph) {
56  RemoveInplaceOps(graph->block());
57 }
58 } // namespace jit
59 } // namespace torch
Definition: jit_type.h:17