Caffe2 - C++ API
A deep learning, cross platform ML framework
alias_analysis.cpp
1 #include <torch/csrc/jit/passes/alias_analysis.h>
2 
3 #include <torch/csrc/jit/script/error_report.h>
4 #include <torch/csrc/utils/memory.h>
5 
6 namespace torch {
7 namespace jit {
8 
9 bool AliasDb::shouldAnnotate(const TypePtr& type) {
10  return type->isSubtypeOf(TensorType::get()) ||
11  type->kind() == TypeKind::ListType ||
12  type->kind() == TypeKind::TupleType ||
13  type->kind() == TypeKind::DictType || type->kind() == TypeKind::VarType ||
14  type->kind() == TypeKind::FutureType ||
15  type->kind() == TypeKind::ClassType ||
16  (type->kind() == TypeKind::OptionalType &&
17  shouldAnnotate(type->cast<OptionalType>()->getElementType()));
18 }
19 
20 // We only need to annotate values that either are mutable or could contain
21 // mutable types.
22 bool AliasDb::shouldAnnotate(const Value* v) {
23  return shouldAnnotate(v->type());
24 }
25 
26 AliasDb::~AliasDb() = default;
27 
28 AliasDb::AliasDb(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {
29  memoryDAG_ = torch::make_unique<MemoryDAG>();
30  analyze(graph_);
31 }
32 
33 // Does `n` use or write to any wildcard aliases?
34 bool AliasDb::hasWildcard(const Node* n) const {
35  for (const auto input : n->inputs()) {
36  if (isWildcard(input)) {
37  return true;
38  }
39  }
40  for (const auto output : n->outputs()) {
41  if (isWildcard(output)) {
42  return true;
43  }
44  }
45  return false;
46 }
47 
48 bool AliasDb::isWildcard(const Value* v) const {
49  return wildcards_.count(v);
50 }
51 
52 bool AliasDb::writesTo(Node* n, const Value* v) const {
53  if (!shouldAnnotate(v)) {
54  // This is a primitive type
55  return false;
56  }
57  if (isWildcard(v)) {
58  return wildcardWriters_.count(n);
59  }
60 
61  if (!elementMap_.count(v) || !writeIndex_.count(n)) {
62  return false;
63  }
64 
65  // Can short-circuit if we know this node writes directly to `v`
66  if (writeIndex_.at(n).count(v)) {
67  return true;
68  }
69 
70  // Otherwise, check if `v` may alias any of written-to values in `n`
71  const auto vSet = ValueSet{v};
72  return mayAlias(vSet, writeIndex_.at(n));
73 }
74 
75 bool AliasDb::hasWriters(const Node* n) const {
76  for (const auto input : n->inputs()) {
77  if (hasWriters(input)) {
78  return true;
79  }
80  }
81  for (const auto output : n->outputs()) {
82  if (hasWriters(output)) {
83  return true;
84  }
85  }
86  return false;
87 }
88 
89 bool AliasDb::hasWriters(const Value* v) const {
90  if (isWildcard(v)) {
91  // If `n` has a wildcard, any write in the graph may write to it.
92  // So the only way we know there are no writers is if there are no writes
93  // at all.
94  return numWrites_ != 0;
95  }
96 
97  if (!elementMap_.count(v)) {
98  return false;
99  }
100 
101  if (wildcardWriters_.size() > 0) {
102  // A write to the wildcard may be a write to any value.
103  return true;
104  }
105 
106  if (isWriteCacheStale_) {
107  rebuildWriteCache();
108  }
109 
110  for (const auto loc : elementMap_.at(v)->getMemoryLocations()) {
111  if (writeCache_.count(loc)) {
112  return true;
113  }
114  }
115 
116  return false;
117 }
118 
119 bool AliasDb::hasWrites(Node* n) const {
120  for (const auto input : n->inputs()) {
121  if (writesTo(n, input)) {
122  return true;
123  }
124  }
125  for (const auto output : n->outputs()) {
126  if (writesTo(n, output)) {
127  return true;
128  }
129  }
130  return false;
131 }
132 
133 bool AliasDb::writesToInputAlias(Node* n) const {
134  std::vector<const Value*> writes;
135  for (const auto input : n->inputs()) {
136  if (writesTo(n, input)) {
137  writes.push_back(input);
138  }
139  }
140  for (const auto output : n->outputs()) {
141  if (writesTo(n, output)) {
142  writes.push_back(output);
143  }
144  }
145 
146  // For all writes, check if the written value may alias a graph input
147  return std::any_of(writes.cbegin(), writes.cend(), [&](const Value* v) {
148  return std::any_of(
149  graph_->inputs().cbegin(),
150  graph_->inputs().cend(),
151  [&](const Value* graphInput) {
152  return shouldAnnotate(graphInput) && mayAlias(graphInput, v);
153  });
154  });
155 }
156 
157 void AliasDb::getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks) const {
158  for (auto node : b->nodes()) {
159  getWritesImpl(node, ret, recurseBlocks);
160  }
161 }
162 
163 void AliasDb::getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks) const {
164  for (const auto input : n->inputs()) {
165  if (writesTo(n, input)) {
166  ret.insert(input);
167  }
168  }
169  for (const auto output : n->outputs()) {
170  if (writesTo(n, output)) {
171  ret.insert(output);
172  }
173  }
174 
175  if (recurseBlocks) {
176  for (auto block : n->blocks()) {
177  getWritesImpl(block, ret, recurseBlocks);
178  }
179  }
180 }
181 
182 // Get all writes by all nodes in a block, recursively exploring sub-blocks
183 ValueSet AliasDb::getWrites(Block* b) const {
184  ValueSet writes;
185  getWritesImpl(b, writes, /*recurseBlocks=*/true);
186  return writes;
187 }
188 
189 // Does `n` write to an alias of one of the values in `vs`?
190 bool AliasDb::writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks)
191  const {
192  const auto writtenTo = getWrites(n, recurseBlocks);
193  return mayAlias(vs, writtenTo);
194 }
195 
196 std::unordered_set<const Value*> AliasDb::getWrites(Node* n, bool recurseBlocks)
197  const {
198  ValueSet writes;
199  getWritesImpl(n, writes, recurseBlocks);
200  return writes;
201 }
202 
203 void AliasDb::getReadsImpl(Node* n, ValueSet& ret, bool recurseBlocks) const {
204  for (const auto input : n->inputs()) {
205  ret.insert(input);
206  }
207  for (const auto output : n->outputs()) {
208  ret.insert(output);
209  }
210 
211  if (recurseBlocks) {
212  for (auto block : n->blocks()) {
213  for (auto node : block->nodes()) {
214  getReadsImpl(node, ret, recurseBlocks);
215  }
216  }
217  }
218 }
219 
220 ValueSet AliasDb::getReads(Node* n, bool recurseBlocks) const {
221  ValueSet reads;
222  getReadsImpl(n, reads, recurseBlocks);
223  return reads;
224 }
225 
226 void AliasDb::dump() const {
227  std::cout << "\n===1. GRAPH===\n";
228  graph_->dump();
229 
230  std::cout << "\n===2. ALIAS DB===\n";
231  for (const auto& ptrPair : elementMap_) {
232  const auto element = ptrPair.second;
233  if (element->pointsTo.size() > 0) {
234  std::cout << element->value->uniqueName() << " points to: ";
235  for (const auto pointedTo : element->pointsTo) {
236  std::cout << pointedTo->value->uniqueName() << ", ";
237  }
238  std::cout << "\n";
239  }
240  }
241 
242  std::cout << "\n===3. WILDCARDS===\n";
243  for (const auto wildcard : wildcards_) {
244  std::cout << wildcard->uniqueName() << ", ";
245  }
246  std::cout << "\n";
247 
248  std::cout << "\n===4. Writes===\n";
249  for (const auto& pr : writeIndex_) {
250  const auto node = pr.first;
251  const auto& values = pr.second;
252  std::cout << *node;
253  std::cout << " ";
254  for (const auto value : values) {
255  std::cout << value->uniqueName() << ", ";
256  }
257  std::cout << "\n";
258  }
259  std::cout << "\n";
260 }
261 
262 // TODO: need to create a dummy "graph input alias" value in MemoryDAG for all
263 // inputs of the same type to point to. Currently they all point to the first
264 // element, which is technically wrong.
265 void AliasDb::makeAllAlias(const std::vector<Value*>& values) {
266  if (values.size() > 0) {
267  giveFreshAlias(values[0]);
268  }
269  for (const auto value : values) {
270  makePointerTo(value, values[0]);
271  }
272 }
273 
274 void AliasDb::analyze(const std::shared_ptr<Graph>& graph) {
275  // Assign aliases to the graph's inputs, assuming that all inputs of a given
276  // type may alias to each other.
277 
278  // 1. Partition inputs by their type
279  std::map<TypeKind, std::vector<Value*>> listTypes;
280  std::unordered_map<TupleTypePtr, std::vector<Value*>> tupleTypes;
281  std::unordered_map<DictTypePtr, std::vector<Value*>> dictTypes;
282  std::unordered_map<ClassTypePtr, std::vector<Value*>> classTypes;
283  std::vector<Value*> tensors;
284 
285  for (auto input : graph->inputs()) {
286  auto inputType = input->type();
287  // unwrap optional types
288  if (inputType->kind() == TypeKind::OptionalType) {
289  inputType = inputType->cast<OptionalType>()->getElementType();
290  }
291 
292  if (inputType->isSubtypeOf(TensorType::get())) {
293  tensors.push_back(input);
294  } else if (inputType->kind() == TypeKind::ListType) {
295  auto containedType = inputType->containedTypes().at(0);
296  // All tensor subtypes may alias to each other, so we should consider all
297  // lists of them to alias to each other.
298  if (containedType->isSubtypeOf(TensorType::get())) {
299  containedType = TensorType::get();
300  }
301  listTypes[containedType->kind()].push_back(input);
302  } else if (inputType->kind() == TypeKind::TupleType) {
303  auto tupleType = inputType->cast<TupleType>();
304  tupleTypes[tupleType].push_back(input);
305  } else if (inputType->kind() == TypeKind::DictType) {
306  auto dictType = inputType->cast<DictType>();
307  dictTypes[dictType].push_back(input);
308  } else if (inputType->kind() == TypeKind::ClassType) {
309  auto classType = inputType->cast<ClassType>();
310  classTypes[classType].push_back(input);
311  } else {
312  AT_ASSERT(!shouldAnnotate(input));
313  }
314  }
315 
316  // 2. Make all partitions alias each other
317  for (const auto& pr : listTypes) {
318  makeAllAlias(pr.second);
319  }
320  for (const auto& pr : tupleTypes) {
321  makeAllAlias(pr.second);
322  }
323  for (const auto& pr : dictTypes) {
324  makeAllAlias(pr.second);
325  }
326  for (const auto& pr : classTypes) {
327  makeAllAlias(pr.second);
328  }
329  makeAllAlias(tensors);
330 
331  analyze(graph->block());
332 }
333 
334 void AliasDb::analyze(Block* block) {
335  for (auto node : block->nodes()) {
336  analyze(node);
337  }
338 }
339 
340 void AliasDb::analyze(Node* node) {
341  analyzeImpl(node);
342 
343  // After analyzing, update the wildcard index
344  if (hasWildcard(node)) {
345  wildcardNodes_.insert(node);
346  }
347 }
348 
349 // The basic strategy is:
350 // 1. Retrieve alias information for every input.
351 // 2. Use the node's schema's alias annotations to propgagate alias/write
352 // information to the outputs. For unschematized nodes, a special analyzer
353 // will have to be handwritten.
354 void AliasDb::analyzeImpl(Node* node) {
355  // These nodes are not schematized, so we need to handle them specially
356  switch (node->kind()) {
357  case prim::If:
358  return analyzeIf(node);
359  case prim::Loop:
360  return analyzeLoop(node);
361  case prim::FusionGroup:
362  case prim::DifferentiableGraph:
363  return analyzeSubgraph(node);
364  case prim::fork:
365  return analyzeFork(node);
366  case aten::wait:
367  return analyzeWait(node);
368  case prim::Constant:
369  case prim::DictConstruct:
370  case prim::ListConstruct:
371  case prim::TupleConstruct:
372  case prim::AutogradZero:
373  case prim::FusedConcat:
374  case prim::MMTreeReduce:
375  case prim::MMBatchSide:
376  case prim::BroadcastSizes:
377  case prim::ChunkSizes:
378  case prim::Function:
379  case prim::CreateObject:
380  return analyzeCreator(node);
381  case prim::TupleUnpack:
382  case prim::TupleIndex:
383  case prim::DictIndex:
384  case prim::TupleSlice:
385  case prim::ListUnpack:
386  case prim::PythonOp:
387  case prim::GetAttr:
388  return analyzeExtractor(node);
389  case prim::ConstantChunk:
390  return analyzeChunk(node);
391  case prim::BroadcastingChunk:
392  return analyzeBroadcastingChunk(node);
393  case prim::SetAttr:
394  return analyzeSetAttr(node);
395  case aten::add:
396  case aten::sub:
397  case aten::mul:
398  case aten::div: {
399  // This is necessary because we sometimes get unschematized combinations
400  // of Tensor/primitive.
401  auto maybeSchema = node->maybeSchema();
402  if (!maybeSchema) {
403  return analyzeCreator(node);
404  }
405  // If the node has a schema, fall through and analyze it normally
406  break;
407  }
408  case prim::Print:
409  // These ops do nothing
410  return;
411  default:
412  AT_ASSERT(!aliasAnalysisHasSpecialCaseFor(node->kind()));
413  }
414 
415  const auto& schema = node->schema();
416  if (schema.is_vararg() || schema.is_varret()) {
417  const auto hasMutableOutputs = std::any_of(
418  node->outputs().cbegin(),
419  node->outputs().cend(),
420  [](const Value* output) { return shouldAnnotate(output); });
421 
422  // We don't have alias info for this node. Either schematize it, or
423  // add it an analyze* method for it.
424  if (hasMutableOutputs) {
425  throw script::ErrorReport(node->getSourceLocation())
426  << "Alias information not found for node. File a bug report.\n"
427  << "Node: " << *node << "\n";
428  }
429  }
430 
431  // Bind formal alias annotation to actual alias sets
432  std::unordered_map<Symbol, Value*> formalToActual;
433  for (size_t i = 0; i < schema.arguments().size(); i++) {
434  const auto& formal = schema.arguments()[i].alias_info();
435  const auto& actualValue = node->inputs().at(i);
436  // Skip if there's no alias annotation
437  if (!formal) {
438  continue;
439  }
440 
441  // If this type cannot alias, continue. Can occur with a VarType schema
442  if (!shouldAnnotate(actualValue)) {
443  continue;
444  }
445 
446  // We don't support composite types for alias analysis yet.
447  AT_ASSERT(formal->containedTypes().size() == 0);
448  // TODO neither unions nor wildcards make sense on an input. We should
449  // disallow them in function schema
450  AT_ASSERT(!formal->isWildcard())
451  const auto& formalAlias = formal->beforeSet();
452 
453  // skip if we've already bound this alias
454  if (formalToActual.count(formalAlias) != 0) {
455  continue;
456  }
457 
458  // Bind the formal to the actual
459  formalToActual[formalAlias] = actualValue;
460 
461  // Record writes
462  if (formal->isWrite()) {
463  registerWrite(actualValue, node);
464  }
465  }
466 
467  // Use the formal-actual mapping to give aliases to the outputs
468  for (size_t i = 0; i < schema.returns().size(); i++) {
469  const auto actual = node->outputs().at(i);
470  const auto& formal = schema.returns()[i].alias_info();
471  if (!formal) {
472  // This is a fresh tensor
473  giveFreshAlias(actual);
474  continue;
475  }
476 
477  // If this type cannot alias, continue. Can occur with a VarType schema
478  if (!shouldAnnotate(actual)) {
479  continue;
480  }
481 
482  // We don't support composite types for alias analysis yet.
483  AT_ASSERT(formal->containedTypes().size() == 0);
484 
485  if (formal->isWildcard()) {
486  setWildcard(actual);
487  continue;
488  }
489 
490  for (const auto& formalAlias : formal->beforeSets()) {
491  // If we encounter an alias annotation that wasn't in the inputs:
492  if (!formalToActual.count(formalAlias)) {
493  // If this alias is not seen elsewhere and is the only annotation on
494  // the output, it's equivalent to being fresh:
495  // e.g. foo(Tensor(a) self) -> Tensor(b)
496  if (formal->beforeSets().size() == 1) {
497  giveFreshAlias(actual);
498  }
499  // Or it is the form of a|fresh, which we can ignore, taking the
500  // conservative assumption that the output must alias `a`, e.g
501  // aten::cuda(Tensor(a) self) -> Tensor(a|fresh)
502 
503  // Don't assign an alias set in that case.
504  continue;
505  }
506 
507  auto toAlias = formalToActual.at(formalAlias);
508  makePointerTo(actual, toAlias);
509  }
510 
511  // Record writes
512  if (formal->isWrite()) {
513  registerWrite(actual, node);
514  }
515  }
516 }
517 // Register the fact that `n` writes to `v`.
518 void AliasDb::registerWrite(const Value* v, Node* n) {
519  numWrites_++;
520 
521  if (isWildcard(v)) {
522  wildcardWriters_.insert(n);
523  return;
524  }
525 
526  AT_ASSERT(elementMap_.count(v));
527  writeIndex_[n].insert(v);
528 }
529 
530 void AliasDb::analyzeIf(Node* node) {
531  // For if statements, the alias set of an output is the union of the
532  // alias sets generated by the if and else block
533  const auto trueBlock = node->blocks().at(0);
534  const auto falseBlock = node->blocks().at(1);
535  analyze(trueBlock);
536  analyze(falseBlock);
537 
538  for (size_t i = 0; i < node->outputs().size(); i++) {
539  const auto nodeOutput = node->outputs()[i];
540 
541  const auto trueOutput = trueBlock->outputs().at(i);
542  const auto falseOutput = falseBlock->outputs().at(i);
543 
544  makePointerTo(nodeOutput, trueOutput);
545  makePointerTo(nodeOutput, falseOutput);
546  }
547 }
548 
549 void AliasDb::analyzeLoop(Node* node) {
550  const auto bodyBlock = node->blocks().at(0);
551  const auto loopCarriedInputs = node->inputs().slice(2); // skip max, cond
552  const auto blockInputs = bodyBlock->inputs().slice(1); // skip trip
553  const auto blockOutputs = bodyBlock->outputs().slice(1); // skip trip
554  AT_ASSERT(loopCarriedInputs.size() == blockInputs.size());
555  AT_ASSERT(blockOutputs.size() == node->outputs().size());
556 
557  // Run alias analysis on the loop body, iterating until the block output
558  // alias info converges.
559  // Copy node input aliases to block input
560  mapAliases(blockInputs, loopCarriedInputs);
561 
562  // Populate block output alias info by analyzing the body
563  analyze(bodyBlock);
564 
565  // Copy the alias info from the block output to the node output
566  mapAliases(node->outputs(), blockOutputs);
567 }
568 
569 void AliasDb::analyzeSubgraph(Node* node) {
570  const auto subgraph = node->g(attr::Subgraph).get();
571 
572  subgraphToOwner_.insert({subgraph, node});
573 
574  const auto subgraphBlock = subgraph->block();
575  mapAliases(subgraphBlock->inputs(), node->inputs());
576 
577  analyze(subgraphBlock);
578 
579  // TODO(suo): the subgraph outputs and node outputs are NOT NECESSARILY the
580  // same length. Autodifferentiation maybe capture additional outputs in the
581  // subgraph block.
582  AT_ASSERT(subgraphBlock->outputs().size() >= node->outputs().size());
583  for (size_t i = 0; i < node->outputs().size(); i++) {
584  makePointerTo(node->outputs()[i], subgraphBlock->outputs()[i]);
585  }
586 }
587 
588 // For nodes that generate a fresh value from nothing
589 void AliasDb::analyzeCreator(Node* node) {
590  for (Value* output : node->outputs()) {
591  giveFreshAlias(output);
592  }
593 }
594 
595 // For nodes that extract values from a composite type. Right now, this just
596 // gives up and creates wildcards for everything.
597 void AliasDb::analyzeExtractor(Node* node) {
598  for (const auto output : node->outputs()) {
599  if (shouldAnnotate(output)) {
600  setWildcard(output);
601  }
602  }
603 }
604 
605 // For torch.chunk(), all returned tensors may alias the input tensor
606 void AliasDb::analyzeChunk(Node* node) {
607  for (auto output : node->outputs()) {
608  makePointerTo(output, node->input());
609  }
610 }
611 
612 // Propagate aliasing and write information from the subgraph outputs to the
613 // outputs of the corresponding aten::wait() calls, since that's where the
614 // values will eventually emerge.
615 void AliasDb::analyzeFork(Node* node) {
616  const auto subgraph = node->g(attr::Subgraph).get();
617  subgraphToOwner_.insert({subgraph, node});
618 
619  const auto subgraphBlock = subgraph->block();
620  mapAliases(subgraphBlock->inputs(), node->inputs());
621  analyze(subgraphBlock);
622 
623  // Give the future that the fork emits a fresh value
624  for (const auto output : node->outputs()) {
625  giveFreshAlias(output);
626  }
627 }
628 
629 void AliasDb::analyzeWait(Node* node) {
630  const auto fut = node->input();
631  AT_ASSERT(fut->type()->kind() == TypeKind::FutureType);
632 
633  if (isWildcard(fut)) {
634  for (const auto output : node->outputs()) {
635  setWildcard(output);
636  }
637  return;
638  }
639 
640  const auto originFuts = getMemoryLocations(fut);
641  for (const auto originFut : originFuts) {
642  const auto subgraphNode = originFut->node();
643 
644  const auto subgraph = subgraphNode->g(attr::Subgraph).get();
645  const auto subgraphWrites = getWrites(subgraph->block());
646 
647  // Retrieve aliasing info from the subgraph
648  mapAliases(node->outputs(), subgraph->outputs());
649 
650  // Propagate write information to the `wait` node.
651  //
652  // We need to do this for all writes in the entire subgraph, so that we
653  // disallow reorders past a call to "aten::wait".
654  //
655  // Consider the following Fork where the subgraph writes to %a:
656  //
657  // %c : Future[Tensor] = prim::Fork(%a, %b) <-- writes to %a
658  // ...
659  // aten::wait(%c)
660  // aten::use(%a) <-- we can't move this node before the `wait` safely!
661  //
662  // Say we define the "live interval" of a fork the interval between the
663  // `fork` and its first corresponding `wait` (inclusive).
664  //
665  // Any writes in the subgraph can happen at any point in the live interval,
666  // so it's not safe to re-order any reads to those memory locations from
667  // outside the live interval to inside.
668  //
669  // In reality, any reads *inside* the live interval are undefined behavior,
670  // since the writes may or may not have been executed yet. But we'll let
671  // users do that and shoot themselves in the foot for now.
672  for (const auto write : subgraphWrites) {
673  registerWrite(write, node);
674  }
675  }
676 }
677 
678 // SetAttr: writes to the `self` field
679 void AliasDb::analyzeSetAttr(Node* node) {
680  const auto self = node->inputs().at(0);
681  AT_ASSERT(self->type()->kind() == TypeKind::ClassType);
682  registerWrite(self, node);
683 }
684 
685 // BroadcastingChunk: all inputs are broadcasted, and then individually chunked.
686 // This is an intermediate node used only in the graph fuser.
687 void AliasDb::analyzeBroadcastingChunk(Node* node) {
688  auto inputs = node->inputs();
689  auto outputs = node->outputs();
690  auto nchunks = node->i(attr::chunks);
691  for (size_t index = 0; index < inputs.size(); ++index) {
692  // Each inputs[i] is aliased by exactly `nchunks` distinct output tensors:
693  // inputs[i] produces chunks outputs[i * nchunks + k] for k in [0..nchunks)
694  auto output_begin = outputs.begin() + index * nchunks;
695  for (auto it = output_begin; it != output_begin + nchunks; ++it) {
696  makePointerTo(*it, inputs.at(index));
697  }
698  }
699 }
700 
701 // Register the fact that `value` is a pointer to `to`
702 void AliasDb::makePointerTo(const Value* from, const Value* to) {
703  if (!shouldAnnotate(from)) {
704  AT_ASSERT(!shouldAnnotate(to));
705  return;
706  }
707 
708  if (from == to) {
709  return;
710  }
711 
712  // Special case: if `from` is an optional, `to` could be a None. Don't
713  // create a pointer in that case
714  if (from->type()->kind() == TypeKind::OptionalType &&
715  to->type()->kind() == TypeKind::NoneType) {
716  return;
717  }
718 
719  // At this point, we should be dealing with two mutable types.
720  AT_ASSERT(shouldAnnotate(from) && shouldAnnotate(to));
721 
722  // If either value is a wildcard, don't insert anything into the graph;
723  // wildcards are tracked separately since they have different aliasing rules.
724  if (isWildcard(to) || isWildcard(from)) {
725  setWildcard(from);
726  return;
727  }
728 
729  if (!isTracked(from)) {
730  giveFreshAlias(from);
731  }
732  if (!isTracked(to)) {
733  giveFreshAlias(to);
734  }
735  auto fromEl = elementMap_.at(from);
736  auto toEl = elementMap_.at(to);
737  memoryDAG_->makePointerTo(fromEl, toEl);
738 }
739 
740 bool AliasDb::mayAlias(const Value* a, const Value* b) const {
741  if (isWildcard(a) || isWildcard(b)) {
742  return true;
743  }
744 
745  if (!elementMap_.count(a) || !elementMap_.count(b)) {
746  return false;
747  }
748 
749  return memoryDAG_->mayAlias(elementMap_.at(a), elementMap_.at(b));
750 }
751 
752 // Make each value in the `from` list point to its partner in the `to` list
753 void AliasDb::mapAliases(at::ArrayRef<Value*> from, at::ArrayRef<Value*> to) {
754  AT_ASSERT(to.size() == from.size());
755  for (size_t i = 0; i < to.size(); i++) {
756  makePointerTo(from[i], to[i]);
757  }
758 }
759 
760 void AliasDb::giveFreshAlias(const Value* value) {
761  if (!shouldAnnotate(value)) {
762  return;
763  }
764 
765  if (isTracked(value)) {
766  // Inside a loop, we may have given a fresh alias to this value already, so
767  // skip
768  return;
769  }
770 
771  elementMap_[value] = memoryDAG_->makeFreshValue(value);
772 }
773 
774 bool AliasDb::isTracked(const Value* v) const {
775  return isWildcard(v) || elementMap_.count(v);
776 }
777 
778 bool AliasDb::moveAfterTopologicallyValid(Node* n, Node* movePoint) {
779  return tryMove(n, movePoint, MoveSide::AFTER, /*dryRun=*/false);
780 }
781 
782 bool AliasDb::couldMoveAfterTopologically(Node* n, Node* movePoint) {
783  return tryMove(n, movePoint, MoveSide::AFTER, /*dryRun=*/true);
784 }
785 
786 bool AliasDb::moveBeforeTopologicallyValid(Node* n, Node* movePoint) {
787  // We have to distinguish the move side (instead of just moving after
788  // n->prev()). Consider the following example:
789  // If the dependency graph looks like
790  // n -> movePoint -> o
791  // then moveBefore(o) will end up with
792  // n, o, movePoint
793  // but moveAfter(n) will return false.
794  return tryMove(n, movePoint, MoveSide::BEFORE, /*dryRun=*/false);
795 }
796 
797 bool AliasDb::couldMoveBeforeTopologically(Node* n, Node* movePoint) {
798  return tryMove(n, movePoint, MoveSide::BEFORE, /*dryRun=*/true);
799 }
800 
801 // Helper for topologically-safe node moves. See `tryMove()` for details.
803  public:
804  explicit WorkingSet(Node* mover, const AliasDb& aliasDb) : aliasDb_(aliasDb) {
805  add(mover);
806  }
807 
808  // Add `n` to the working set
809  void add(Node* n) {
810  nodes_.push_back(n);
811  for (const auto user : getUsersSameBlock(n)) {
812  users_.insert(user);
813  }
814 
815  for (const auto& write : aliasDb_.getWrites(n, /*recurseBlocks=*/true)) {
816  writes_.insert(write);
817  }
818  for (const auto& read : aliasDb_.getReads(n, /*recurseBlocks=*/true)) {
819  reads_.insert(read);
820  }
821  if (aliasDb_.hasWildcard(n)) {
822  numWildcards_++;
823  }
824  }
825 
826  void eraseMover() {
827  auto mover = nodes_.front();
828  for (const auto user : getUsersSameBlock(mover)) {
829  const auto it = users_.find(user);
830  if (it != users_.end()) {
831  users_.erase(it);
832  }
833  }
834 
835  for (const auto& write :
836  aliasDb_.getWrites(mover, /*recurseBlocks=*/true)) {
837  const auto it = writes_.find(write);
838  if (it != writes_.end()) {
839  writes_.erase(it);
840  }
841  }
842  for (const auto& read : aliasDb_.getReads(mover, /*recurseBlocks=*/true)) {
843  const auto it = reads_.find(read);
844  if (it != reads_.end()) {
845  reads_.erase(it);
846  }
847  }
848  if (aliasDb_.hasWildcard(mover)) {
849  numWildcards_--;
850  }
851  nodes_.pop_front();
852  }
853 
854  const std::list<Node*>& nodes() {
855  return nodes_;
856  }
857 
858  // Does the working set depend on `n`?
859  bool dependsOn(Node* n) const {
860  if (nodes_.empty()) {
861  return false;
862  }
863 
864  return hasDataDependency(n) || hasMutabilityDependency(n);
865  }
866 
867  private:
868  bool hasDataDependency(Node* n) const {
869  if (n->isAfter(nodes_.front())) {
870  return producesFor(n);
871  } else {
872  return consumesFrom(n);
873  }
874  }
875 
876  bool hasMutabilityDependency(Node* n) const {
877  // 1. Handle wildcard dependencies:
878  // If the working set has a wildcard, `n` can't write to anything.
879  if (numWildcards_ > 0 && aliasDb_.hasWrites(n)) {
880  return true;
881  }
882 
883  // If `n` has a wildcard, the working set can't write to anything.
884  if (aliasDb_.hasWildcard(n) && writes_.size() > 0) {
885  return true;
886  }
887 
888  // 2. Handle regular mutable dependencies
889  // Check that `n` does not write to anything used by the working set
890  const auto nWrites = aliasDb_.getWrites(n, /*recurseBlocks=*/true);
891  if (aliasDb_.mayAlias(nWrites, reads_)) {
892  return true;
893  }
894 
895  // Check that the working set doesn't write to anything that `n` uses.
896  const auto nReads = aliasDb_.getReads(n, /*recurseBlocks=*/true);
897  if (aliasDb_.mayAlias(writes_, nReads)) {
898  return true;
899  }
900  return false;
901  }
902 
903  // Does the working set produce any values consumed by `n`?
904  bool producesFor(Node* n) const {
905  // This equivalent to asking: does the total use-set of all the nodes in the
906  // working set include `n`?
907  return users_.count(n) != 0;
908  }
909 
910  // Does the working set consume any values produced by `n`?
911  bool consumesFrom(Node* n) const {
912  const auto users = getUsersSameBlock(n);
913  return std::any_of(nodes_.begin(), nodes_.end(), [&](Node* node) {
914  return users.count(node) != 0;
915  });
916  }
917 
918  // Get all users of outputs of `n`, in the same block as `n`.
919  // This means if there is an `if` node that uses an output of `n` in some
920  // inner sub-block, we will consider the whole `if` node a user of `n`.
921  std::unordered_set<Node*> getUsersSameBlock(Node* n) const {
922  std::unordered_set<Node*> users;
923  for (const auto output : n->outputs()) {
924  for (const auto& use : output->uses()) {
925  if (auto sameBlock = findSameBlock(use.user, n)) {
926  users.insert(sameBlock);
927  }
928  }
929  }
930  return users;
931  }
932 
933  // Traverse `target`'s blockchain upward until we find a node that shares a
934  // block with `n`.
935  //
936  // If one can't be found (say, because `n` is an inner block and target is
937  // outside), then return nullptr. Since we can only reorder nodes within a
938  // block, `target` would be irrelevant.
939  static Node* findSameBlock(Node* target, Node* n) {
940  AT_ASSERT(target->owningGraph() == n->owningGraph());
941  if (target->owningBlock() == n->owningBlock()) {
942  return target;
943  } else {
944  // This user is in a sub-block. Traverse the blockchain upward until
945  // we arrive at a node that shares a block with `this`
946  auto curNode = target;
947  while (curNode->owningBlock() != n->owningBlock()) {
948  curNode = curNode->owningBlock()->owningNode();
949  if (curNode == nullptr) {
950  return curNode;
951  }
952  }
953  return curNode;
954  }
955  }
956 
957  const AliasDb& aliasDb_;
958  std::list<Node*> nodes_;
959  // users => # of working set nodes it uses
960  std::unordered_multiset<Node*> users_;
961  // Values written to by the working set => number of nodes writing to value
962  std::unordered_multiset<const Value*> writes_;
963  std::unordered_multiset<const Value*> reads_;
964  size_t numWildcards_ = 0;
965 };
966 
967 // Try to move `toMove` before/after `movePoint` while preserving value
968 // dependencies. Returns false iff such a move could not be made.
969 //
970 // If `dryRun` is set, don't actually execute the move, just check if the move
971 // is possible
972 //
973 // The basic approach is: have a "working set" that we are moving forward, one
974 // node at a time. When we can't move past a node (because it depends on the
975 // working set), then add it to the working set and keep moving until we hit
976 // `moveAfter`.
977 bool AliasDb::tryMove(
978  Node* toMove,
979  Node* movePoint,
980  MoveSide moveSide,
981  bool dryRun) {
982  AT_ASSERT(toMove->owningBlock() == movePoint->owningBlock());
983  if (toMove == movePoint) {
984  return true;
985  }
986 
987  // 1. Move from `this` toward movePoint, building up the working set of
988  // dependencies
989  WorkingSet workingSet(toMove, *this);
990 
991  int direction;
992  if (toMove->isAfter(movePoint)) {
993  direction = kPrevDirection;
994  } else {
995  direction = kNextDirection;
996  }
997 
998  auto curNode = toMove->next_in_graph[direction];
999  // Move forward one node at a time
1000  while (curNode != movePoint) {
1001  if (workingSet.dependsOn(curNode)) {
1002  // If we can't move past this node, add it to the working set
1003  workingSet.add(curNode);
1004  }
1005  curNode = curNode->next_in_graph[direction];
1006  }
1007 
1008  // 2. Decide whether we can move it all to `movePoint`.
1009 
1010  // Say we are moving directly before movePoint and `toMove` starts before
1011  // movePoint in the graph. The move looks like
1012  //
1013  // `toMove` `toMove` |
1014  // <dependencies> -> `movePoint` | `toMove` and deps are split
1015  // `movePoint` <dependencies> |
1016  //
1017  // Contrast with the case where `toMove` starts AFTER movePoint:
1018  //
1019  // `movePoint` <dependencies> |
1020  // <dependencies> -> `toMove` | `toMove` and deps are together
1021  // `toMove` `movePoint` |
1022  //
1023  // In the first case, we need to split `this` off from its dependencies, so we
1024  // can move the dependencies below `movePoint` and keep `toMove` above.
1025  const bool splitToMoveAndDeps =
1026  (moveSide == MoveSide::BEFORE && toMove->isBefore(movePoint)) ||
1027  (moveSide == MoveSide::AFTER && toMove->isAfter(movePoint));
1028 
1029  if (splitToMoveAndDeps) {
1030  // remove `this` from dependencies to be moved past `movePoint`
1031  workingSet.eraseMover();
1032  }
1033 
1034  // Check if we can move the working set past the move point
1035  if (workingSet.dependsOn(movePoint)) {
1036  // if we can't, then there are intermediate dependencies between the
1037  // `this` and `movePoint`, so we can't do the move
1038  return false;
1039  }
1040 
1041  if (dryRun) {
1042  return true;
1043  }
1044 
1045  // 3. Execute the move
1046  AT_ASSERT(curNode == movePoint);
1047  if (splitToMoveAndDeps) {
1048  // Move `toMove`
1049  move(toMove, movePoint, moveSide);
1050 
1051  // Then move all of its dependencies on the other side of `movePoint`
1052  const auto reversed =
1053  moveSide == MoveSide::BEFORE ? MoveSide::AFTER : MoveSide::BEFORE;
1054  for (auto n : workingSet.nodes()) {
1055  move(n, curNode, reversed);
1056  curNode = n;
1057  }
1058  } else {
1059  // Just append/prepend everything to `movePoint`
1060  for (auto n : workingSet.nodes()) {
1061  move(n, curNode, moveSide);
1062  curNode = n;
1063  }
1064  }
1065  return true;
1066 }
1067 
1068 // Helper function so we can generalize `tryMove`
1069 void AliasDb::move(Node* toMove, Node* movePoint, MoveSide moveSide) {
1070  switch (moveSide) {
1071  case MoveSide::BEFORE:
1072  toMove->moveBefore(movePoint);
1073  break;
1074  case MoveSide::AFTER:
1075  toMove->moveAfter(movePoint);
1076  break;
1077  }
1078 }
1079 
1080 bool AliasDb::hasUntrackedEffects(Node* node) const {
1081  bool touchesWildcard = false;
1082  if (const auto lastWildcard = getLastWildcard()) {
1083  touchesWildcard = hasWrites(node) &&
1084  (isBeforeSameGraph(node, *lastWildcard) || node == *lastWildcard);
1085  }
1086 
1087  return writesToInputAlias(node) || touchesWildcard;
1088 }
1089 
1090 // Nodes must be in the same graph in order to do `isBefore` or `isAfter`. This
1091 // traverses the subgraph "chain" upward until we find two nodes that share an
1092 // owning graph.
1093 //
1094 // NOTE: this is n^2 in subgraph depth. Right now the maximum depth is like 2,
1095 // but if we ever do huge nested subgraphs we'll need to reconsider this.
1096 bool AliasDb::isBeforeSameGraph(const Node* a, const Node* b) const {
1097  auto lhs = a;
1098  while (true) {
1099  auto rhs = b;
1100  while (true) {
1101  if (lhs->owningGraph() == rhs->owningGraph()) {
1102  return lhs->isBefore(rhs);
1103  }
1104  if (!subgraphToOwner_.count(rhs->owningGraph())) {
1105  break;
1106  }
1107  rhs = subgraphToOwner_.at(rhs->owningGraph());
1108  }
1109  if (!subgraphToOwner_.count(lhs->owningGraph())) {
1110  break;
1111  }
1112  lhs = subgraphToOwner_.at(lhs->owningGraph());
1113  }
1114  AT_ASSERT(false);
1115 }
1116 
1117 c10::optional<const Node*> AliasDb::getLastWildcard() const {
1118  auto it = std::max_element(
1119  wildcardNodes_.cbegin(),
1120  wildcardNodes_.cend(),
1121  [this](const Node* a, const Node* b) { return isBeforeSameGraph(a, b); });
1122  if (it != wildcardNodes_.end()) {
1123  return *it;
1124  } else {
1125  return c10::nullopt;
1126  }
1127 }
1128 
1129 TORCH_API bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
1130  // WARNING: by adding a case to this list, you are asserting that you have
1131  // added a case for the unschematized node in AliasDb::analyze
1132  const static std::unordered_set<Symbol> handled = {
1133  prim::If,
1134  prim::Loop,
1135  prim::FusionGroup,
1136  prim::DifferentiableGraph,
1137  prim::Constant,
1138  prim::DictConstruct,
1139  prim::ListConstruct,
1140  prim::TupleConstruct,
1141  prim::AutogradZero,
1142  prim::FusedConcat,
1143  prim::MMTreeReduce,
1144  prim::MMBatchSide,
1145  prim::BroadcastSizes,
1146  prim::ChunkSizes,
1147  prim::Function,
1148  prim::TupleUnpack,
1149  prim::TupleIndex,
1150  prim::DictIndex,
1151  prim::TupleSlice,
1152  prim::ListUnpack,
1153  prim::PythonOp,
1154  prim::ConstantChunk,
1155  prim::BroadcastingChunk,
1156  prim::fork,
1157  prim::CreateObject,
1158  prim::GetAttr,
1159  prim::SetAttr,
1160  aten::wait,
1161  aten::add,
1162  aten::sub,
1163  aten::mul,
1164  aten::div,
1165  };
1166 
1167  // Operators that should not be used by alias analysis
1168  const static std::unordered_set<Symbol> purposefully_not_handled = {
1169  prim::Print,
1170  prim::Load,
1171  prim::Store,
1172  prim::Drop,
1173  at::onnx::Reshape,
1174  at::onnx::Shape,
1175  prim::AutogradAnyNonZero,
1176  prim::AutogradAdd,
1177  };
1178 
1179  return handled.count(symbol) || purposefully_not_handled.count(symbol);
1180 }
1181 
1182 // Register `v` as a wildcard value.
1183 void AliasDb::setWildcard(const Value* v) {
1184  if (!shouldAnnotate(v)) {
1185  return;
1186  }
1187  wildcards_.insert(v);
1188 }
1189 
1190 void AliasDb::rebuildWriteCache() const {
1191  for (const auto& pr : writeIndex_) {
1192  const auto& writtenValues = pr.second;
1193 
1194  for (const auto value : writtenValues) {
1195  for (const auto loc : elementMap_.at(value)->getMemoryLocations()) {
1196  writeCache_.insert(loc);
1197  }
1198  }
1199  }
1200  isWriteCacheStale_ = false;
1201 }
1202 
1203 ValueSet AliasDb::getMemoryLocations(const Value* v) const {
1204  ValueSet ret;
1205  if (!elementMap_.count(v)) {
1206  return ret;
1207  }
1208 
1209  for (const auto el : elementMap_.at(v)->getMemoryLocations()) {
1210  ret.insert(el->value);
1211  }
1212  return ret;
1213 }
1214 } // namespace jit
1215 } // namespace torch
Alias analysis pass.
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41