Caffe2 - C++ API
A deep learning, cross platform ML framework
memory_dag.h
1 #pragma once
2 
3 #include <unordered_set>
4 #include <unordered_map>
5 #include <memory>
6 
7 namespace torch {
8 namespace jit {
9 
10 struct Element;
11 struct Value;
12 
13 // class MemoryDAG
14 //
15 // This class tracks the "A points to B" graph for all values. It is used by
16 // AliasDb to provide a higher-level API.
17 //
18 // We maintain a DAG where:
19 // - Vertices (called "elements") represent values and
20 // other aliasing entities (e.g. like the stuff inside a list)
21 // - Edges represent a "points-to" relationship.
22 //
23 // Leaves in this DAG are entities that don't point to anything, and thus
24 // correspond to unique "memory locations".
25 //
26 // So, by traversing the "points-to" graph to the leaves, you can determine
27 // which memory locations an element may point to.
28 class MemoryDAG {
29  public:
30  // Make `from` point at `to`.
31  void makePointerTo(Element* from, Element* to);
32 
33  // Make a fresh element (i.e. an element that doesn't point to anything) and
34  // return it.
35  Element* makeFreshValue(const Value* v);
36 
37  // Do `a` and `b` potentially share a memory location?
38  bool mayAlias(const Element* a, const Element* b) const;
39  bool mayAlias(Element* a, Element* b) const;
40 
41  // Do any values in group `a` potentially share a memory location with any
42  // value in group `b`?
43  //
44  // This is written so that either of the inputs could be a multiset
45  template <typename T, typename U>
46  bool mayAlias(const T& a, const U& b) const {
47  if (a.empty() || b.empty()) {
48  return false;
49  }
50 
51  // Record all memory locations from group `a`
52  std::unordered_set<const Element*> memoryLocations;
53  for (auto it = a.cbegin(); it != a.cend();) {
54  const auto element = *it;
55 
56  for (const auto loc : element->getMemoryLocations()) {
57  memoryLocations.insert(loc);
58  }
59 
60  const auto cnt = a.count(*it);
61  std::advance(it, cnt);
62  }
63 
64  // If any of group `b`s memory locations overlap, return true.
65  for (auto it = b.cbegin(); it != b.cend();) {
66  const auto element = *it;
67 
68  for (const auto loc : element->getMemoryLocations()) {
69  if (memoryLocations.count(loc)) {
70  return true;
71  }
72  }
73 
74  const auto cnt = b.count(*it);
75  std::advance(it, cnt);
76  }
77  // No overlap, so group `a` and `b` do not share a memory location
78  return false;
79  }
80 
81  private:
82  bool mayAliasImpl(const Element* a, const Element* b) const;
83  // Structure that owns all the element pointers. It's a map of
84  // raw pointer -> unique_ptr to facilitate easy queries
85  std::unordered_map<Element*, std::unique_ptr<Element>> elements_;
86 };
87 
88 enum class BfsDirection {
89  POINTS_TO,
90  POINTED_FROM,
91 };
92 
93 // `Element` represents the vertex in the points-to graph. It represents
94 // anything that could have an aliasing relationship, mostly IR `Value`s, but
95 // also the "inside of a list", or wildcards.
96 struct Element {
97  // The value that this element corresponds to. May be null if this element
98  // doesn't represent a first-class value.
99  const Value* value = nullptr;
100 
101  // All elements that this element *may* point to. It's possible to have
102  // multiple elements that you might point to due to control flow/complex ops
103  std::unordered_set<Element*> pointsTo;
104  // Backreference for points-to.
105  std::unordered_set<Element*> pointedFrom;
106 
107  // Return the unique memory locations that `Element` might represent.
108  std::unordered_set<const Element*> getMemoryLocations() const;
109  // We do path compression to make repeated memory location queries faster.
110  // An empty cache means it is invalidated (it can never be empty otherwise,
111  // since every element must point to at least one memory location).
112  mutable std::unordered_set<const Element*> cachedMemoryLocations_;
113 
114  // Do a breadth-first search over the graph, starting at `this` and
115  // traversing in the direction `dir`.`fn` will be run on each element.
116  template <typename Fn>
117  bool bfs(Fn fn, BfsDirection dir) const;
118 };
119 
120 } // namespace jit
121 } // namespace torch
Definition: jit_type.h:17