3 #include <torch/csrc/jit/alias_info.h> 4 #include <torch/csrc/jit/ir.h> 5 #include <torch/csrc/jit/passes/utils/memory_dag.h> 30 TORCH_API
explicit AliasDb(std::shared_ptr<Graph> graph);
41 bool hasUntrackedEffects(
Node* n)
const;
45 bool writesToAlias(
Node* n,
const ValueSet& vs,
bool recurseBlocks =
false)
49 bool mayAlias(
const Value* a,
const Value* b)
const;
58 template <
typename,
typename...>
class T,
60 template <
typename,
typename...>
class U>
63 const U<const Value*, Other2...>& b)
const {
64 if (a.empty() || b.empty()) {
72 [
this](
const Value* v) {
return isWildcard(v); }) ||
73 std::any_of(b.cbegin(), b.cend(), [
this](
const Value* v) {
80 for (
const Value* v : a) {
81 if (elementMap_.count(v)) {
82 aElements.insert(elementMap_.at(v));
86 U<Element*> bElements;
87 for (
const Value* v : b) {
88 if (elementMap_.count(v)) {
89 bElements.insert(elementMap_.at(v));
93 return memoryDAG_->mayAlias(aElements, bElements);
97 bool hasWriters(
const Node* n)
const;
108 bool moveAfterTopologicallyValid(
Node* n,
Node* movePoint);
109 bool moveBeforeTopologicallyValid(
Node* n,
Node* movePoint);
111 bool couldMoveAfterTopologically(
Node* n,
Node* movePoint);
112 bool couldMoveBeforeTopologically(
Node* n,
Node* movePoint);
115 TORCH_API
void dump()
const;
120 enum class MoveSide { BEFORE, AFTER };
121 bool tryMove(
Node* toMove,
Node* movePoint, MoveSide moveSide,
bool dryRun);
122 void move(
Node* toMove,
Node* movePoint, MoveSide moveSide);
123 bool isBeforeOrAfter(
const Node* n, MoveSide moveSide)
const;
129 bool hasWrites(
Node* n)
const;
134 ValueSet getWrites(
Node* n,
bool recurseBlocks =
false)
const;
135 ValueSet getWrites(
Block* b)
const;
136 void getWritesImpl(
Block* b, ValueSet& ret,
bool recurseBlocks =
false)
const;
137 void getWritesImpl(
Node* n, ValueSet& ret,
bool recurseBlocks =
false)
const;
139 bool hasWriters(
const Value* v)
const;
141 void registerWrite(
const Value* v,
Node* n);
144 ValueSet getReads(
Node* n,
bool recurseBlocks =
false)
const;
145 void getReadsImpl(
Node* n, ValueSet& ret,
bool recurseBlocks =
false)
const;
148 bool writesToInputAlias(
Node* n)
const;
150 bool writesTo(
Node* n,
const Value* v)
const;
156 bool isWildcard(
const Value* v)
const;
158 void setWildcard(
const Value* v);
160 const std::unordered_set<Node*>& getWildcardWriters()
const {
161 return wildcardWriters_;
164 bool hasWildcard(
const Node* n)
const;
171 void analyze(
const std::shared_ptr<Graph>& graph);
172 void analyze(
Block* block);
173 void analyze(
Node* node);
174 void analyzeImpl(
Node* node);
175 void analyzeIf(
Node* node);
176 void analyzeLoop(
Node* node);
177 void analyzeSubgraph(
Node* node);
178 void analyzeCreator(
Node* node);
179 void analyzeExtractor(
Node* node);
180 void analyzeChunk(
Node* node);
181 void analyzeBroadcastingChunk(
Node* node);
182 void analyzeFork(
Node* node);
183 void analyzeWait(
Node* node);
184 void analyzeSetAttr(
Node* node);
189 void makeAllAlias(
const std::vector<Value*>& values);
190 void makePointerTo(
const Value* value,
const Value* to);
192 void giveFreshAlias(
const Value* value);
194 static bool shouldAnnotate(
const Value* v);
195 static bool shouldAnnotate(
const TypePtr& type);
196 bool hasUsesAfter(
Symbol alias,
const Node* n)
const;
197 bool isBeforeSameGraph(
const Node* lhs,
const Node* rhs)
const;
200 bool isTracked(
const Value* v)
const;
205 ValueSet getMemoryLocations(
const Value* v)
const;
207 std::shared_ptr<Graph> graph_;
208 std::unordered_map<const Graph*, const Node*> subgraphToOwner_;
211 std::unique_ptr<MemoryDAG> memoryDAG_;
213 std::unordered_map<const Value*, Element*> elementMap_;
218 std::unordered_set<Node*> wildcardWriters_;
220 std::unordered_set<const Node*> wildcardNodes_;
223 size_t numWrites_ = 0;
224 std::unordered_map<Node*, ValueSet> writeIndex_;
225 mutable std::unordered_set<const Element*> writeCache_;
226 mutable bool isWriteCacheStale_ =
true;
227 void rebuildWriteCache()
const;
231 TORCH_API
bool aliasAnalysisHasSpecialCaseFor(
c10::Symbol sym);
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...