1 #include "memory_dag.h" 3 #include <torch/csrc/utils/memory.h> 9 bool MemoryDAG::mayAlias(Element* a, Element* b)
const {
10 return mayAliasImpl(a, b);
13 bool MemoryDAG::mayAlias(
const Element* a,
const Element* b)
const {
14 return mayAliasImpl(a, b);
17 bool MemoryDAG::mayAliasImpl(
const Element* a,
const Element* b)
const {
18 const auto aMemLoc = a->getMemoryLocations();
19 const auto bMemLoc = b->getMemoryLocations();
24 for (
const auto aLoc : aMemLoc) {
25 for (
const auto bLoc : bMemLoc) {
35 void MemoryDAG::makePointerTo(Element* from, Element* to) {
36 from->pointsTo.insert(to);
37 to->pointedFrom.insert(from);
41 Element* MemoryDAG::makeFreshValue(
const Value* v) {
42 auto el = torch::make_unique<Element>();
45 auto rawPtr = el.get();
46 elements_.emplace(rawPtr, std::move(el));
50 std::unordered_set<const Element*> Element::getMemoryLocations()
const {
51 if (!cachedMemoryLocations_.empty()) {
52 return cachedMemoryLocations_;
56 std::unordered_set<const Element*> ret;
58 [&](
const Element* el) {
59 if (el->pointsTo.empty()) {
63 BfsDirection::POINTS_TO);
65 cachedMemoryLocations_ = ret;
71 template <
typename Fn>
72 bool Element::bfs(Fn fn, BfsDirection dir)
const {
73 std::queue<const Element*> queue;
74 std::unordered_set<const Element*> seen;
77 while (!queue.empty()) {
78 const auto el = queue.front();
85 case BfsDirection::POINTS_TO: {
86 for (
auto ptr : el->pointsTo) {
87 if (!seen.count(ptr)) {
93 case BfsDirection::POINTED_FROM: {
94 for (
auto ptr : el->pointedFrom) {
95 if (!seen.count(ptr)) {