Caffe2 - C++ API
A deep learning, cross platform ML framework
memory_dag.cpp
1 #include "memory_dag.h"
2 
3 #include <torch/csrc/utils/memory.h>
4 #include <queue>
5 
6 namespace torch {
7 namespace jit {
8 
9 bool MemoryDAG::mayAlias(Element* a, Element* b) const {
10  return mayAliasImpl(a, b);
11 }
12 
13 bool MemoryDAG::mayAlias(const Element* a, const Element* b) const {
14  return mayAliasImpl(a, b);
15 }
16 
17 bool MemoryDAG::mayAliasImpl(const Element* a, const Element* b) const {
18  const auto aMemLoc = a->getMemoryLocations();
19  const auto bMemLoc = b->getMemoryLocations();
20 
21  // XXX: This could be more efficiently done as a bitwise AND on two bitfields
22  // that represent memory location membership. If these comparisons end up
23  // being a bottleneck, consider implementing it that way.
24  for (const auto aLoc : aMemLoc) {
25  for (const auto bLoc : bMemLoc) {
26  if (aLoc == bLoc) {
27  return true;
28  }
29  }
30  }
31  return false;
32 }
33 
34 // Make `v` point at `to`.
35 void MemoryDAG::makePointerTo(Element* from, Element* to) {
36  from->pointsTo.insert(to);
37  to->pointedFrom.insert(from);
38 }
39 
40 // Give `v` a fresh alias (i.e. it does not point to any value)
41 Element* MemoryDAG::makeFreshValue(const Value* v) {
42  auto el = torch::make_unique<Element>();
43  el->value = v;
44 
45  auto rawPtr = el.get();
46  elements_.emplace(rawPtr, std::move(el));
47  return rawPtr;
48 }
49 
50 std::unordered_set<const Element*> Element::getMemoryLocations() const {
51  if (!cachedMemoryLocations_.empty()) {
52  return cachedMemoryLocations_;
53  }
54 
55  // Do a BFS in the `points-to` direction, collecting all memory locations
56  std::unordered_set<const Element*> ret;
57  this->bfs(
58  [&](const Element* el) {
59  if (el->pointsTo.empty()) {
60  ret.insert(el);
61  }
62  },
63  BfsDirection::POINTS_TO);
64 
65  cachedMemoryLocations_ = ret;
66  return ret;
67 }
68 
69 // Do a breadth-first search over the graph, starting at `this` and
70 // traversing in the direction `dir`.`fn` will be run on each element.
71 template <typename Fn>
72 bool Element::bfs(Fn fn, BfsDirection dir) const {
73  std::queue<const Element*> queue;
74  std::unordered_set<const Element*> seen;
75 
76  queue.push(this);
77  while (!queue.empty()) {
78  const auto el = queue.front();
79  queue.pop();
80  seen.insert(el);
81 
82  fn(el);
83 
84  switch (dir) {
85  case BfsDirection::POINTS_TO: {
86  for (auto ptr : el->pointsTo) {
87  if (!seen.count(ptr)) {
88  queue.push(ptr);
89  }
90  }
91  } break;
92 
93  case BfsDirection::POINTED_FROM: {
94  for (auto ptr : el->pointedFrom) {
95  if (!seen.count(ptr)) {
96  queue.push(ptr);
97  }
98  }
99  } break;
100  }
101  }
102  return false;
103 }
104 } // namespace jit
105 } // namespace torch
Definition: jit_type.h:17