Caffe2 - C++ API
A deep learning, cross platform ML framework
test_alias_analysis.h
1 #pragma once
2 
3 #include "test/cpp/jit/test_base.h"
4 #include "torch/csrc/jit/custom_operator.h"
5 #include "torch/csrc/jit/passes/alias_analysis.h"
6 #include "torch/csrc/jit/script/compiler.h"
7 #include "torch/csrc/utils/memory.h"
8 
9 namespace torch {
10 namespace jit {
11 
12 // Fixture to set up a graph and make assertions clearer
15  createGraph();
16  aliasDb = torch::make_unique<AliasDb>(graph);
17  }
18 
19  // Nodes are named after their output.
20  // e.g. "a" is an alias for "the node that outputs the value `a`"
21  void createGraph() {
22  graph = std::make_shared<Graph>();
23  createNode("a", {});
24  createNode("b", {"a"});
25  createNode("c", {});
26  createNode("d", {"a", "b"});
27  createNode("e", {"c", "b"});
28  createNode("f", {"e"});
29  createNode("g", {"e"});
30  createNode("h", {"g"});
31  createNode("i", {"g"});
32  createNode("j", {"i"});
33  createNode("k", {"i"});
34  createNode("l", {"a"});
35  createNode("m", {}, {"l"}); // block depends on l
36  createNode("n", {"m"});
37  createNode("o", {"n"});
38  createNode("p", {});
39  createNode("q", {});
40  createNode("r", {"q"});
41  createNode("s", {"q"});
42 
43  graph->lint();
44  }
45 
46  void createNode(
47  const std::string& name,
48  const std::vector<std::string>& inputNames,
49  const std::vector<std::string>& blockInputNames = {}) {
50  std::vector<Value*> inputs;
51  for (const auto name : inputNames) {
52  inputs.push_back(nodes.at(name)->output());
53  }
54  auto node = graph->appendNode(graph->create(prim::AutogradZero, inputs));
55  node->output()->setUniqueName(name);
56  nodes[name] = node;
57 
58  if (blockInputNames.size() != 0) {
59  node->addBlock();
60  std::vector<Value*> blockDeps;
61  for (const auto name : blockInputNames) {
62  blockDeps.push_back(nodes.at(name)->output());
63  }
64 
65  auto block = node->blocks().at(0);
66  block->appendNode(graph->create(prim::AutogradZero, blockDeps));
67  }
68  }
69 
70  bool moveBeforeTopologicallyValid(
71  const std::string& toInsert,
72  const std::string& insertPoint) {
73  std::function<bool(Node*, Node*)> func =
74  [this](Node* toInsert, Node* insertPoint) {
75  return aliasDb->moveBeforeTopologicallyValid(toInsert, insertPoint);
76  };
77  return moveWithChecks(toInsert, insertPoint, func);
78  }
79 
80  bool moveAfterTopologicallyValid(
81  const std::string& toInsert,
82  const std::string& insertPoint) {
83  std::function<bool(Node*, Node*)> func =
84  [this](Node* toInsert, Node* insertPoint) {
85  return aliasDb->moveAfterTopologicallyValid(toInsert, insertPoint);
86  };
87  return moveWithChecks(toInsert, insertPoint, func);
88  }
89 
90  bool moveWithChecks(
91  const std::string& toInsert,
92  const std::string& insertPoint,
93  std::function<bool(Node*, Node*)> func) {
94  auto n = nodes.at(toInsert);
95  auto insert = nodes.at(insertPoint);
96  bool isAfter = n->isAfter(insert);
97 
98  std::vector<Node*> originalOrdering;
99  Node* original = isAfter ? n->next() : n->prev();
100 
101  auto curNode = original;
102  while (curNode != n->owningBlock()->return_node()) {
103  originalOrdering.push_back(curNode);
104  if (isAfter) {
105  curNode = curNode->next();
106  } else {
107  curNode = curNode->prev();
108  }
109  }
110 
111  const auto couldMove = func(n, insert);
112  // Check the graph is okay
113  graph->lint();
114 
115  // If this is the picture of nodes
116  // <some nodes> ... toInsert ... <some more nodes> ... insertPoint
117  // ^----------^ check that these nodes haven't moved
118  curNode = original;
119  size_t idx = 0;
120  while (curNode != n->owningBlock()->return_node()) {
121  AT_ASSERT(originalOrdering[idx] == curNode);
122  if (isAfter) {
123  curNode = curNode->next();
124  } else {
125  curNode = curNode->prev();
126  }
127  idx++;
128  }
129 
130  return couldMove;
131  }
132 
133  void checkPostCondition(
134  const std::string& toInsert,
135  const std::string& insertPoint,
136  bool after) {
137  if (after) {
138  AT_ASSERT(nodes.at(toInsert)->prev() == nodes.at(insertPoint));
139  } else {
140  AT_ASSERT(nodes.at(toInsert)->next() == nodes.at(insertPoint));
141  }
142  }
143 
144  std::shared_ptr<Graph> graph;
145  std::unique_ptr<AliasDb> aliasDb;
146  std::unordered_map<std::string, Node*> nodes;
147 };
148 
149 void testTopologicalMove() {
150  {
151  // Check that we are removing `this`'s deps properly when we need to split
152  // `this` and deps (see code for what the hell that means)
153  TopoMoveTestFixture fixture;
154  AT_ASSERT(fixture.moveBeforeTopologicallyValid("q", "s"));
155  fixture.checkPostCondition("q", "s", false);
156  }
157  // Move after
158  {
159  // Simple move backward
160  TopoMoveTestFixture fixture;
161  AT_ASSERT(fixture.moveAfterTopologicallyValid("c", "a"));
162  fixture.checkPostCondition("c", "a", true);
163  }
164  {
165  // simple invalid move backward
166  TopoMoveTestFixture fixture;
167  AT_ASSERT(!fixture.moveAfterTopologicallyValid("d", "a"));
168  }
169  {
170  // doesn't actually move anything
171  TopoMoveTestFixture fixture;
172  AT_ASSERT(fixture.moveAfterTopologicallyValid("f", "e"));
173  fixture.checkPostCondition("f", "e", true);
174  }
175  {
176  // move backward with multiple dependencies
177  TopoMoveTestFixture fixture;
178  AT_ASSERT(fixture.moveAfterTopologicallyValid("e", "c"));
179  fixture.checkPostCondition("e", "c", true);
180  }
181  {
182  // Move backward with non-zero working set
183  TopoMoveTestFixture fixture;
184  AT_ASSERT(fixture.moveAfterTopologicallyValid("k", "f"));
185  fixture.checkPostCondition("k", "f", true);
186  }
187  {
188  // Simple move forward
189  TopoMoveTestFixture fixture;
190  AT_ASSERT(fixture.moveAfterTopologicallyValid("c", "d"));
191  fixture.checkPostCondition("c", "d", true);
192  }
193  {
194  // Move forward with non-zero working set
195  TopoMoveTestFixture fixture;
196  AT_ASSERT(fixture.moveAfterTopologicallyValid("f", "l"));
197  fixture.checkPostCondition("f", "l", true);
198  }
199 
200  // Move before
201  {
202  // Simple move forward
203  TopoMoveTestFixture fixture;
204  AT_ASSERT(fixture.moveBeforeTopologicallyValid("b", "d"));
205  fixture.checkPostCondition("b", "d", false);
206  }
207  {
208  // Simple move backward
209  TopoMoveTestFixture fixture;
210  AT_ASSERT(fixture.moveBeforeTopologicallyValid("c", "a"));
211  fixture.checkPostCondition("c", "a", false);
212  }
213  {
214  // doesn't actually move anything
215  TopoMoveTestFixture fixture;
216  AT_ASSERT(fixture.moveBeforeTopologicallyValid("a", "b"));
217  fixture.checkPostCondition("a", "b", false);
218  }
219  {
220  // move forward with deps
221  TopoMoveTestFixture fixture;
222  AT_ASSERT(fixture.moveBeforeTopologicallyValid("f", "m"));
223  fixture.checkPostCondition("f", "m", false);
224  }
225  {
226  // move backward with deps
227  TopoMoveTestFixture fixture;
228  AT_ASSERT(fixture.moveBeforeTopologicallyValid("l", "f"));
229  fixture.checkPostCondition("l", "f", false);
230  }
231 
232  // check that dependencies in blocks are recognized
233  {
234  TopoMoveTestFixture fixture;
235  AT_ASSERT(!fixture.moveAfterTopologicallyValid("l", "m"));
236  AT_ASSERT(!fixture.moveBeforeTopologicallyValid("m", "l"));
237  AT_ASSERT(!fixture.moveAfterTopologicallyValid("n", "l"));
238  AT_ASSERT(!fixture.moveBeforeTopologicallyValid("l", "n"));
239  }
240 
241  // Test that moveAfter(n) and moveBefore(n->next()) are not necessarily
242  // equivalent. Here, the dependency ordering is n -> o -> p. So we can't
243  // move `n` after `o`, but we can move `n` before `p` (which pushes `o` after
244  // `p`)
245  {
246  TopoMoveTestFixture fixture;
247  AT_ASSERT(!fixture.moveAfterTopologicallyValid("n", "o"));
248  AT_ASSERT(fixture.moveBeforeTopologicallyValid("o", "p"));
249  fixture.checkPostCondition("o", "p", false);
250  }
251 }
252 
253 namespace {
254 Node* insertIf(
255  Graph& g,
256  Value* condValue,
257  std::function<std::vector<Value*>()> trueInst,
258  std::function<std::vector<Value*>()> falseInst) {
259  auto if_ = g.insertNode(g.create(prim::If, 0));
260  if_->addInput(condValue); // condition value
261  auto trueBlock = if_->addBlock();
262  auto falseBlock = if_->addBlock();
263  {
264  // Mutate in true block
265  WithInsertPoint g(trueBlock);
266  auto outputs = trueInst();
267  for (auto output : outputs) {
268  trueBlock->registerOutput(output);
269  }
270  }
271  {
272  WithInsertPoint g(falseBlock);
273  auto outputs = falseInst();
274  for (auto output : outputs) {
275  falseBlock->registerOutput(output);
276  }
277  }
278 
279  AT_ASSERT(trueBlock->outputs().size() == falseBlock->outputs().size());
280  for (auto output : trueBlock->outputs()) {
281  if_->addOutput()->setType(output->type());
282  }
283  return if_;
284 }
285 } // namespace
286 
287 void testAliasAnalysis() {
288  {
289  auto graph = std::make_shared<Graph>();
290  auto a = graph->addInput();
291  auto b = graph->addInput();
292 
293  // addsB = b + b
294  // c = a + b
295  // a += b
296  // d = c + c
297  auto addsB = graph->insert(aten::add, {b, b});
298  auto c = graph->insert(aten::add, {a, b});
299  auto aMut = graph->insert(aten::add_, {a, b});
300  auto d = graph->insert(aten::add, {c, c});
301 
302  graph->lint();
303 
304  AliasDb aliasDb(graph);
305  // Can't move past a mutation of a used value
306  AT_ASSERT(!aliasDb.moveAfterTopologicallyValid(c->node(), aMut->node()));
307  AT_ASSERT(aliasDb.moveAfterTopologicallyValid(d->node(), c->node()));
308 
309  // b should alias to a (since they are both inputs)
310  AT_ASSERT(
311  !aliasDb.moveAfterTopologicallyValid(addsB->node(), aMut->node()));
312  AT_ASSERT(aliasDb.moveAfterTopologicallyValid(addsB->node(), c->node()));
313 
314  graph->lint();
315  }
316  {
317  auto graph = std::make_shared<Graph>();
318  auto a = graph->addInput();
319  auto b = graph->addInput();
320 
321  auto constant = graph->insertConstant(1);
322  auto fresh = graph->insert(aten::rand, {constant});
323  auto usesB = graph->insert(aten::add, {b, fresh});
324  auto aliasesB = graph->insert(aten::select, {a, constant, constant});
325  auto mutatesAliasOfB = graph->insert(aten::add_, {aliasesB, fresh});
326  graph->insert(aten::add, {fresh, aliasesB});
327  graph->lint();
328 
329  AliasDb aliasDb(graph);
330  AT_ASSERT(!aliasDb.moveAfterTopologicallyValid(
331  aliasesB->node(), mutatesAliasOfB->node()));
332  AT_ASSERT(!aliasDb.moveAfterTopologicallyValid(
333  usesB->node(), mutatesAliasOfB->node()));
334  }
335  {
336  // Test moves across inner blocks
337 
338  // a = rand(1)
339  // b = rand(1)
340  // if True:
341  // a.add_(b)
342  // c = a + b
343  auto graph = std::make_shared<Graph>();
344  auto constant = graph->insertConstant(1);
345  auto a = graph->insert(aten::rand, {constant});
346  auto b = graph->insert(aten::rand, {constant});
347 
348  auto if_ = insertIf(
349  *graph,
350  constant,
351  [&]() -> std::vector<Value*> {
352  auto aMut = graph->insert(aten::add_, {a, b});
353  return {aMut};
354  },
355  [&]() -> std::vector<Value*> { return {a}; });
356 
357  auto c = graph->insert(aten::add, {a, b});
358 
359  graph->lint();
360 
361  // we should not be able to move `c` before the if statement, since it
362  // may write to `a`.
363  AliasDb aliasDb(graph);
364  ASSERT_FALSE(aliasDb.moveBeforeTopologicallyValid(c->node(), if_));
365  }
366  {
367  // test fork/wait
368 
369  // a = rand(1)
370  // fut = fork(a)
371  // Subgraph is: return a.add_(1)
372  // ... some unrelated code
373  // c = wait(b)
374  // d = a + a
375 
376  auto graph = std::make_shared<Graph>();
377  auto constant = graph->insertConstant(1);
378  auto a = graph->insert(aten::rand, {constant});
379 
380  auto forkNode = graph->insertNode(graph->create(prim::fork));
381  auto forkBlock = forkNode->addBlock();
382  {
383  WithInsertPoint g(forkBlock);
384  auto aMut = graph->insert(aten::add_, {a, constant});
385  forkBlock->registerOutput(aMut);
386  forkNode->output()->setType(FutureType::create(aMut->type()));
387  }
388  script::lambdaLiftFork(forkNode);
389 
390  auto fut = forkNode->output();
391  auto wait = graph->insert(aten::wait, {fut})->node();
392  auto d = graph->insert(aten::add, {a, a});
393 
394  graph->lint();
395 
396  // Should not be able to move `d` before the wait call
397  AliasDb aliasDb(graph);
398  ASSERT_FALSE(aliasDb.moveBeforeTopologicallyValid(d->node(), wait));
399  }
400  {
401  // test fork/wait in an if statement
402 
403  // a = rand(1)
404  // if 1:
405  // fut = fork(a)
406  // Subgraph is: return a.add_(1)
407  // else:
408  // fut = fork(a)
409  // Subgraph is: return a.sub_(1)
410  // c = wait(b)
411  // d = a + a
412 
413  auto graph = std::make_shared<Graph>();
414  auto constant = graph->insertConstant(1);
415  auto a = graph->insert(aten::rand, {constant});
416  auto if_ = insertIf(
417  *graph,
418  constant,
419  [&]() -> std::vector<Value*> {
420  auto forkNode = graph->insertNode(graph->create(prim::fork));
421  auto forkBlock = forkNode->addBlock();
422  {
423  WithInsertPoint g(forkBlock);
424  auto aMut = graph->insert(aten::add_, {a, constant});
425  forkBlock->registerOutput(aMut);
426  forkNode->output()->setType(FutureType::create(aMut->type()));
427  }
428  script::lambdaLiftFork(forkNode);
429  return {forkNode->output()};
430  },
431  [&]() -> std::vector<Value*> {
432  auto forkNode = graph->insertNode(graph->create(prim::fork));
433  auto forkBlock = forkNode->addBlock();
434  {
435  WithInsertPoint g(forkBlock);
436  auto aMut = graph->insert(aten::sub_, {a, constant});
437  forkBlock->registerOutput(aMut);
438  forkNode->output()->setType(FutureType::create(aMut->type()));
439  }
440  script::lambdaLiftFork(forkNode);
441  return {forkNode->output()};
442  });
443 
444  auto fut = if_->output();
445  auto wait = graph->insert(aten::wait, {fut})->node();
446  auto d = graph->insert(aten::add, {a, a});
447 
448  graph->lint();
449 
450  // Should not be able to move `d` before the wait call
451  AliasDb aliasDb(graph);
452  ASSERT_FALSE(aliasDb.moveBeforeTopologicallyValid(d->node(), wait));
453  }
454 }
455 
456 void testWriteTracking() {
457  RegisterOperators reg({createOperator(
458  "foo::creates_alias(Tensor(a) x) -> Tensor(a)",
459  [](at::Tensor a) { return a; })});
460  const auto creates_alias = Symbol::fromQualString("foo::creates_alias");
461  const auto returns_wildcard = Symbol::fromQualString("foo::returns_wildcard");
462  {
463  auto graph = std::make_shared<Graph>();
464  auto a = graph->addInput();
465  auto b = graph->addInput();
466 
467  // aten::add(%b, %b)
468  // aten::add_(%a, %b)
469  // foo::creates_alias(%a)
470  auto pureNode = graph->insert(aten::add, {b, b})->node();
471  auto writingNode = graph->insert(aten::add_, {a, b})->node();
472  auto node3 = graph->insert(creates_alias, {a})->node();
473  auto aAlias = node3->output();
474 
475  graph->lint();
476 
477  AliasDb aliasDb(graph);
478  ASSERT_TRUE(aliasDb.mayAlias(aAlias, a));
479  ASSERT_TRUE(aliasDb.mayAlias(a, b));
480  ASSERT_FALSE(
481  aliasDb.writesToAlias(pureNode, std::unordered_set<const Value*>{a}));
482  ASSERT_FALSE(
483  aliasDb.writesToAlias(pureNode, std::unordered_set<const Value*>{b}));
484  ASSERT_TRUE(aliasDb.writesToAlias(
485  writingNode, std::unordered_set<const Value*>{a}));
486  ASSERT_TRUE(aliasDb.writesToAlias(
487  writingNode, std::unordered_set<const Value*>{a, b}));
488  ASSERT_TRUE(aliasDb.writesToAlias(
489  writingNode, std::unordered_set<const Value*>{aAlias}));
490  }
491 }
492 
493 void testWildcards() {
494  RegisterOperators reg({createOperator(
495  "foo::returns_wildcard(Tensor a) -> Tensor(*)",
496  [](at::Tensor a) { return a; }),
497  createOperator(
498  "foo::writes(Tensor(z!) a) -> Tensor(a)",
499  [](at::Tensor a) { return a; })});
500  const auto returns_wildcard = Symbol::fromQualString("foo::returns_wildcard");
501  const auto writes = Symbol::fromQualString("foo::writes");
502 
503  auto graph = std::make_shared<Graph>();
504  const auto a = graph->addInput();
505 
506  const auto constant = graph->insertConstant(1);
507  const auto fresh = graph->insert(aten::rand, {constant});
508  const auto fresh2 = graph->insert(aten::rand, {constant});
509  const auto wildcard = graph->insert(returns_wildcard, {fresh});
510 
511  {
512  graph->lint();
513  AliasDb aliasDb(graph);
514 
515  ASSERT_FALSE(aliasDb.mayAlias(a, fresh));
516  ASSERT_TRUE(aliasDb.mayAlias(wildcard, fresh));
517  ASSERT_TRUE(aliasDb.mayAlias(wildcard, a));
518  ASSERT_FALSE(aliasDb.mayAlias(
519  std::unordered_set<const Value*>({wildcard}),
520  std::unordered_set<const Value*>()));
521  ASSERT_FALSE(aliasDb.hasWriters(wildcard->node()));
522  }
523 
524  graph->insert(writes, {fresh2})->node();
525  {
526  graph->lint();
527  AliasDb aliasDb(graph);
528  // Any write should be considered a write to the wildcard
529  ASSERT_TRUE(aliasDb.hasWriters(wildcard->node()));
530  }
531 
532  const auto wildcardWrite = graph->insert(writes, {wildcard})->node();
533  {
534  graph->lint();
535  AliasDb aliasDb(graph);
536  // Test writes to wildcards
537  ASSERT_TRUE(aliasDb.writesToAlias(
538  wildcardWrite, std::unordered_set<const Value*>{fresh}));
539  ASSERT_TRUE(aliasDb.writesToAlias(
540  wildcardWrite, std::unordered_set<const Value*>{fresh2}));
541  ASSERT_TRUE(aliasDb.writesToAlias(
542  wildcardWrite, std::unordered_set<const Value*>{a}));
543  ASSERT_TRUE(aliasDb.hasWriters(wildcard->node()));
544  }
545 }
546 
547 void testMemoryDAG() {
548  auto graph = std::make_shared<Graph>();
549  const Value* aValue = graph->addInput();
550  const Value* bValue = graph->addInput();
551  const Value* cValue = graph->addInput();
552  const Value* dValue = graph->addInput();
553  const Value* eValue = graph->addInput();
554  const Value* fValue = graph->addInput();
555  const Value* gValue = graph->addInput();
556 
557  {
558  // a <- b <- c
559  // b <- d
560  // a <- e
561  // f <- e
562  // g is by itself
563  MemoryDAG t;
564  auto a = t.makeFreshValue(aValue);
565  auto b = t.makeFreshValue(bValue);
566  auto c = t.makeFreshValue(cValue);
567  auto d = t.makeFreshValue(dValue);
568  auto e = t.makeFreshValue(eValue);
569  auto f = t.makeFreshValue(fValue);
570  auto g = t.makeFreshValue(gValue);
571  t.makePointerTo(b, a);
572  t.makePointerTo(c, b);
573  t.makePointerTo(d, b);
574  t.makePointerTo(e, a);
575  t.makePointerTo(e, f);
576 
580  // Values should alias themselves
581  ASSERT_TRUE(t.mayAlias(a, a));
582  ASSERT_TRUE(t.mayAlias(g, g));
583 
584  // Values that point to the same location should alias
585  ASSERT_TRUE(t.mayAlias(a, b));
586  ASSERT_TRUE(t.mayAlias(a, c));
587  ASSERT_TRUE(t.mayAlias(c, d));
588 
589  // e may point to a OR f
590  ASSERT_TRUE(t.mayAlias(e, a));
591  ASSERT_TRUE(t.mayAlias(e, f));
592  // But a and f don't alias
593  ASSERT_FALSE(t.mayAlias(a, f));
594 
598  std::multiset<const Element*> foo{c, c, d};
599  std::multiset<const Element*> bar{e, f};
600  std::unordered_set<const Element*> baz{f, g};
601  ASSERT_TRUE(t.mayAlias(foo, bar));
602  ASSERT_TRUE(t.mayAlias(bar, baz));
603  ASSERT_FALSE(t.mayAlias(foo, baz));
604  }
605 }
606 } // namespace jit
607 } // namespace torch
Alias analysis pass.
Registration class for new operators.
Definition: jit_type.h:17
An utility class for setting temporary insertion points.
Definition: ir.h:1174