Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_mm.cpp
1 #include <torch/csrc/jit/passes/batch_mm.h>
2 
3 #include <ATen/core/functional.h>
4 #include <ATen/core/interned_strings.h>
5 #include <c10/util/Exception.h>
6 #include <torch/csrc/jit/constants.h>
7 #include <torch/csrc/jit/custom_operator.h>
8 #include <torch/csrc/jit/passes/alias_analysis.h>
9 #include <torch/csrc/jit/passes/dead_code_elimination.h>
10 #include <torch/csrc/jit/passes/peephole.h>
11 #include <torch/csrc/jit/symbolic_variable.h>
12 
13 #include <ATen/ATen.h>
14 #include <algorithm>
15 #include <unordered_map>
16 
17 namespace torch {
18 namespace jit {
19 
20 // This pass looks for trees in the graph, where leaves are mm ops, and the
21 // inner vertices are add nodes. Once we have such a tree they can be reduced to
22 // two concats and a single mm (basically into a single multiply of a wide
23 // matrix, with a tall matrix). Such patterns show up mostly in backward of
24 // RNNs, since the derivative of many uses of matrix multiplies with same
25 // weights forms exactly such a tree (note that it's usually also highly
26 // imbalanced i.e. has O(n) depth).
27 //
28 // This (or any tree of adds of MMs):
29 //
30 // +------+ +------+ +------+ +------+ +------+
31 // | | | | | | | | | |
32 // | L1 | | R1 | + | L2 | | R2 | = | O |
33 // | | | | | | | | | |
34 // +------+ +------+ +------+ +------+ +------+
35 //
36 // can be basically transformed into a single MM which looks like this
37 // (we concat all lhs operands, concat rhs operands, do mm):
38 //
39 // +------+
40 // | |
41 // | R1 |
42 // | |
43 // +------+
44 // | |
45 // | R2 |
46 // | |
47 // +------+
48 // +------+------+ +------+
49 // | | | | |
50 // | L1 | L2 | | O |
51 // | | | | |
52 // +------+------+ +------+
53 
54 // Note [Further optimizations]
55 // It would be straightforward to extend the TreeToken class to also detect if
56 // all MMs had the same lhs/rhs. In such case it's more efficient to expand the
57 // lhs and use bmm + sum instead of repeating it in memory via concat.
58 
59 // Note [Overlapping trees]
60 // Additionally it wouldn't be too hard to add support for partially overlapping
61 // trees. Right now the it's forbidden in the algorithm (only a single tree will
62 // be allowed), so theoretically we might miss some optimization options,
63 // especially that the rejected tree could be much larger. I didn't implement
64 // that because it's not necessary for the simple RNN cases I saw, so I decided
65 // to keep stuff simple. If we ever get around implementing this, the right
66 // solution is probably to fuse MMs for the common part, and assume it's an
67 // input leaf for the outer two parts (I don't think it's beneficial to
68 // recompute, unless the subtree is super small, but let's not get into such
69 // details).
70 
71 // The algorithm we're using is simple. We're iterating through the graph in the
72 // topological order and labeling nodes with TreeTokens. Then, we look for roots
73 // of the trees we formed and fuse them.
74 
75 // Tunable parameter. Set to something larger if it turns out to be better.
76 static constexpr size_t min_fusion_size = 4;
77 
78 bool have_same_shape(at::TensorList inputs) {
79  auto expected_sizes = inputs[0].sizes();
80  return std::all_of(
81  inputs.begin(), inputs.end(), [expected_sizes](const at::Tensor& t) {
82  return t.sizes() == expected_sizes;
83  });
84 }
85 
86 bool shape_is_fast_for_reduce(const at::Tensor& lhs, const at::Tensor& rhs) {
87  size_t l = lhs.size(0);
88  size_t m = lhs.size(1);
89  size_t r = rhs.size(1);
90  // Numbers obtained by some simple benchmarks of fp32 gemms on a TITAN V
91  return m < 512 || ((l < 256 && r < 256) || (l > 256 && r > 256));
92 }
93 
94 RegisterOperators mm_tree_reduction_reg(
95  {Operator(prim::MMTreeReduce, [](const Node* node) {
96  size_t num_inputs = node->inputs().size();
97  return [num_inputs](Stack& stack) {
98  std::vector<at::Tensor> inputs;
99  inputs.reserve(num_inputs);
100  for (auto it = stack.end() - num_inputs; it != stack.end(); ++it) {
101  inputs.push_back(std::move(*it).toTensor());
102  }
103  drop(stack, num_inputs);
104 
105  AT_ASSERT(inputs.size() > 0);
106  AT_ASSERT(inputs.size() % 2 == 0);
107  size_t side_num_elems = inputs.size() / 2;
108  auto lhs_inputs = at::TensorList(inputs).slice(0, side_num_elems);
109  auto rhs_inputs = at::TensorList(inputs).slice(side_num_elems);
110  // TODO: checking this is not free, so we should stop if this keeps
111  // failing
112  if (have_same_shape(lhs_inputs) && have_same_shape(rhs_inputs) &&
113  shape_is_fast_for_reduce(lhs_inputs[0], rhs_inputs[0])) {
114  auto lhs = at::cat(lhs_inputs, /*dim=*/1);
115  auto rhs = at::cat(rhs_inputs, /*dim=*/0);
116  push(stack, at::mm(lhs, rhs));
117  } else {
118  auto acc = at::mm(inputs[0], inputs[side_num_elems]);
119  for (size_t i = 1; i < side_num_elems; ++i) {
120  acc.add_(at::mm(inputs[i], inputs[side_num_elems + i]));
121  }
122  push(stack, std::move(acc));
123  }
124  return 0;
125  };
126  })});
127 
128 // TreeTokens will be used to label nodes of the graph, if the nodes will fit
129 // our mm/add tree pattern. Basically we do dynamic programming on DAGs, where
130 // when we reach node N with inputs A and B, then A and B have already been
131 // procesed, and we can try to unify their TreeTokens (if they have them)
132 // and build a larger tree.
133 struct TreeToken {
134  uint64_t tree_size = 0; // NOTE: measured in number of leaves i.e. mm ops
135  Node* node = nullptr;
136  bool is_root = false;
137 
138  static TreeToken mm(Node* mm) {
139  TreeToken token;
140  token.tree_size = 1;
141  token.node = mm;
142  token.is_root = true;
143  return token;
144  }
145 
146  // NB: the returned token might be invalid, so make sure to check its boolean
147  // value!
148  static TreeToken transpose(Node* t, TreeToken& inp_token) {
149  TreeToken token;
150  if (!inp_token.node->matches(
151  "aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
152  return token;
153  }
154  token.tree_size = 1;
155  token.node = t;
156  token.is_root = true;
157  inp_token.is_root = false;
158  return token;
159  }
160 
161  // NB: the returned token might be invalid, so make sure to check its boolean
162  // value!
163  static TreeToken add(Node* add, TreeToken& l, TreeToken& r) {
164  TreeToken token;
165  // See Note [Overlapping trees]
166  if (&l == &r || !l.is_root || !r.is_root)
167  return token;
168  token.tree_size = l.tree_size + r.tree_size;
169  token.node = add;
170  token.is_root = true;
171  l.is_root = r.is_root =
172  false; // Reserve the subtrees, so they can't be used again.
173  return token;
174  }
175 
176  explicit operator bool() {
177  return is_root;
178  }
179 
180  std::vector<Node*> removeTransposesAndGatherMatmuls() {
181  std::vector<Node*> matmuls;
182  std::vector<Node*> queue{node};
183  Graph* graph = node->owningGraph();
184  while (!queue.empty()) {
185  auto n = queue.back();
186  queue.pop_back();
187  if (n->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
188  matmuls.push_back(n);
189  } else if (n->matches("aten::t(Tensor self) -> Tensor")) {
190  Node* input_node = n->input()->node();
191  AT_ASSERT(input_node->matches(
192  "aten::mm(Tensor self, Tensor mat2) -> Tensor"));
193  // (AB)^T == B^TA^T
194  WithInsertPoint insert_guard{input_node};
195  Value* A = input_node->inputs()[0];
196  Value* B = input_node->inputs()[1];
197  Value* AT = graph->insert(aten::t, {A});
198  Value* BT = graph->insert(aten::t, {B});
199  Value* BTAT = graph->insert(aten::mm, {BT, AT});
200  n->output()->replaceAllUsesWith(BTAT);
201  matmuls.push_back(BTAT->node());
202  } else if (
203  n->matches(
204  "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
205  queue.push_back(n->inputs()[0]->node());
206  queue.push_back(n->inputs()[1]->node());
207  } else {
208  AT_ASSERTM(false, "Unsupported node found in a BatchMM tree!");
209  }
210  }
211  return matmuls;
212  }
213 };
214 
215 enum class Side { LHS, RHS };
216 
217 void BatchMMTreeReduce(Block* block) {
218  auto graph = block->owningGraph();
219 
220  // Look for trees in the block
221  std::unordered_map<Node*, TreeToken> tokens;
222  for (auto node : block->nodes()) {
223  if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
224  tokens[node] = TreeToken::mm(node);
225  } else if (node->matches("aten::t(Tensor self) -> Tensor")) {
226  auto input_it = tokens.find(node->input()->node());
227  if (input_it != tokens.end()) {
228  tokens[node] = TreeToken::transpose(node, input_it->second);
229  }
230  } else if (
231  node->matches(
232  "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
233  Node* lhs = node->inputs()[0]->node();
234  Node* rhs = node->inputs()[1]->node();
235  auto lhs_it = tokens.find(lhs);
236  auto rhs_it = tokens.find(rhs);
237  // See Note [Overlapping trees] (regarding the uses().size() == 1 check)
238  // We could treat a subtree with multiple uses as if it was overlapping.
239  // XXX: uses().size() == 1 is also something that guarantees that this
240  // transform is valid, because we know for sure that the none of these
241  // operands depend on the result of the other. If we were to remove this,
242  // we need to compute a transitive closure and actually check the
243  // dependencies.
244  if (lhs_it != tokens.end() && rhs_it != tokens.end() &&
245  lhs->output()->uses().size() == 1 &&
246  rhs->output()->uses().size() == 1) {
247  if (auto token = TreeToken::add(node, lhs_it->second, rhs_it->second)) {
248  tokens[node] = token;
249  }
250  }
251  } else {
252  for (auto block : node->blocks()) {
253  BatchMMTreeReduce(block);
254  }
255  }
256  }
257 
258  // Merge trees we've found
259  for (auto& item : tokens) {
260  auto& root = item.second;
261  if (!root || root.tree_size < min_fusion_size)
262  continue;
263  auto matmuls = root.removeTransposesAndGatherMatmuls();
264  WithInsertPoint insert_guard{root.node};
265  Node* tree_reduce =
266  graph->insertNode(graph->create(Symbol::prim("MMTreeReduce")));
267  for (Node* matmul : matmuls) {
268  tree_reduce->addInput(matmul->inputs().at(0));
269  }
270  for (Node* matmul : matmuls) {
271  tree_reduce->addInput(matmul->inputs().at(1));
272  }
273  root.node->output()->replaceAllUsesWith(tree_reduce->output());
274  // NB: don't bother with cleaning up after yourself. We'll use DCE for that.
275  }
276 }
277 
278 bool shape_is_fast_for_side(const at::Tensor& other_side_input) {
279  // Cutoff chosed by benchmarking on a TITAN V
280  return other_side_input.numel() <= 1024 * 2048;
281 }
282 
283 RegisterOperators mm_batch_side_reg(
284  {Operator(prim::MMBatchSide, [](const Node* node) {
285  size_t num_other_side_inputs = node->inputs().size() - 1;
286  Side single_side = static_cast<Side>(node->i(Symbol::attr("side")));
287  return [num_other_side_inputs, single_side](Stack& stack) {
288  at::Tensor side_input;
289  std::vector<at::Tensor> other_side_inputs;
290  other_side_inputs.reserve(num_other_side_inputs);
291  for (auto it = stack.end() - num_other_side_inputs; it != stack.end();
292  ++it) {
293  other_side_inputs.push_back(std::move(*it).toTensor());
294  }
295  drop(stack, num_other_side_inputs);
296  pop(stack, side_input);
297 
298  auto any_other_input = other_side_inputs[0];
299  if (have_same_shape(other_side_inputs) &&
300  shape_is_fast_for_side(other_side_inputs[0])) {
301  auto other_side_input =
302  at::cat(other_side_inputs, single_side == Side::LHS ? 1 : 0);
303  auto mm_out = single_side == Side::LHS
304  ? side_input.mm(other_side_input)
305  : other_side_input.mm(side_input);
306  auto outputs = at::chunk(
307  mm_out,
308  num_other_side_inputs,
309  /*dim=*/single_side == Side::LHS ? 1 : 0);
310  stack.insert(
311  stack.end(),
312  std::make_move_iterator(outputs.begin()),
313  std::make_move_iterator(outputs.end()));
314  } else {
315  if (single_side == Side::LHS) {
316  for (at::Tensor& other : other_side_inputs) {
317  stack.emplace_back(side_input.mm(other));
318  }
319  } else {
320  for (at::Tensor& other : other_side_inputs) {
321  stack.emplace_back(other.mm(side_input));
322  }
323  }
324  }
325 
326  return 0;
327  };
328  })});
329 
330 std::pair<std::vector<Node*>, std::vector<Node*>> gatherIndependentMMUses(
331  Value* value,
332  AliasDb& alias_db) {
333  const auto postprocess = [&](std::vector<Node*> mms) {
334  if (mms.size() == 0) {
335  return mms;
336  }
337  std::sort(mms.begin(), mms.end(), [](Node* n, Node* m) {
338  return n->isBefore(m);
339  });
340  // Filter out dependent MMs. This algorithm might do very badly if e.g. you
341  // have a lot of independent MMs, that depend on the first one, but I doubt
342  // this will be a common scenario.
343  for (size_t i = 0; i < mms.size(); ++i) {
344  if (mms[i] == nullptr)
345  continue;
346  for (size_t j = i + 1; j < mms.size(); ++j) {
347  if (mms[j] == nullptr)
348  continue;
349  if (!alias_db.couldMoveBeforeTopologically(mms[j], mms[i])) {
350  mms[j] = nullptr;
351  }
352  }
353  }
354  return c10::filter(mms, [](Node* n) { return n != nullptr; });
355  };
356 
357  Block* block = value->node()->owningBlock();
358  std::vector<Node*> lhses; // Will contain nodes where value is used as an lhs
359  std::vector<Node*> rhses; // Like above, but rhs
360  for (Use u : value->uses()) {
361  if (u.user->owningBlock() == block &&
362  u.user->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
363  if (u.offset == 0 && u.user->inputs()[1] != value) {
364  lhses.push_back(u.user);
365  } else if (u.offset == 1 && u.user->inputs()[0] != value) {
366  rhses.push_back(u.user);
367  }
368  }
369  }
370  return std::make_pair(postprocess(lhses), postprocess(rhses));
371 }
372 
373 void BatchMMSide(Block* block, AliasDb& alias_db) {
374  // NB: 8 is the current loop unrolling factor
375  static constexpr size_t how_many_is_many = 8;
376  const auto batch_side = [&](std::vector<Node*>& mms, Side side) {
377  AT_ASSERT(!mms.empty());
378  for (int64_t i = static_cast<int64_t>(mms.size()) - 2; i >= 0; --i) {
379  bool move_ok = alias_db.moveBeforeTopologicallyValid(mms[i], mms[i + 1]);
380  AT_ASSERT(move_ok);
381  }
382  WithInsertPoint insert_guard{mms[0]};
383  Graph* graph = mms[0]->owningGraph();
384  Node* batch_mm = graph->create(
385  prim::MMBatchSide,
386  /*inputs=*/{},
387  /*num_outputs=*/mms.size());
388  graph->insertNode(batch_mm);
389  batch_mm->i_(Symbol::attr("side"), static_cast<int>(side));
390  Value* const_side = mms[0]->inputs().at(side == Side::LHS ? 0 : 1);
391  batch_mm->addInput(const_side);
392  for (size_t i = 0; i < mms.size(); ++i) {
393  batch_mm->addInput(mms[i]->inputs().at(side == Side::LHS ? 1 : 0));
394  mms[i]->output()->replaceAllUsesWith(batch_mm->outputs().at(i));
395  }
396  };
397 
398  std::unordered_set<Value*> considered_values;
399  for (Node* node : block->nodes()) {
400  if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
401  for (Value* input : node->inputs()) {
402  if (/*bool not_inserted = */ !considered_values.emplace(input).second) {
403  continue;
404  }
405  auto uses_with_many = gatherIndependentMMUses(input, alias_db);
406  if (uses_with_many.first.size() >= how_many_is_many) {
407  batch_side(uses_with_many.first, Side::LHS);
408  }
409  if (uses_with_many.second.size() >= how_many_is_many) {
410  batch_side(uses_with_many.second, Side::RHS);
411  }
412  }
413  } else {
414  for (Block* subblock : node->blocks()) {
415  BatchMMSide(subblock, alias_db);
416  }
417  }
418  }
419 }
420 
421 bool hasMutableOperators(Block* block) {
422  for (auto n : block->nodes()) {
423  if (n->kind().is_aten() && n->schema().is_mutable())
424  return true;
425  for (auto b : n->blocks()) {
426  if (hasMutableOperators(b))
427  return true;
428  }
429  }
430  return false;
431 }
432 
433 void BatchMM(std::shared_ptr<Graph>& graph) {
434  if (hasMutableOperators(graph->block())) {
435  // TODO(suo): make BatchMM mutability-safe
436  return;
437  }
438  AliasDb alias_db(graph);
439  BatchMMTreeReduce(graph->block());
440  BatchMMSide(graph->block(), alias_db);
441  EliminateDeadCode(graph);
442  // It's possible that transpose rearrangements have created sequences of
443  // consecutive transposes that didn't exist before.
444  PeepholeOptimize(graph);
445 }
446 
447 } // namespace jit
448 } // namespace torch
Alias analysis pass.
AT_CPP14_CONSTEXPR ArrayRef< T > slice(size_t N, size_t M) const
slice(n, m) - Chop off the first N elements of the array, and keep M elements in the array...
Definition: ArrayRef.h:161
Registration class for new operators.
Definition: static.cpp:52
Definition: jit_type.h:17
An utility class for setting temporary insertion points.
Definition: ir.h:1174
Definition: static.cpp:58
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41
Flush-To-Zero and Denormals-Are-Zero mode.