Caffe2 - C++ API
A deep learning, cross platform ML framework
alias_analysis.h
1 #pragma once
2 
3 #include <torch/csrc/jit/alias_info.h>
4 #include <torch/csrc/jit/ir.h>
5 #include <torch/csrc/jit/passes/utils/memory_dag.h>
6 
7 namespace torch {
8 namespace jit {
9 
28 class AliasDb {
29  public:
30  TORCH_API explicit AliasDb(std::shared_ptr<Graph> graph);
31  TORCH_API ~AliasDb();
32 
33  // There are limitations to what effects the alias analysis can track. Two
34  // kinds of nodes may have untracked effects:
35  // 1. Nodes that write to a value that may alias the graph inputs (since
36  // the inputs can be used outside the graph).
37  // 2. Nodes that write to something in the wildcard set.
38  //
39  // These nodes are considered not safe to eliminate or mutate under any
40  // circumstances.
41  bool hasUntrackedEffects(Node* n) const;
42 
43  // Does `n` write to an alias of one of the values in `vs`?
44  // if `recurseBlocks` is true, consider writes on the nodes in `n`s sub-blocks
45  bool writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks = false)
46  const;
47 
48  // Do `a` and `b` potentially share a memory location?
49  bool mayAlias(const Value* a, const Value* b) const;
50  // Do any values in group `a` potentially share a memory location with any
51  // value in group `b`? i.e. may they overlap?
52  //
53  // NOTE: Bit of ugly templating, but this is just to make sure we can
54  // transform an arbitrary container of `Values` to the same container of
55  // `Elements`.
56  template <
57  typename... Other1,
58  template <typename, typename...> class T,
59  typename... Other2,
60  template <typename, typename...> class U>
61  bool mayAlias(
63  const U<const Value*, Other2...>& b) const {
64  if (a.empty() || b.empty()) {
65  return false;
66  }
67  // Short-circuit for special case: if any value is a wildcard, the two sets
68  // may alias
69  if (std::any_of(
70  a.cbegin(),
71  a.cend(),
72  [this](const Value* v) { return isWildcard(v); }) ||
73  std::any_of(b.cbegin(), b.cend(), [this](const Value* v) {
74  return isWildcard(v);
75  })) {
76  return true;
77  }
78 
79  T<Element*> aElements;
80  for (const Value* v : a) {
81  if (elementMap_.count(v)) {
82  aElements.insert(elementMap_.at(v));
83  }
84  }
85 
86  U<Element*> bElements;
87  for (const Value* v : b) {
88  if (elementMap_.count(v)) {
89  bElements.insert(elementMap_.at(v));
90  }
91  }
92 
93  return memoryDAG_->mayAlias(aElements, bElements);
94  }
95 
96  // Do any nodes write to an alias set inputed/outputed by `n`?
97  bool hasWriters(const Node* n) const;
98 
99  // Move 'n' (already in the graph) after 'movePoint' in the topological order.
100  //
101  // Tries to preserve value dependencies, so other nodes might be moved. We
102  // make two gurantees about the postcondition of the node list:
103  // - `n` is directly after `movePoint`.
104  // - only nodes between `n` and `movePoint` have been moved.
105  //
106  // Returns `false` if it's impossible to move `n` after `MovePoint` without
107  // violating dependencies, otherwise executes the move and returns `true`
108  bool moveAfterTopologicallyValid(Node* n, Node* movePoint);
109  bool moveBeforeTopologicallyValid(Node* n, Node* movePoint);
110 
111  bool couldMoveAfterTopologically(Node* n, Node* movePoint);
112  bool couldMoveBeforeTopologically(Node* n, Node* movePoint);
113 
114  // For debugging: print alias db state to stdout
115  TORCH_API void dump() const;
116 
117  private:
118  // Helper for topologically-safe node moves.
119  class WorkingSet;
120  enum class MoveSide { BEFORE, AFTER };
121  bool tryMove(Node* toMove, Node* movePoint, MoveSide moveSide, bool dryRun);
122  void move(Node* toMove, Node* movePoint, MoveSide moveSide);
123  bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
124 
128  // Does `n` write to any alias sets?
129  bool hasWrites(Node* n) const;
130  // Get all the values that `n` writes to.
131  // NOTE: this only returns values directly written to, not aliases thereof
132  //
133  // if `recurseBlocks` is true, gather writes on the nodes in `n`s sub-blocks
134  ValueSet getWrites(Node* n, bool recurseBlocks = false) const;
135  ValueSet getWrites(Block* b) const;
136  void getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks = false) const;
137  void getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const;
138  // Do any nodes write to `v`s memory location?
139  bool hasWriters(const Value* v) const;
140  // Register the fact that `n` writes to `v`.
141  void registerWrite(const Value* v, Node* n);
142  // Get all the values that `n` reads from.
143  // if `recurseBlocks` is true, gather reads on the nodes in `n`s sub-blocks
144  ValueSet getReads(Node* n, bool recurseBlocks = false) const;
145  void getReadsImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const;
146 
147  // Does `n` write to a value that may alias one of the graph inputs?
148  bool writesToInputAlias(Node* n) const;
149  // Does `n` write to `v` or any aliases of `v`?
150  bool writesTo(Node* n, const Value* v) const;
151 
155  // is `v` a wildcard?
156  bool isWildcard(const Value* v) const;
157  // Register `v` as a wildcard value.
158  void setWildcard(const Value* v);
159  // Get all nodes that write to a wildcard value.
160  const std::unordered_set<Node*>& getWildcardWriters() const {
161  return wildcardWriters_;
162  }
163  // Does `n` use or write to any wildcard aliases?
164  bool hasWildcard(const Node* n) const;
165  // Returns nullopt if there are no wildcard nodes
166  c10::optional<const Node*> getLastWildcard() const;
167 
171  void analyze(const std::shared_ptr<Graph>& graph);
172  void analyze(Block* block);
173  void analyze(Node* node);
174  void analyzeImpl(Node* node);
175  void analyzeIf(Node* node);
176  void analyzeLoop(Node* node);
177  void analyzeSubgraph(Node* node);
178  void analyzeCreator(Node* node);
179  void analyzeExtractor(Node* node);
180  void analyzeChunk(Node* node);
181  void analyzeBroadcastingChunk(Node* node);
182  void analyzeFork(Node* node);
183  void analyzeWait(Node* node);
184  void analyzeSetAttr(Node* node);
185 
189  void makeAllAlias(const std::vector<Value*>& values);
190  void makePointerTo(const Value* value, const Value* to);
191  void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from);
192  void giveFreshAlias(const Value* value);
193 
194  static bool shouldAnnotate(const Value* v);
195  static bool shouldAnnotate(const TypePtr& type);
196  bool hasUsesAfter(Symbol alias, const Node* n) const;
197  bool isBeforeSameGraph(const Node* lhs, const Node* rhs) const;
198 
199  // Returns true iff `v` is part of the alias tracker/is a wildcard
200  bool isTracked(const Value* v) const;
201 
202  // Get the values that represent the memory locations that `v` may point to.
203  // Return values are guaranteed to be "fresh" tensors--they do not point to
204  // anything else.
205  ValueSet getMemoryLocations(const Value* v) const;
206 
207  std::shared_ptr<Graph> graph_;
208  std::unordered_map<const Graph*, const Node*> subgraphToOwner_;
209 
210  // The points-to graph that stores aliasing relationships
211  std::unique_ptr<MemoryDAG> memoryDAG_;
212  // Mapping of values to MemoryDAG elements
213  std::unordered_map<const Value*, Element*> elementMap_;
214 
215  // All values that may point to a wildcard value.
216  ValueSet wildcards_;
217  // All nodes that write to a wildcard
218  std::unordered_set<Node*> wildcardWriters_;
219  // All nodes that contain a wildcard
220  std::unordered_set<const Node*> wildcardNodes_;
221 
222  // State for tracking write info
223  size_t numWrites_ = 0;
224  std::unordered_map<Node*, ValueSet> writeIndex_;
225  mutable std::unordered_set<const Element*> writeCache_;
226  mutable bool isWriteCacheStale_ = true;
227  void rebuildWriteCache() const;
228 };
229 
230 // Used to assert that unschematized operators have an analysis method written
231 TORCH_API bool aliasAnalysisHasSpecialCaseFor(c10::Symbol sym);
232 } // namespace jit
233 } // namespace torch
Alias analysis pass.
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41