Caffe2 - C++ API
A deep learning, cross platform ML framework
peephole.cpp
1 #include <torch/csrc/jit/passes/peephole.h>
2 
3 #include <torch/csrc/jit/symbolic_variable.h>
4 
5 #include <torch/csrc/jit/passes/dead_code_elimination.h>
6 
7 namespace torch {
8 namespace jit {
9 
10 // The intent for this optimization pass is to catch all of the small, easy to
11 // catch peephole optimizations you might be interested in doing.
12 //
13 // Right now, it does:
14 // - Eliminate no-op 'expand' nodes
15 // - Simply x.t().t() to x
16 //
17 // TODO: Decide what kind of fixed point strategy we will have
18 //
19 // The parameter `addmm_fusion_enabled` exists because, as it is today, fusing
20 // add + mm has no benefit within PyTorch running ATen ops. However, we rely on
21 // seeing the fused version of addmm for ONNX export, since after ONNX
22 // translation we would see redundant Gemm ops with sub-optimal inputs. This
23 // flag is exposed so that ONNX export can pass `true` to get the fused
24 // behavior, but normal JIT peephole optimization is left alone.
25 void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) {
26  for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
27  auto* node = *it;
28 
29  for (Block* sub_block : node->blocks()) {
30  PeepholeOptimizeImpl(sub_block, addmm_fusion_enabled);
31  }
32 
33  // XXX: remember that if you want to simplify an expression by combining
34  // multiple nodes into a different one, then you need to check that they all
35  // belong to the given block
36  if (node->matches(
37  "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
38  /*const_inputs=*/attr::size)) {
39  // x.expand(x.size()) == x
40  if (auto input_type = node->namedInput(attr::self)
41  ->type()
42  ->cast<CompleteTensorType>()) {
43  auto expanded_sizes = node->get<std::vector<int64_t>>(attr::size);
44  if (expanded_sizes == input_type->sizes()) {
45  node->output()->replaceAllUsesWith(node->namedInput(attr::self));
46  }
47  }
48  } else if (node->matches("aten::t(Tensor self) -> Tensor")) {
49  // x.t().t() == x
50  Node* input_node = node->input()->node();
51  if (input_node->matches("aten::t(Tensor self) -> Tensor")) {
52  node->output()->replaceAllUsesWith(input_node->input());
53  }
54  } else if (node->matches(
55  "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
56  // x.type_as(y) == x iff x.type() == y.type()
57  auto self_type = node->input(0)->type()->cast<DimensionedTensorType>();
58  auto other_type = node->input(1)->type()->cast<DimensionedTensorType>();
59  if (self_type && other_type &&
60  self_type->scalarType() == other_type->scalarType() &&
61  self_type->device() == other_type->device()) {
62  node->output()->replaceAllUsesWith(node->input(0));
63  }
64  } else if (
65  node->matches(
66  "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
67  /*const_inputs=*/attr::alpha)) {
68  // z + x.mm(y) == z.addmm(x, y) == x.mm(y) + z
69  // This optimization has been disabled at the moment, because it's not
70  // helpful at all until we will be able to represent torch.addmm(a, b, c,
71  // out=a). That's because addmm dispatches internally to gemm, which
72  // computes:
73  // C = beta * C + alpha * A @ B
74  // but aten::addmm(a, b, c, 1, 1) is really:
75  // D = beta * C + alpha * A @ B
76  // and because it works out of place on C, we're only trading off an
77  // explicit add for a copy inside the addmm function. Note that it doesn't
78  // even result in fewer reads, because mm won't even load C (because beta
79  // == 0 for it).
80  if (addmm_fusion_enabled &&
81  node->get<at::Scalar>(attr::alpha).value().toDouble() == 1.) {
82  // Look for mm from both sides of the add
83  for (size_t mm_side = 0; mm_side < 2; mm_side++) {
84  // Add will accept tensors of mismatched scalar types, as long as one
85  // of them is a scalar. Addmm will throw in that case, so we can only
86  // perform this fusion if we're sure that it is correct, and for that
87  // we need the add_mat_type. An alternative would be to insert a
88  // type_as conditional on the tensor shape being a scalar, but that
89  // might add overhead, and make analysis harder.
90  auto add_mat_type =
91  node->input(1 - mm_side)->type()->cast<DimensionedTensorType>();
92  if (!add_mat_type)
93  continue;
94 
95  if (node->input(mm_side)->node()->matches(
96  "aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
97  WithInsertPoint guard(node);
98 
99  auto mm_node = node->input(mm_side)->node();
100  SymbolicVariable add_mat(node->input(1 - mm_side));
101  SymbolicVariable mat1(mm_node->input(0));
102  SymbolicVariable mat2(mm_node->input(1));
103 
104  auto mat_type = mat1.value()->type()->cast<DimensionedTensorType>();
105  if (!mat_type) {
106  mat_type = mat2.value()->type()->cast<DimensionedTensorType>();
107  }
108  // We insert the type_as if we're sure that the added element is a
109  // scalar, and we either don't know what is the type of the
110  // multiplied matrices, or know the type, and know that it's
111  // mismatched.
112  if (add_mat_type->dim() == 0 &&
113  (!mat_type ||
114  add_mat_type->scalarType() != mat_type->scalarType())) {
115  add_mat = add_mat.type_as(mat1);
116  }
117 
118  SymbolicVariable addmm_value = add_mat.addmm(mat1, mat2);
119 
120  // Copy shape information from output node
121  ((Value*)addmm_value)->copyMetadata(node->output());
122  node->output()->replaceAllUsesWith(addmm_value);
123  }
124  }
125  }
126  // TODO: this doesn't work with Scalar-Tensor ops! We should canonicalize
127  // those
128  } else if (
129  node->matches(
130  "aten::mul(Tensor self, Scalar other) -> Tensor",
131  /*const_inputs=*/attr::other) ||
132  node->matches(
133  "aten::div(Tensor self, Scalar other) -> Tensor",
134  /*const_inputs=*/attr::other)) {
135  // x * 1 == x / 1 == x
136  if (node->get<at::Scalar>(attr::other)->toDouble() == 1) {
137  node->output()->replaceAllUsesWith(node->input(0));
138  }
139  } else if (
140  node->matches(
141  "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
142  /*const_inputs=*/{attr::alpha, attr::other}) ||
143  node->matches(
144  "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
145  /*const_inputs=*/{attr::alpha, attr::other})) {
146  // x + 0 == x - 0 == x
147  if (node->get<at::Scalar>(attr::alpha)->toDouble() == 1 &&
148  node->get<at::Scalar>(attr::other)->toDouble() == 0) {
149  node->output()->replaceAllUsesWith(node->input(0));
150  }
151  } else if (
152  node->kind() == prim::Float || node->kind() == prim::Int ||
153  node->kind() == prim::ImplicitTensorToNum) {
154  Node* input_node = node->input()->node();
155  if (input_node->kind() == prim::NumToTensor) {
156  node->output()->replaceAllUsesWith(input_node->input());
157  }
158  } else if (
159  node->matches(
160  "aten::_grad_sum_to_size(Tensor(a) self, int[] size) -> Tensor(a)")) {
161  auto uses = node->output()->uses();
162  for (Use u : uses) {
163  if (u.user->matches(
164  "aten::_grad_sum_to_size(Tensor(a) self, int[] size) -> Tensor(a)")) {
165  u.user->replaceInput(0, node->inputs().at(0));
166  }
167  }
168  }
169  }
170 }
171 
172 void PeepholeOptimize(Block* block, bool addmm_fusion_enabled) {
173  PeepholeOptimizeImpl(block, addmm_fusion_enabled);
174  // Eliminate dead code created by any peephole passes we've just done
175  EliminateDeadCode(block);
176 }
177 
178 void PeepholeOptimize(
179  const std::shared_ptr<Graph>& graph,
180  bool addmm_fusion_enabled) {
181  PeepholeOptimize(graph->block(), addmm_fusion_enabled);
182 }
183 
184 } // namespace jit
185 } // namespace torch
Scalar represents a 0-dimensional tensor which contains a single element.
Definition: Scalar.h:22
Definition: jit_type.h:17