Caffe2 - C++ API
A deep learning, cross platform ML framework
graph_fuser.cpp
1 #include <torch/csrc/jit/passes/graph_fuser.h>
2 
3 #include <ATen/ExpandUtils.h>
4 #include <c10/util/Exception.h>
5 #include <torch/csrc/jit/autodiff.h>
6 #include <torch/csrc/jit/custom_operator.h>
7 #include <torch/csrc/jit/fuser/interface.h>
8 #include <torch/csrc/jit/operator.h>
9 #include <torch/csrc/jit/passes/alias_analysis.h>
10 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
11 #include <torch/csrc/jit/passes/constant_pooling.h>
12 #include <torch/csrc/jit/passes/dead_code_elimination.h>
13 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
14 #include <torch/csrc/jit/script/compiler.h>
15 #include <torch/csrc/jit/symbolic_variable.h>
16 
17 #include <queue>
18 #include <unordered_map>
19 
20 namespace torch {
21 namespace jit {
22 
23 namespace {
24 
25 // What is a simple mappable operator? It:
26 // - Has a single tensor output
27 // - Output and all tensor inputs have the same shape
28 // - Output and all tensor inputs have the same scalar type
29 // or all tensor inputs have the same scalar type and
30 // output is identified in PropagateInputShapes
31 // - Output and all tensor inputs should be on the same device
32 // - Produces contiguous outputs
33 // Some of these restrictions may be relaxable, but you should
34 // carefully read the code first, as we rely on these assumptions.
35 bool isSimpleMap(Node* node) {
36  static OperatorSet simple_mappable{{
37  "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
38 
39  "aten::abs(Tensor self) -> Tensor",
40  "aten::acos(Tensor self) -> Tensor",
41  "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
42  "aten::asin(Tensor self) -> Tensor",
43  "aten::atan(Tensor self) -> Tensor",
44  "aten::atan2(Tensor self, Tensor other) -> Tensor",
45  "aten::ceil(Tensor self) -> Tensor",
46  "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
47  "aten::cos(Tensor self) -> Tensor",
48  "aten::cosh(Tensor self) -> Tensor",
49  "aten::div(Tensor self, Tensor other) -> Tensor",
50  "aten::exp(Tensor self) -> Tensor",
51  "aten::expm1(Tensor self) -> Tensor",
52  "aten::erf(Tensor self) -> Tensor",
53  "aten::erfc(Tensor self) -> Tensor",
54  "aten::floor(Tensor self) -> Tensor",
55  "aten::fmod(Tensor self, Tensor other) -> Tensor",
56  "aten::frac(Tensor self) -> Tensor",
57  "aten::lgamma(Tensor self) -> Tensor",
58  "aten::log(Tensor self) -> Tensor",
59  "aten::log10(Tensor self) -> Tensor",
60  "aten::log1p(Tensor self) -> Tensor",
61  "aten::log2(Tensor self) -> Tensor",
62  "aten::max(Tensor self, Tensor other) -> Tensor",
63  "aten::min(Tensor self, Tensor other) -> Tensor",
64  "aten::mul(Tensor self, Tensor other) -> Tensor",
65  "aten::neg(Tensor self) -> Tensor",
66  "aten::pow(Tensor self, Tensor exponent) -> Tensor",
67  "aten::pow(Tensor self, Scalar exponent) -> Tensor",
68  "aten::rand_like(Tensor self) -> Tensor",
69  "aten::reciprocal(Tensor self) -> Tensor",
70  "aten::relu(Tensor self) -> Tensor",
71  "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
72  "aten::remainder(Tensor self, Tensor other) -> Tensor",
73  "aten::round(Tensor self) -> Tensor",
74  "aten::rsqrt(Tensor self) -> Tensor",
75  "aten::sigmoid(Tensor self) -> Tensor",
76  "aten::sin(Tensor self) -> Tensor",
77  "aten::sinh(Tensor self) -> Tensor",
78  "aten::sqrt(Tensor self) -> Tensor",
79  "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
80  "aten::tan(Tensor self) -> Tensor",
81  "aten::tanh(Tensor self) -> Tensor",
82  "aten::trunc(Tensor self) -> Tensor",
83  "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
84  "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
85  "aten::mul(Tensor self, Scalar other) -> Tensor",
86  "aten::div(Tensor self, Scalar other) -> Tensor",
87 
88  "aten::eq(Tensor self, Tensor other) -> Tensor",
89  "aten::eq(Tensor self, Scalar other) -> Tensor",
90  "aten::ne(Tensor self, Tensor other) -> Tensor",
91  "aten::ne(Tensor self, Scalar other) -> Tensor",
92  "aten::ge(Tensor self, Tensor other) -> Tensor",
93  "aten::ge(Tensor self, Scalar other) -> Tensor",
94  "aten::gt(Tensor self, Tensor other) -> Tensor",
95  "aten::gt(Tensor self, Scalar other) -> Tensor",
96  "aten::le(Tensor self, Tensor other) -> Tensor",
97  "aten::le(Tensor self, Scalar other) -> Tensor",
98  "aten::lt(Tensor self, Tensor other) -> Tensor",
99  "aten::lt(Tensor self, Scalar other) -> Tensor",
100 
101  "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
102 
103  "aten::type_as(Tensor self, Tensor other) -> Tensor",
104  }};
105  if (!simple_mappable.find(node)) {
106  return false;
107  }
108  // Check that all non-tensor inputs are constant
109  for (Value* input : node->inputs()) {
110  if (input->type()->isSubtypeOf(TensorType::get())) {
111  continue;
112  }
113  if (input->node()->kind() != prim::Constant) {
114  return false;
115  }
116  }
117  return true;
118 }
119 
120 RegisterOperators reg_bn_unsqueeze({Operator(
121  "aten::_ncf_unsqueeze(Tensor self, int ndim) -> Tensor",
122  [](const Node* node) {
123  return [](Stack& stack) {
124  const int64_t ndim = pop(stack).toInt();
125  auto self = pop(stack).toTensor();
126  c10::SmallVector<int64_t, 8> sizes(ndim, 1);
127  AT_ASSERT(self.dim() == 1);
128  sizes.at(1) = self.size(0);
129  push(stack, self.reshape(sizes));
130  return 0;
131  };
132  })});
133 
134 // Yes, no, or no value if we can't tell
135 c10::optional<bool> isDefined(Value* tensor) {
136  if (tensor->type()->isSubtypeOf(TensorType::get())) {
137  return true;
138  }
139  if (tensor->node()->mustBeNone()) {
140  return false;
141  }
142  return {};
143 }
144 
145 bool isFusableBatchNorm(Node* batch_norm) {
146  if (!batch_norm->matches(
147  "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
148  return false;
149  }
150  // If we can't determine if weight and bias is defined statically there's
151  // really no point in decomposing batch norm into simpler ops, since it won't
152  // get fused into a single kernel.
153  return isDefined(batch_norm->namedInput(attr::weight)).has_value() &&
154  isDefined(batch_norm->namedInput(attr::bias)).has_value();
155 }
156 
157 Value* broadcastSizes(at::ArrayRef<Value*> sizes) {
158  AT_ASSERT(!sizes.empty());
159  Graph* graph = sizes[0]->owningGraph();
160  Node* broadcast_n =
161  graph->insertNode(graph->create(prim::BroadcastSizes, sizes));
162  broadcast_n->output()->setType(ListType::ofInts());
163  return broadcast_n->output();
164 }
165 
166 struct GraphFuser {
167  Block* block_;
168  std::unique_ptr<AliasDb> aliasDb_;
169  std::shared_ptr<Graph> graph_;
170 
171  GraphFuser(Block* block, std::shared_ptr<Graph> graph)
172  : block_(block), graph_(std::move(graph)) {}
173 
174  value_list tensorInputs(Node* node) {
175  return filter(node->inputs(), [](Value* v) {
176  return v->type()->isSubtypeOf(TensorType::get());
177  });
178  }
179 
180  bool containsGradSumToSize(Node* fusion_group) {
181  auto nodes = getSubgraph(fusion_group).nodes();
182  return std::any_of(nodes.begin(), nodes.end(), [](Node* n) {
183  return n->kind() == aten::_grad_sum_to_size;
184  });
185  }
186 
187  bool isFusable(Node* node) {
188  return isFusableMap(node) || isFusableBatchNorm(node);
189  }
190 
191  bool isFusableMap(Node* node) {
192  // We don't want to bother with cross-block node movements, as they
193  // are not necessarily correct.
194  if (node->owningBlock() != block_)
195  return false;
196  if (node->kind() == aten::_grad_sum_to_size) {
197  // We only fuse _grad_sum_to_size if
198  // - we will fuse its input next (checked here)
199  // - we can commute the _grad_sum_to_size with everything
200  // along the computation graph until we reach the outputs,
201  // but this is checked later
202  return isFusable(node->inputs()[0]->node());
203  }
204  return node->kind() == prim::FusionGroup || isSimpleMap(node);
205  }
206 
207  bool isFusableCatNode(Node* node) {
208  if (node->kind() != aten::cat)
209  return false;
210  if (!node->is_constant(attr::dim))
211  return false;
212  auto tensors_node = node->namedInput(attr::tensors)->node();
213  if (tensors_node->kind() != prim::ListConstruct)
214  return false;
215  // NB: Note that technically other uses of the list aren't a big problem for
216  // us. It would be enough to place the prim::FusedConcat before the
217  // prim::ListConstruct, and allUsersAreThisConsumerOrOccurAfterIt would
218  // still be satisfied. However, I don't expect this to be necessary any time
219  // soon, and so we're simply assuming that we don't have to deal with it.
220  if (tensors_node->output()->uses().size() > 1)
221  return false;
222  return true;
223  }
224 
225  bool calculatesSize(Node* node) {
226  return node->matches("aten::size(Tensor self) -> int[]");
227  }
228 
229  bool allUsersAreThisConsumerOrCalcSizes(Node* consumer, Value* producer) {
230  auto defining_node = producer->node();
231  for (auto o : defining_node->outputs()) {
232  for (auto u : o->uses()) {
233  if (u.user != consumer && !calculatesSize(u.user))
234  return false;
235  }
236  }
237  return true;
238  }
239 
240  Graph& getSubgraph(Node* n) {
241  AT_ASSERT(n->kind() == prim::FusionGroup);
242  return *n->g(attr::Subgraph);
243  }
244 
245  void decomposeBatchNorm(Node* batch_norm) {
246  static std::shared_ptr<Graph> bn_graph;
247  static std::once_flag flag;
248  std::call_once(
249  flag,
250  [](std::shared_ptr<Graph>* graph_ptr) {
251  static const char* source = R"SCRIPT(
252  def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor:
253  if training:
254  norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum)
255  else:
256  norm_mean = torch._unwrap_optional(running_mean)
257  norm_var = torch._unwrap_optional(running_var)
258  norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim())
259  norm_var = torch._ncf_unsqueeze(norm_var, input.dim())
260  norm_invstd = 1 / (eps + torch.sqrt(norm_var))
261  return ((input - norm_mean) * norm_invstd)
262  )SCRIPT";
263  auto module = std::make_shared<script::Module>();
264  defineMethodsInModule(
265  module, source, script::nativeResolver, /*self=*/c10::nullopt);
266  *graph_ptr = module->get_method("batch_norm").graph();
267  },
268  &bn_graph);
269 
270  AT_ASSERT(isFusableBatchNorm(batch_norm));
271  WithInsertPoint insert_guard{batch_norm};
272  Value* input = batch_norm->namedInput(attr::input);
273  Value* input_dim = graph_->insert(aten::dim, {input});
274  std::vector<Value*> inputs{input,
275  batch_norm->namedInput(attr::running_mean),
276  batch_norm->namedInput(attr::running_var),
277  batch_norm->namedInput(attr::training),
278  batch_norm->namedInput(attr::momentum),
279  batch_norm->namedInput(attr::eps)};
280  Value* new_output =
281  SubgraphUtils::inlineGraph(bn_graph, inputs, batch_norm).at(0);
282  auto weight = batch_norm->namedInput(attr::weight);
283  auto bias = batch_norm->namedInput(attr::bias);
284  if (isDefined(weight).value()) {
285  Value* expanded_weight =
286  graph_->insert(aten::_ncf_unsqueeze, {weight, input_dim});
287  new_output = graph_->insert(aten::mul, {new_output, expanded_weight});
288  }
289  if (isDefined(bias).value()) {
290  Value* expanded_bias =
291  graph_->insert(aten::_ncf_unsqueeze, {bias, input_dim});
292  new_output = graph_->insert(aten::add, {new_output, expanded_bias});
293  }
294  batch_norm->output()->replaceAllUsesWith(new_output);
295  batch_norm->destroy();
296  }
297 
298  void mergeFusionGroups(Node* consumer_group, Node* producer_group) {
299  // Now we have two fusion groups!
300  // Revert the fusion - place all inner nodes of producer back in the outer
301  // graph.
302  std::vector<Node*> temporary_nodes;
303  auto producer_subgraph = &getSubgraph(producer_group);
304 
305  // Initialize a map of inner graph values to outer graph values
306  std::unordered_map<Value*, Value*> inner_to_outer;
307  auto inner_inputs = producer_subgraph->inputs();
308  auto outer_inputs = producer_group->inputs();
309  for (size_t i = 0; i < inner_inputs.size(); ++i) {
310  inner_to_outer[inner_inputs[i]] = outer_inputs[i];
311  }
312 
313  // Clone all nodes
314  for (auto inner : producer_subgraph->nodes()) {
315  Node* outer = block_->owningGraph()->createClone(
316  inner, [&](Value* k) -> Value* { return inner_to_outer.at(k); });
317  outer->insertBefore(producer_group);
318  temporary_nodes.emplace_back(outer);
319  auto inner_outputs = inner->outputs();
320  auto outer_outputs = outer->outputs();
321  for (size_t i = 0; i < inner_outputs.size(); ++i)
322  inner_to_outer[inner_outputs[i]] = outer_outputs[i];
323  }
324 
325  // Replace uses of producer_group outputs and destroy the producer
326  auto subgraph_outputs = producer_subgraph->outputs();
327  for (size_t i = 0; i < subgraph_outputs.size(); ++i) {
328  auto outer_output = inner_to_outer.at(subgraph_outputs[i]);
329  producer_group->outputs()[i]->replaceAllUsesWith(outer_output);
330  }
331  producer_group->destroy();
332  producer_group =
333  nullptr; // Just to get a clear error in case someone uses it
334 
335  // Inline the temporary nodes into the first group
336  auto consumer_subgraph = &getSubgraph(consumer_group);
337  for (auto it = temporary_nodes.rbegin(); it != temporary_nodes.rend();
338  ++it) {
339  Node* node = *it;
340  Node* merged = mergeNodeIntoGroup(consumer_group, node);
341  // If any of the outputs are still used then we need to add them
342  auto outputs = node->outputs();
343  for (size_t i = 0; i < outputs.size(); ++i) {
344  auto output = outputs[i];
345  if (output->uses().size() == 0)
346  continue;
347  consumer_subgraph->registerOutput(merged->outputs()[i]);
348  auto new_output = consumer_group->addOutput();
349  output->replaceAllUsesWith(new_output);
350  new_output->setType(output->type());
351  }
352  node->destroy();
353  }
354  }
355 
356  // insert a producer node into a consuming fusion group.
357  // DOES NOT WORK if n is a consumer of an output of the fusion group
358  // returns the node _inside_ the group that represents the node
359  Node* mergeNodeIntoGroup(Node* group, Node* n) {
360  AT_ASSERT(n->kind() != prim::FusionGroup);
361  auto& subgraph = getSubgraph(group);
362  // map from nodes in the surrounding graph to parameters in the fusion
363  // group's subgraph that correspond to them
364  std::unordered_map<Value*, Value*> inputs_map;
365  size_t i = 0;
366  size_t tensor_insert_idx = 0;
367  AT_ASSERT(group->inputs().size() == subgraph.inputs().size());
368  for (auto input : group->inputs()) {
369  inputs_map[input] = subgraph.inputs()[i++];
370  if (input->type()->isSubtypeOf(TensorType::get()))
371  tensor_insert_idx = i;
372  }
373  // add n's inputs to the fusion group's input list if we don't already have
374  // them
375  // we insert tensors first because the fuser assumes that to be the case
376  // (as a legacy from tensors only)
377  WithInsertPoint guard(*subgraph.nodes().begin());
378  for (auto input : n->inputs()) {
379  if (inputs_map.count(input) == 0) {
380  if (input->type()->isSubtypeOf(TensorType::get())) {
381  auto in_group = subgraph.insertInput(tensor_insert_idx);
382  in_group->setType(input->type());
383  inputs_map[input] = in_group;
384  group->insertInput(tensor_insert_idx, input);
385  tensor_insert_idx++;
386  } else if (
387  n->kind() == aten::_grad_sum_to_size &&
388  input->type()->isSubtypeOf(ListType::ofInts())) {
389  auto in_group = subgraph.addInput();
390  in_group->setType(input->type());
391  inputs_map[input] = in_group;
392  group->addInput(input);
393  } else {
394  // We don't support passing in scalars as arguments to fused kernels,
395  // so we generally don't allow fusing tensor-scalar operations unless
396  // the scalar is constant. In those cases we inline the constants
397  // directly in the body of the fused group.
398  AT_ASSERT(input->node()->kind() == prim::Constant);
399  Node* in_const =
400  subgraph.createClone(input->node(), [](Value*) -> Value* {
401  throw std::runtime_error("unexpected input");
402  });
403  subgraph.insertNode(in_const);
404  inputs_map[input] = in_const->output();
405  }
406  }
407  }
408  // copy n into the graph, remapping its inputs to internal nodes
409  Node* in_graph = subgraph.createClone(
410  n, [&](Value* k) -> Value* { return inputs_map[k]; });
411  // if n's outputs are already inputs to the fusion group,
412  // we need to remove them because n is now inside the fusion group.
413  //
414  // i.e.,
415  // x = f(w); group(x, y, z) becomes group(w, y, z).
416  // x, y, z = f(w); group(x, y, z) becomes group(w).
417  //
418  // remapping nodes that used the input to the newly-merged node
419  // n is not an input when the fusion group is empty
420  auto inputs = group->inputs();
421  for (size_t i = 0; i < n->outputs().size(); ++i) {
422  auto it = std::find(inputs.begin(), inputs.end(), n->outputs()[i]);
423  if (it != inputs.end()) {
424  size_t p = it - inputs.begin();
425  group->removeInput(p);
426  subgraph.inputs()[p]->replaceAllUsesWith(in_graph->outputs()[i]);
427  subgraph.eraseInput(p);
428  }
429  }
430  return subgraph.insertNode(in_graph);
431  }
432 
433  // turn consumer node n into a fusion group with just n inside
434  // to prepare for fusion and replace uses of n with the new group
435  Node* createSingletonFusionGroup(Node* n) {
436  auto group = block_->owningGraph()->createFusionGroup();
437  // propogate position information for the new node so we can always
438  // have a valid mapping
439  group->insertBefore(n);
440  Node* mergedNode = mergeNodeIntoGroup(group, n);
441  getSubgraph(group).registerOutput(mergedNode->output());
442  auto sel = group->addOutput();
443  sel->copyMetadata(n->output());
444  n->replaceAllUsesWith(group);
445  n->destroy();
446  return group;
447  }
448 
449  // TODO: remove this and use WithInsertPoint instead
450  void insertAt(Node** insertion_point, Node* n) {
451  n->insertAfter(*insertion_point);
452  *insertion_point = n;
453  }
454 
455  at::optional<Node*> tryFuse(Node* consumer, Value* producer) {
456  // this handles cases where producer can be moved _into_ the fusion group of
457  // consumer.
458  // TODO: extend to fusion of consumer into _producer's_ fusion blob
459  // if the consumer allInputsAreThisProducer(consumer,producer)
460  // we can move the consumer up into the producer.
461  // but this requires better handling of merging fusion groups so it is not
462  // done now
463  bool shouldFuse = isFusable(producer->node()) &&
464  // Rearrange nodes such that all uses of producer are after the
465  // consumer. Fusion will rewrite those later uses to use the version of
466  // producer generated by the fused blob. In this case, producer becomes
467  // an output of the fusion group.
468  aliasDb_->moveBeforeTopologicallyValid(producer->node(), consumer);
469 
470  if (!shouldFuse) {
471  return at::nullopt;
472  }
473  if (producer->node()->kind() == aten::_grad_sum_to_size &&
474  consumer->kind() == prim::FusionGroup) {
475  // check that we will be able to move the _grad_sum_to_size to be fused
476  // to the end of the fusion group in the fusion compiler
477  // the difficulty here is that the producer is not part of the fusion
478  // group yet
479  for (auto& u : producer->uses()) {
480  if (u.user == consumer) {
481  auto subgraph = &getSubgraph(consumer);
482  if (!trackSingleGradSumToSizeToOutputs(
483  subgraph->inputs().at(u.offset), nullptr)) {
484  return at::nullopt;
485  }
486  }
487  }
488  }
489 
490  auto group = consumer;
491  if (consumer->kind() != prim::FusionGroup) {
492  group = createSingletonFusionGroup(consumer);
493  }
494  if (producer->node()->matches(
495  "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
496  // We don't do any fusions in here, but simply decompose the batch norm
497  // into a kernel that computes the stats + pointwise ops which will be
498  // considered in this fusion next.
499  decomposeBatchNorm(producer->node());
500  return group;
501  }
502  if (producer->node()->kind() == prim::FusionGroup) {
503  mergeFusionGroups(group, producer->node());
504  return group;
505  }
506  AT_ASSERT(producer->node()->outputs().size() == 1);
507  Node* merged = mergeNodeIntoGroup(group, producer->node());
508  // remaining uses of this producer can occur because we allow
509  // fusion in cases where uses remain after the consumer
510  // if these exist, re-route them to the version of producer
511  // created in FusionGroup
512  if (producer->uses().size() != 0) {
513  getSubgraph(group).registerOutput(merged->output());
514  Value* new_producer = group->addOutput();
515  new_producer->copyMetadata(producer);
516  producer->replaceAllUsesWith(new_producer);
517  }
518  producer->node()->destroy();
519  return group;
520  }
521 
522  bool canFuseChunk(Node* consumer, Value* producer) {
523  if (consumer->kind() != prim::FusionGroup) {
524  return false;
525  }
526  // Does the chunk have constant chunks/dim?
527  auto* chunk = producer->node();
528  if (chunk->kind() != prim::ConstantChunk)
529  return false;
530  // And all uses of the chunk are in this consumer
531  for (auto s : chunk->outputs()) {
532  for (auto u : s->uses()) {
533  if (u.user != consumer) {
534  return false;
535  }
536  }
537  }
538  // And isn't a no-op chunk (chunks == 1). Have CSE clean this up.
539  // We could fuse this but it's better to just delete the node.
540  if (chunk->i(attr::chunks) == 1) {
541  return false;
542  }
543  return true;
544  }
545 
546  c10::optional<Node*> findFusedChunk(Node* group, Value* input) {
547  AT_ASSERT(group->kind() == prim::FusionGroup);
548  auto it = std::find(group->inputs().begin(), group->inputs().end(), input);
549  if (it == group->inputs().end()) {
550  return c10::nullopt;
551  }
552  size_t input_index = it - group->inputs().begin();
553  auto& subgraph = getSubgraph(group);
554  auto* subgraph_input = subgraph.inputs().at(input_index);
555  // If subgraph_input is an input to prim::ConstantChunk, it will have 1 use
556  auto* node = subgraph_input->uses().at(0).user;
557  if (node->kind() == prim::ConstantChunk) {
558  AT_ASSERT(subgraph_input->uses().size() == 1);
559  return node;
560  }
561  return c10::nullopt;
562  }
563 
564  void fuseChunkByReusingExistingFusedChunk(
565  Node* group,
566  Node* chunk,
567  Node* existingFusedChunk) {
568  if (chunk->outputs().size() != existingFusedChunk->outputs().size()) {
569  return;
570  }
571  auto& subgraph = getSubgraph(group);
572  for (size_t i = 0; i < chunk->outputs().size(); ++i) {
573  // Find the input to the FusionGroup (group)
574  auto* replacement_val = existingFusedChunk->outputs().at(i);
575  auto* val = chunk->outputs().at(i);
576  auto it = std::find(group->inputs().begin(), group->inputs().end(), val);
577  auto input_index = it - group->inputs().begin();
578 
579  // Rewrite the graph to use replacement_val
580  auto group_input = subgraph.inputs().at(input_index);
581  group_input->replaceAllUsesWith(replacement_val);
582 
583  // Remove the input, it's no longer needed
584  group->removeInput(input_index);
585  subgraph.eraseInput(input_index);
586  }
587  chunk->destroy();
588  }
589 
590  // There are two invariants for prim::ConstantChunk:
591  // (1) the tensor input to prim::ConstantChunk must be an input to the fusion
592  // group (2) no two ConstantChunks in the same FusionGroup can share a tensor
593  // input.
594  graph_node_list::iterator fuseChunk(Node* consumer, Value* producer) {
595  auto* chunk = producer->node();
596  AT_ASSERT(consumer->kind() == prim::FusionGroup);
597  AT_ASSERT(chunk->kind() == prim::ConstantChunk);
598 
599  // if producer's input is already an input to a prim::ConstantChunk node,
600  // we cannot add a new prim::ConstantChunk node because of invariant (2).
601  auto* chunked_tensor = producer->node()->input();
602  if (auto existingFusedChunk = findFusedChunk(consumer, chunked_tensor)) {
603  fuseChunkByReusingExistingFusedChunk(
604  consumer, chunk, *existingFusedChunk);
605  return consumer->reverseIterator();
606  }
607 
608  // Move prim::ConstantChunk into the FusionGroup
609  mergeNodeIntoGroup(consumer, chunk);
610  chunk->destroy();
611  return consumer->reverseIterator();
612  }
613 
614  value_list sortReverseTopological(ArrayRef<Value*> inputs) {
615  value_list result;
616  for (auto i : inputs) {
617  if (i->node()->owningBlock() == block_) {
618  result.push_back(i);
619  }
620  }
621  // Sort in reverse topological order
622  std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
623  return a->node()->isAfter(b->node());
624  });
625  return result;
626  }
627 
628  graph_node_list::iterator scanNodeForChunks(Node* consumer) {
629  if (consumer->kind() == prim::FusionGroup) {
630  auto inputs = sortReverseTopological(consumer->inputs());
631  for (auto producer : inputs) {
632  if (!canFuseChunk(consumer, producer)) {
633  continue;
634  }
635  return fuseChunk(consumer, producer);
636  }
637  }
638  return ++consumer->reverseIterator();
639  }
640 
641  void insertExplicitBroadcast(Node* node) {
642  WithInsertPoint insert_guard{node};
643  auto tensors = tensorInputs(node);
644  auto new_tensors =
645  SymbolicVariable::broadcast_tensors(fmap<SymbolicVariable>(tensors));
646 
647  // Replace tensors inputs with broadcasted values
648  auto new_tensors_it = new_tensors.begin();
649  for (size_t i = 0; i < node->inputs().size(); ++i) {
650  if (node->inputs()[i]->type()->isSubtypeOf(TensorType::get())) {
651  AT_ASSERT(new_tensors_it != new_tensors.end());
652  node->replaceInput(i, *(new_tensors_it++));
653  }
654  }
655  }
656 
657  Node* promoteChunkToBroadcastingChunk(Node* chunk) {
658  AT_ASSERT(chunk->kind() == prim::ConstantChunk);
659 
660  size_t nchunks = chunk->i(attr::chunks);
661  Node* bchunk =
662  chunk->owningGraph()->create(prim::BroadcastingChunk, nchunks);
663  bchunk->addInput(chunk->input());
664  for (size_t i = 0; i < nchunks; ++i) {
665  auto* old_output = chunk->outputs().at(i);
666  auto* new_output = bchunk->outputs().at(i);
667  new_output->copyMetadata(old_output);
668  old_output->replaceAllUsesWith(new_output);
669  }
670  bchunk->copyAttributes(*chunk);
671  bchunk->insertAfter(chunk);
672  chunk->destroy();
673  return bchunk;
674  }
675 
676  // in places where op can be fused into a consumer but chunk is in the way
677  // distribute chunk to op's operands:
678  // replace a,b = chunk(op(x,y,z)) with:
679  // x', y', z' = broadcast_tensors([x, y, z])
680  // x0,x1 = chunk(x') (x0 has a's type, x1 has b's type)
681  // y0,y1 = chunk(y') (y0 has a's type, y1 has b's type)
682  // z0,z1 = chunk(z') (z0 has a's type, z1 has b's type)
683  // a = op(x0,y0,z0) (a,b have their same size but are now contiguous)
684  // b = op(x1,y1,x1)
685  //
686  // The graph fuser uses an intermediate prim::BroadcastingChunk node to
687  // represent this behavior concisely. BroadcastingChunk(x, y, z) broadcasts
688  // all of its inputs and then chunks each input, in order, the same way.
689  // The above graph is equivalent to:
690  // x0, x1, y0, y1, z0, z1 = BroadcastingChunk(x, y, z)
691  // a = op(x0,y0,z0)
692  // b = op(x1,y1,x1)
693  //
694  // NB: The explicit broadcast is important for correctness.
695  // Let's say we have:
696  // %z = aten::mul(%x, %y)
697  // %z.1, %z.2 = aten::chunk(%z, ...)
698  // ... = prim::FusionGroup(%z.1, %z.2, ...)
699  // It's possible that %x and %y do not have the same size as %z and
700  // need to be expanded first so that they can be chunked like %z
701  //
702  // NB: Chunk motion only occurs with fusable consumers, which implies
703  // that there is always some other operation, e.g., a+b, that happens
704  // after the chunk, and will be put into the fusion group. This is
705  // important, because distributing the chunk changes the contiguity
706  // of a and b, and so the results would be invalid, except that we know
707  // that simple_mappable operations will restore contiguity before
708  // we exit the fusion group.
709  //
710  // NB: The intermediate BroadcastingChunk is important for moving chunks past
711  // more than one operation: the graph fuser is not able to easily move
712  // operations around broadcast_tensors + chunk nodes. Let f, g, h be fusible
713  // ops
714  // x = f(v, w)
715  // z = g(x, y)
716  // a, b = chunk(z)
717  // c = h(a, b)
718  // becomes (with the broadcast_tensors + chunk approach):
719  // x = f(v, w)
720  // x', y' = broadcast_tensors([x, y])
721  // ax, bx = chunk(x')
722  // ay, by = chunk(y')
723  // a = g(ax, ay)
724  // b = g(bx, by)
725  // c = h(a, b)
726  // The broadcast_tensors node makes it harder to move f into the resulting
727  // FusionGroup of g, g, and h. Keeping the broadcasting and chunk behavior
728  // together results in:
729  // x = f(v, w)
730  // ax, bx, ay, by = BroadcastingChunk(x, y)
731  // a = g(ax, ay)
732  // b = g(bx, by)
733  // c = h(a, b)
734  // making it easier to move f after the BroadcastingChunk:
735  // ay, by, av, bv, aw, bw = BroadcastingChunk(y, v, w)
736  // ax = f(av, aw)
737  // by = f(bv, bw)
738  // a = g(ax, ay)
739  // b = g(bx, by)
740  // c = h(a, b)
741 
742  bool tryToMoveChunk(Node* consumer, Value* producer) {
743  // is the output from a chunk/bchunk node?
744  auto* chunk = producer->node();
745  if (chunk->kind() != prim::ConstantChunk &&
746  chunk->kind() != prim::BroadcastingChunk)
747  return false;
748 
749  // try to find a producer to move after the chunk/bchunk. The producer must
750  // be fusible into the consumer.
751  auto it = std::find_if(
752  chunk->inputs().begin(),
753  chunk->inputs().end(),
754  [&](Value* producer_for_chunk) {
755  return isFusableMap(producer_for_chunk->node()) &&
756  allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk);
757  });
758  if (it == chunk->inputs().end()) {
759  return false;
760  }
761  Value* producer_for_chunk = *it;
762  size_t producer_index = it - chunk->inputs().begin();
763 
764  // all uses of the chunk must be in in this consumer
765  for (auto s : chunk->outputs()) {
766  for (auto u : s->uses()) {
767  if (u.user != consumer)
768  return false;
769  }
770  }
771  // multiple return operators
772  Node* producer_for_chunk_node = producer_for_chunk->node();
773  AT_ASSERT(producer_for_chunk_node->outputs().size() == 1);
774 
775  // Convert chunk to bchunk, if it isn't one already. The bchunk represents a
776  // broadcast and one or more chunk operations.
777  auto* bchunk = chunk;
778  if (chunk->kind() == prim::ConstantChunk) {
779  bchunk = promoteChunkToBroadcastingChunk(chunk);
780  }
781  size_t nchunks = bchunk->i(attr::chunks);
782  WithInsertPoint guard(bchunk->next());
783 
784  std::vector<Value*> producer_chunk_outputs;
785  for (size_t i = 0; i < nchunks; i++) {
786  producer_chunk_outputs.push_back(
787  bchunk->output(nchunks * producer_index + i));
788  }
789 
790  // Add each of op's operands to the bchunk node.
791  // chunked_inputs[input_nr][chunk_output_idx]
792  // = Node* for chunk_output_idx'th output of the chunk(inputs[input_nr])
793  std::vector<std::vector<Value*>> chunked_inputs;
794 
795  for (auto input : producer_for_chunk_node->inputs()) {
796  // XXX: we only work with pointwise ops in here, so we know it is valid to
797  // push the concat only through tensor arguments (and all other args can
798  // be safely ignored).
799  if (!input->type()->isSubtypeOf(TensorType::get()))
800  continue;
801 
802  // if 'input' is already an input to the bchunk, reuse it.
803  auto bchunk_inputs = bchunk->inputs();
804  auto it = std::find(bchunk_inputs.begin(), bchunk_inputs.end(), input);
805  if (it != bchunk_inputs.end()) {
806  chunked_inputs.emplace_back();
807  auto input_index = std::distance(bchunk_inputs.begin(), it);
808  for (size_t chunk = 0; chunk < nchunks; ++chunk) {
809  chunked_inputs.back().push_back(
810  bchunk->outputs().at(nchunks * input_index + chunk));
811  }
812  continue;
813  }
814 
815  // NB: I decided not to use cloneFrom here, because if we make cloneFrom
816  // copy selects one day, it is definitely not what you want here (selects
817  // have different types).
818  // TODO: Perhaps we should use cloneFrom now, as it seems unlikely
819  // to copy select nodes now that we have refactored to have a Value
820  // distinct from Node.
821  bchunk->addInput(input);
822  chunked_inputs.emplace_back(); // alas, to not be C++17
823  for (auto chunk_sel : producer_chunk_outputs) {
824  Value* input_chunk_sel = bchunk->addOutput();
825  input_chunk_sel->setType(chunk_sel->type());
826  chunked_inputs.back().push_back(input_chunk_sel);
827  }
828  }
829 
830  // apply the op to each chunk of the chunked operands,
831  // and then rewrite the graph to use them!
832  for (auto chunk_sel : producer_chunk_outputs) {
833  auto original_inputs = producer_for_chunk_node->inputs();
834  Node* chunked_op =
835  block_->owningGraph()->create(producer_for_chunk_node->kind());
836  chunked_op->copyAttributes(*producer_for_chunk_node);
837  chunked_op->output()->setType(chunk_sel->type());
838  auto chunked_inputs_it = chunked_inputs.begin();
839  for (Value* original_input : original_inputs) {
840  if (original_input->type()->isSubtypeOf(TensorType::get())) {
841  AT_ASSERT(chunked_inputs_it != chunked_inputs.end());
842  chunked_op->addInput(
843  chunked_inputs_it->at(chunk_sel->offset() % nchunks));
844  ++chunked_inputs_it;
845  } else {
846  chunked_op->addInput(original_input);
847  }
848  }
849  bchunk->owningGraph()->insertNode(chunked_op);
850  chunk_sel->replaceAllUsesWith(chunked_op->output());
851  }
852 
853  bchunk->removeInput(producer_index);
854  for (size_t i = 0; i < nchunks; i++) {
855  bchunk->eraseOutput(nchunks * producer_index);
856  }
857 
858  // The output of producer_for_chunk_node could have been used in some
859  // aten::size operators, so we need to clean those up as well (we simply
860  // broadcast all its tensor inputs).
861  auto size_calc_uses = producer_for_chunk_node->output()->uses();
862  if (!size_calc_uses.empty()) {
863  auto tensor_inputs = filter(
864  producer_for_chunk_node->inputs(),
865  [](Value* v) { return v->type()->isSubtypeOf(TensorType::get()); });
866  auto tensor_sizes = fmap(tensor_inputs, [](Value* v) {
867  return v->owningGraph()->insert(aten::size, {v});
868  });
869  AT_ASSERT(!tensor_sizes.empty());
870  Value* output_size = tensor_sizes.size() == 1
871  ? tensor_sizes[0]
872  : broadcastSizes(tensor_sizes);
873  for (Use u : size_calc_uses) {
874  u.user->output()->replaceAllUsesWith(output_size);
875  u.user->destroy();
876  }
877  }
878  producer_for_chunk_node->destroy();
879  return true;
880  }
881 
882  // returns where to continue scanning, and whether any fusion was made
883  std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
884  if (isFusable(consumer)) {
885  // handle inputs in reverse topological order as well...
886  // otherwise in f(a,a+b) it will appear a is used twice if we consider
887  // the f-a fusion before the f-(a+b) fusion first.
888  auto inputs = sortReverseTopological(consumer->inputs());
889  for (auto producer : inputs) {
890  if (tryToMoveChunk(consumer, producer)) {
891  // the chunk before this consumer was re-arranged to allow fusion,
892  // we scan this consumer again to perform the fusion
893  return std::make_pair(consumer->reverseIterator(), true);
894  }
895  auto fusion_group = tryFuse(consumer, producer);
896  if (fusion_group) {
897  // after fusion, consumer moves into a FusionGroup, so inputs is no
898  // longer valid so we rescan the new FusionGroup for more fusions...
899  return std::make_pair(fusion_group.value()->reverseIterator(), true);
900  }
901  }
902  }
903  return std::make_pair(++consumer->reverseIterator(), false);
904  }
905 
906  void replaceIntermediateBroadcastingChunks() {
907  for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
908  auto* node = *it;
909  ++it; // We might delete node, so increment the iterator now.
910  if (node->kind() != prim::BroadcastingChunk) {
911  continue;
912  }
913  auto* bchunk = node;
914  insertExplicitBroadcast(bchunk);
915 
916  auto* graph = block_->owningGraph();
917  size_t nchunks = bchunk->i(attr::chunks);
918  WithInsertPoint guard(bchunk->next());
919 
920  // Split the bchunk into bchunks.inputs().size() number of chunk nodes.
921  for (size_t input_offset = 0; input_offset < bchunk->inputs().size();
922  input_offset++) {
923  auto* input = bchunk->inputs().at(input_offset);
924 
925  Node* new_chunk =
926  graph->insertNode(graph->create(prim::ConstantChunk, input, 0));
927  new_chunk->copyAttributes(*bchunk);
928  for (size_t output_offset = 0; output_offset < nchunks;
929  output_offset++) {
930  auto new_output = new_chunk->addOutput();
931  auto old_output =
932  bchunk->outputs().at(input_offset * nchunks + output_offset);
933  new_output->copyMetadata(old_output);
934  old_output->replaceAllUsesWith(new_output);
935  }
936  }
937  bchunk->destroy();
938  }
939  }
940 
941  bool usedOnlyInSize(Value* v) {
942  const auto& uses = v->uses();
943  return std::all_of(uses.begin(), uses.end(), [](const Use& u) {
944  return u.user->matches("aten::size(Tensor self) -> int[]");
945  });
946  }
947 
948  // Builds up expressions that compute shapes of all intermediates (and
949  // outputs) of the fusion group, based on the sizes of inputs. You should run
950  // DCE to remove those that you end up not using.
951  std::unordered_map<Value*, Value*> buildShapeExpressions(Node* fusion_group) {
952  WithInsertPoint insert_guard{fusion_group->next()};
953  std::unordered_map<Value*, Value*> shape_of;
954 
955  Graph* graph = fusion_group->owningGraph();
956  auto subgraph = fusion_group->g(attr::Subgraph);
957 
958  auto inputs = fusion_group->inputs();
959  auto sinputs = subgraph->inputs();
960  AT_ASSERT(inputs.size() == sinputs.size());
961  for (size_t i = 0; i < inputs.size(); ++i) {
962  if (inputs[i]->type()->isSubtypeOf(TensorType::get())) {
963  shape_of[sinputs[i]] = graph->insert(aten::size, {inputs[i]});
964  }
965  }
966 
967  // When we have a guarantee that an output won't be removed, because it's
968  // used in expressions that don't involve size checks, we can use its size
969  // instead of computing a long chain of broadcasts, starting from the
970  // beginning of the kernel.
971  auto outputs = fusion_group->outputs();
972  auto soutputs = subgraph->outputs();
973  AT_ASSERT(outputs.size() == soutputs.size());
974  for (size_t i = 0; i < outputs.size(); ++i) {
975  if (usedOnlyInSize(outputs[i]))
976  continue;
977  shape_of[soutputs[i]] = graph->insert(aten::size, {outputs[i]});
978  }
979 
980  for (Node* n : subgraph->nodes()) {
981  // XXX: Use of shape_of.emplace is crucial to the output shape
982  // optimization!
983  if (n->kind() == prim::FusedConcat) {
984  // This is a bit more involved, because we have to account for the case
985  // when inputs have different shapes, but fortunately those tensors are
986  // always outputs, and so we can simply avoid replacing their queries,
987  // because it won't help us.
988  continue;
989  }
990  if (n->kind() == prim::Constant) {
991  continue;
992  }
993  if (n->kind() == prim::ConstantChunk) {
994  Node* sizes_node = graph->insertNode(
995  graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
996  sizes_node->i_(attr::dim, n->i(attr::dim));
997  sizes_node->i_(attr::chunks, n->i(attr::chunks));
998  Value* regular_size = sizes_node->outputs().at(0);
999  Value* last_size = sizes_node->outputs().at(1);
1000  regular_size->setType(ListType::ofInts());
1001  last_size->setType(ListType::ofInts());
1002  auto outputs = n->outputs();
1003  for (Value* o : outputs.slice(0, outputs.size() - 1)) {
1004  shape_of.emplace(o, regular_size);
1005  }
1006  shape_of.emplace(outputs.at(outputs.size() - 1), last_size);
1007  continue;
1008  }
1009  auto tensor_inputs = filter(n->inputs(), [](Value* v) {
1010  return v->type()->isSubtypeOf(TensorType::get());
1011  });
1012  auto shapes =
1013  fmap(tensor_inputs, [&](Value* v) { return shape_of.at(v); });
1014  AT_ASSERT(!shapes.empty());
1015  shape_of.emplace(
1016  n->output(), shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes));
1017  }
1018  return shape_of;
1019  }
1020 
1021  void removeOutputsUsedOnlyInSize(Node* fusion_group) {
1022  if (fusion_group->kind() != prim::FusionGroup)
1023  return;
1024  auto subgraph = fusion_group->g(attr::Subgraph);
1025 
1026  auto shape_of = buildShapeExpressions(fusion_group);
1027  auto outputs = fusion_group->outputs().vec();
1028  auto soutputs = subgraph->outputs().vec();
1029  // XXX: Iterating in this order is not only good for performance reasons!
1030  // It is also crucial for correctness (i has to reflect the current true
1031  // index of outputs[i])!
1032  for (int64_t i = static_cast<int64_t>(outputs.size()) - 1; i >= 0; --i) {
1033  auto output = outputs[i];
1034  auto soutput = soutputs[i];
1035  if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) {
1036  auto uses = output->uses();
1037  for (Use u : uses) {
1038  AT_ASSERT(u.user->matches("aten::size(Tensor self) -> int[]"));
1039  u.user->output()->replaceAllUsesWith(shape_of.at(soutput));
1040  u.user->destroy();
1041  }
1042  fusion_group->eraseOutput(i);
1043  subgraph->eraseOutput(i);
1044  }
1045  }
1046  }
1047 
1048  void refreshAliasDb() {
1049  aliasDb_ = torch::make_unique<AliasDb>(graph_);
1050  }
1051 
1052  bool canFuseWithConcat(Value* producer, Node* before_check) {
1053  if (!isFusable(producer->node())) {
1054  return false;
1055  }
1056  // NB: it is important that this check happens after isFusable, which checks
1057  // that the blocks match, and it's not a special node like prim::Param
1058  if (!aliasDb_->couldMoveBeforeTopologically(
1059  producer->node(), before_check)) {
1060  return false;
1061  }
1062  // Fusion groups can be merged with concat's group if and only if
1063  // - the value they produce isn't already coming from a concat and
1064  // - the fusion group does not contain GradSumToSize
1065  if (producer->node()->kind() == prim::FusionGroup) {
1066  auto subgraph = producer->node()->g(attr::Subgraph);
1067  auto* node = subgraph->outputs().at(producer->offset())->node();
1068  return node->kind() != prim::FusedConcat &&
1069  !containsGradSumToSize(producer->node());
1070  }
1071  return true;
1072  }
1073 
1074  Node* createFusedConcat(Node* node) {
1075  AT_ASSERT(node->kind() == aten::cat);
1076 
1077  Graph* graph = node->owningGraph();
1078  Node* list_construct = node->namedInput(attr::tensors)->node();
1079  int64_t dim = node->get<int64_t>(attr::dim).value();
1080 
1081  Node* fused_cat = graph->create(prim::FusedConcat, list_construct->inputs())
1082  ->i_(attr::dim, dim);
1083  fused_cat->insertBefore(list_construct);
1084  fused_cat->output()->copyMetadata(node->output());
1085 
1086  // NB: this deletes the fused_cat node from the original graph
1087  return createSingletonFusionGroup(fused_cat);
1088  }
1089 
1090  void fuseConcats() {
1091  for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();
1092  ++it) {
1093  Node* cat = *it;
1094  if (!isFusableCatNode(cat)) {
1095  continue;
1096  }
1097  Node* list_construct = cat->namedInput(attr::tensors)->node();
1098  Node* fused_cat = createFusedConcat(cat);
1099  Value* fused_cat_out = fused_cat->output();
1100 
1101  auto sorted_inputs = sortReverseTopological(fused_cat->inputs());
1102  size_t input_idx = 0;
1103  bool any_fused = false;
1104  while (input_idx < sorted_inputs.size()) {
1105  Value* input = sorted_inputs[input_idx++];
1106  if (!canFuseWithConcat(input, fused_cat)) {
1107  continue;
1108  }
1109  any_fused = true;
1110  auto maybe_group = tryFuse(fused_cat, input);
1111  AT_ASSERT(maybe_group && maybe_group == fused_cat);
1112  // We could have destroyed multiple inputs when performing this fusion,
1113  // so we have to recompute the list and iterate over it again.
1114  sorted_inputs = sortReverseTopological(fused_cat->inputs());
1115  input_idx = 0;
1116  }
1117 
1118  if (any_fused) {
1119  cat->output()->replaceAllUsesWith(fused_cat_out);
1120  it.destroyCurrent();
1121  if (list_construct->output()->uses().empty()) {
1122  list_construct->destroy();
1123  }
1124  } else {
1125  fused_cat->destroy();
1126  }
1127  }
1128  }
1129 
1130  void optimizeFusedGraphs() {
1131  for (Node* node : block_->nodes()) {
1132  if (node->kind() != prim::FusionGroup) {
1133  continue;
1134  }
1135  auto subgraph = node->g(attr::Subgraph);
1136  EliminateDeadCode(subgraph);
1137  EliminateCommonSubexpression(subgraph);
1138  ConstantPooling(subgraph);
1139  }
1140  }
1141 
1142  void run() {
1143  // Run the pass until no changes are made.
1144  // This is neccessary, because the algorithm can miss out on certain fusion
1145  // opportunities if ran only once. Consider this graph:
1146  //
1147  // %1 = f(...)
1148  // %2 = g(%1)
1149  // %3 = h(%1)
1150  // %4 = l(%3)
1151  // return (%4, %2)
1152  //
1153  // where f, g, h, l are simple map ops.
1154  // The first iteration will fuse %4 and %3, and see that %1 is an input, but
1155  // can't be fused, because it has a different use before the fusion group
1156  // in our topological ordering. Then, %2 will be considered, and fused with
1157  // %1. If we do another iteration, the algorithm will consider the fusion of
1158  // these two groups and fix the situation.
1159  bool any_changed = true;
1160  while (any_changed) {
1161  any_changed = false;
1162  refreshAliasDb();
1163  for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
1164  bool changed;
1165  std::tie(it, changed) = scanNode(*it);
1166  any_changed |= changed;
1167  }
1168  }
1169  refreshAliasDb();
1170 
1171  fuseConcats();
1172 
1173  optimizeFusedGraphs();
1174 
1175  // The graph fuser can add intermediate prim::BroadcastingChunk nodes.
1176  // Replace them with broadcasts + chunks.
1177  replaceIntermediateBroadcastingChunks();
1178 
1179  // Fuse starting chunks into the group.
1180  for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
1181  it = scanNodeForChunks(*it);
1182  }
1183 
1184  // Remove outputs that have been added only because we need their size
1185  for (Node* n : block_->nodes()) {
1186  removeOutputsUsedOnlyInSize(n);
1187  }
1188 
1189  for (Node* node : block_->nodes()) {
1190  for (Block* sub_block : node->blocks()) {
1191  GraphFuser(sub_block, graph_).run();
1192  }
1193  }
1194  }
1195 };
1196 
1197 void PeepholeOptimizeShapeExpressions(Block* block) {
1198  auto nodes = block->nodes();
1199  for (auto it = nodes.begin(); it != nodes.end(); ++it) {
1200  Node* node = *it;
1201  for (Block* subblock : node->blocks()) {
1202  PeepholeOptimizeShapeExpressions(subblock);
1203  }
1204  if (node->kind() == prim::BroadcastSizes) {
1205  // Remove no-op broadcasts.
1206  if (node->inputs().size() == 1) {
1207  node->output()->replaceAllUsesWith(node->input());
1208  it.destroyCurrent();
1209  continue;
1210  }
1211  // Deduplicate inputs, but use their unique() values to ensure
1212  // this process only depends on the graph.
1213  std::map<size_t, Value*> unique_to_value;
1214  for (Value* input : node->inputs()) {
1215  unique_to_value.emplace(input->unique(), input);
1216  }
1217  if (unique_to_value.size() != node->inputs().size()) {
1218  std::vector<Value*> inputs;
1219  inputs.reserve(unique_to_value.size());
1220  for (auto& entry : unique_to_value) {
1221  inputs.push_back(entry.second);
1222  }
1223  if (inputs.size() == 1) {
1224  node->output()->replaceAllUsesWith(inputs[0]);
1225  } else {
1226  WithInsertPoint insert_guard{node};
1227  node->output()->replaceAllUsesWith(broadcastSizes(inputs));
1228  }
1229  it.destroyCurrent();
1230  --it; // Revisit the node with deduplicated inputs
1231  continue;
1232  }
1233  // Remove compose simple chains of broadcasts into a single node.
1234  const auto& uses = node->output()->uses();
1235  if (uses.size() == 1 && uses[0].user->kind() == prim::BroadcastSizes) {
1236  Node* user = uses[0].user;
1237  user->removeInput(uses[0].offset);
1238  // NB: we don't care about deduplication in here, as we will visit user
1239  // later.
1240  for (Value* i : node->inputs()) {
1241  user->addInput(i);
1242  }
1243  it.destroyCurrent();
1244  }
1245  }
1246  }
1247 }
1248 
1249 } // anonymous namespace
1250 
1251 // This takes a _grad_sum_to_size output and tracks it to the return
1252 // statements that depend on it, checking that it only hits nodes
1253 // that commute with _grad_sum_to_size on its path.
1254 // If a non-nullptr vector pointer outputGradSumToSizes is passed, the sizes
1255 // will be recorded as target sizes for the outputs as applicable.
1256 // In the graph_fuser pass we only need to check that we can go to the
1257 // outputs while in the fuser's compiler we want to record the sizes.
1258 // Note: This will only record a new sum_to_size if there is not one
1259 // already. As we want the last grad_sum_to_size, you need to call
1260 // it in reverse order when recording and removing outputs.
1261 bool trackSingleGradSumToSizeToOutputs(
1262  Value* gradSumToSizeOutput,
1263  std::vector<int64_t>* outputGradSumToSizes) {
1264  static OperatorSet commutes_with_SumToSize{{
1265  "aten::mul(Tensor self, Tensor other) -> Tensor",
1266  "aten::div(Tensor self, Tensor other) -> Tensor",
1267  // for div we might check whether we're the first argument
1268  "aten::mul(Tensor self, Scalar other) -> Tensor",
1269  "aten::div(Tensor self, Scalar other) -> Tensor",
1270  "aten::neg(Tensor self) -> Tensor",
1271  "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
1272  "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
1273  // add this used to be prim::AutogradAdd
1274  }};
1275 
1276  std::queue<Use> uses_to_process{};
1277  auto add_to_uses = [&](const use_list& uses) {
1278  for (auto u : uses) {
1279  uses_to_process.push(u);
1280  }
1281  };
1282  add_to_uses(gradSumToSizeOutput->uses());
1283  while (!uses_to_process.empty()) {
1284  auto user = uses_to_process.front().user;
1285  auto offset = uses_to_process.front().offset;
1286  uses_to_process.pop();
1287  if (user->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
1288  // sometimes, a mask or similar is cast to the same type as the gradient,
1289  // i.e. we see other. Then we don't need to do anything, as the shape is
1290  // not used, only the type..
1291  // But we might also see it as self, when the gradient is cast, then we
1292  // want to track it.
1293  if (offset == 0) {
1294  add_to_uses(user->output()->uses());
1295  }
1296  } else if (commutes_with_SumToSize.find(user)) {
1297  add_to_uses(user->output()->uses());
1298  } else if (user->kind() == prim::Return) {
1299  // During compilation and only if we don't already have a
1300  // _grad_sum_to_size for this output we record the size to sum the output
1301  // to. We only do this if we didn't see anything yet because we want later
1302  // (in the graph) nodes to take precedence over earlier ones and we
1303  // iterate backwards. The implicit assumption is that if we have several
1304  // _grad_sumtosizes "in parallel" (from auto-diff added AutogradAdd as the
1305  // backward of using an input in multiple places) they are the same. This
1306  // is because AutogradAdd does not broadcast.
1307  if (outputGradSumToSizes && (*outputGradSumToSizes)[offset] == -1) {
1308  // note: we make the assumption that the sizes are inputs to the
1309  // fusion group (rather than something calculated).
1310  (*outputGradSumToSizes)[offset] =
1311  gradSumToSizeOutput->node()->inputs()[1]->offset();
1312  }
1313  } else if (user->kind() == aten::_grad_sum_to_size) {
1314  // do nothing
1315  // this case only happens in the graph_fuser step because in the
1316  // compile step because we iterate backwards and delete
1317  // all _grad_sum_to_size nodes we see
1318  } else {
1319  // we find something we do not support. Note that this notably includes
1320  // prim::FusedConcat, which we do not know how to deal with in conjunction
1321  // with _grad_sum_to_size
1322  return false;
1323  }
1324  }
1325  return true;
1326 }
1327 
1328 void FuseGraph(std::shared_ptr<Graph>& graph) {
1329  if (canFuseOnCPU() || canFuseOnGPU()) {
1330  GraphFuser(graph->block(), graph).run();
1331  // After FuseGraph some common subexpressions may come back
1332  EliminateCommonSubexpression(graph);
1333  // We might have emitted a fair amount of useless shape propagating code, so
1334  // remove it
1335  EliminateDeadCode(graph);
1336  // Improve the quality of shape propagation code that was left
1337  PeepholeOptimizeShapeExpressions(graph->block());
1338  }
1339 }
1340 
1341 } // namespace jit
1342 } // namespace torch
This is a &#39;vector&#39; (really, a variable-sized array), optimized for the case when the array is small...
Definition: SmallVector.h:939
constexpr bool empty() const
empty - Check if the array is empty.
Definition: ArrayRef.h:129
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41