3 #include "test/cpp/jit/test_base.h" 4 #include "torch/csrc/jit/custom_operator.h" 5 #include "torch/csrc/jit/passes/alias_analysis.h" 6 #include "torch/csrc/jit/script/compiler.h" 7 #include "torch/csrc/utils/memory.h" 16 aliasDb = torch::make_unique<AliasDb>(graph);
22 graph = std::make_shared<Graph>();
24 createNode(
"b", {
"a"});
26 createNode(
"d", {
"a",
"b"});
27 createNode(
"e", {
"c",
"b"});
28 createNode(
"f", {
"e"});
29 createNode(
"g", {
"e"});
30 createNode(
"h", {
"g"});
31 createNode(
"i", {
"g"});
32 createNode(
"j", {
"i"});
33 createNode(
"k", {
"i"});
34 createNode(
"l", {
"a"});
35 createNode(
"m", {}, {
"l"});
36 createNode(
"n", {
"m"});
37 createNode(
"o", {
"n"});
40 createNode(
"r", {
"q"});
41 createNode(
"s", {
"q"});
47 const std::string& name,
48 const std::vector<std::string>& inputNames,
49 const std::vector<std::string>& blockInputNames = {}) {
50 std::vector<Value*> inputs;
51 for (
const auto name : inputNames) {
52 inputs.push_back(nodes.at(name)->output());
54 auto node = graph->appendNode(graph->create(prim::AutogradZero, inputs));
55 node->output()->setUniqueName(name);
58 if (blockInputNames.size() != 0) {
60 std::vector<Value*> blockDeps;
61 for (
const auto name : blockInputNames) {
62 blockDeps.push_back(nodes.at(name)->output());
65 auto block = node->blocks().at(0);
66 block->appendNode(graph->create(prim::AutogradZero, blockDeps));
70 bool moveBeforeTopologicallyValid(
71 const std::string& toInsert,
72 const std::string& insertPoint) {
73 std::function<bool(Node*, Node*)> func =
74 [
this](
Node* toInsert,
Node* insertPoint) {
75 return aliasDb->moveBeforeTopologicallyValid(toInsert, insertPoint);
77 return moveWithChecks(toInsert, insertPoint, func);
80 bool moveAfterTopologicallyValid(
81 const std::string& toInsert,
82 const std::string& insertPoint) {
83 std::function<bool(Node*, Node*)> func =
84 [
this](
Node* toInsert,
Node* insertPoint) {
85 return aliasDb->moveAfterTopologicallyValid(toInsert, insertPoint);
87 return moveWithChecks(toInsert, insertPoint, func);
91 const std::string& toInsert,
92 const std::string& insertPoint,
93 std::function<
bool(
Node*,
Node*)> func) {
94 auto n = nodes.at(toInsert);
95 auto insert = nodes.at(insertPoint);
96 bool isAfter = n->isAfter(insert);
98 std::vector<Node*> originalOrdering;
99 Node* original = isAfter ? n->next() : n->prev();
101 auto curNode = original;
102 while (curNode != n->owningBlock()->return_node()) {
103 originalOrdering.push_back(curNode);
105 curNode = curNode->next();
107 curNode = curNode->prev();
111 const auto couldMove = func(n, insert);
120 while (curNode != n->owningBlock()->return_node()) {
121 AT_ASSERT(originalOrdering[idx] == curNode);
123 curNode = curNode->next();
125 curNode = curNode->prev();
133 void checkPostCondition(
134 const std::string& toInsert,
135 const std::string& insertPoint,
138 AT_ASSERT(nodes.at(toInsert)->prev() == nodes.at(insertPoint));
140 AT_ASSERT(nodes.at(toInsert)->next() == nodes.at(insertPoint));
144 std::shared_ptr<Graph> graph;
145 std::unique_ptr<AliasDb> aliasDb;
146 std::unordered_map<std::string, Node*> nodes;
149 void testTopologicalMove() {
154 AT_ASSERT(fixture.moveBeforeTopologicallyValid(
"q",
"s"));
155 fixture.checkPostCondition(
"q",
"s",
false);
161 AT_ASSERT(fixture.moveAfterTopologicallyValid(
"c",
"a"));
162 fixture.checkPostCondition(
"c",
"a",
true);
167 AT_ASSERT(!fixture.moveAfterTopologicallyValid(
"d",
"a"));
172 AT_ASSERT(fixture.moveAfterTopologicallyValid(
"f",
"e"));
173 fixture.checkPostCondition(
"f",
"e",
true);
178 AT_ASSERT(fixture.moveAfterTopologicallyValid(
"e",
"c"));
179 fixture.checkPostCondition(
"e",
"c",
true);
184 AT_ASSERT(fixture.moveAfterTopologicallyValid(
"k",
"f"));
185 fixture.checkPostCondition(
"k",
"f",
true);
190 AT_ASSERT(fixture.moveAfterTopologicallyValid(
"c",
"d"));
191 fixture.checkPostCondition(
"c",
"d",
true);
196 AT_ASSERT(fixture.moveAfterTopologicallyValid(
"f",
"l"));
197 fixture.checkPostCondition(
"f",
"l",
true);
204 AT_ASSERT(fixture.moveBeforeTopologicallyValid(
"b",
"d"));
205 fixture.checkPostCondition(
"b",
"d",
false);
210 AT_ASSERT(fixture.moveBeforeTopologicallyValid(
"c",
"a"));
211 fixture.checkPostCondition(
"c",
"a",
false);
216 AT_ASSERT(fixture.moveBeforeTopologicallyValid(
"a",
"b"));
217 fixture.checkPostCondition(
"a",
"b",
false);
222 AT_ASSERT(fixture.moveBeforeTopologicallyValid(
"f",
"m"));
223 fixture.checkPostCondition(
"f",
"m",
false);
228 AT_ASSERT(fixture.moveBeforeTopologicallyValid(
"l",
"f"));
229 fixture.checkPostCondition(
"l",
"f",
false);
235 AT_ASSERT(!fixture.moveAfterTopologicallyValid(
"l",
"m"));
236 AT_ASSERT(!fixture.moveBeforeTopologicallyValid(
"m",
"l"));
237 AT_ASSERT(!fixture.moveAfterTopologicallyValid(
"n",
"l"));
238 AT_ASSERT(!fixture.moveBeforeTopologicallyValid(
"l",
"n"));
247 AT_ASSERT(!fixture.moveAfterTopologicallyValid(
"n",
"o"));
248 AT_ASSERT(fixture.moveBeforeTopologicallyValid(
"o",
"p"));
249 fixture.checkPostCondition(
"o",
"p",
false);
257 std::function<std::vector<Value*>()> trueInst,
258 std::function<std::vector<Value*>()> falseInst) {
259 auto if_ = g.insertNode(g.create(prim::If, 0));
260 if_->addInput(condValue);
261 auto trueBlock = if_->addBlock();
262 auto falseBlock = if_->addBlock();
266 auto outputs = trueInst();
267 for (
auto output : outputs) {
268 trueBlock->registerOutput(output);
273 auto outputs = falseInst();
274 for (
auto output : outputs) {
275 falseBlock->registerOutput(output);
279 AT_ASSERT(trueBlock->outputs().size() == falseBlock->outputs().size());
280 for (
auto output : trueBlock->outputs()) {
281 if_->addOutput()->setType(output->type());
287 void testAliasAnalysis() {
289 auto graph = std::make_shared<Graph>();
290 auto a = graph->addInput();
291 auto b = graph->addInput();
297 auto addsB = graph->insert(aten::add, {b, b});
298 auto c = graph->insert(aten::add, {a, b});
299 auto aMut = graph->insert(aten::add_, {a, b});
300 auto d = graph->insert(aten::add, {c, c});
306 AT_ASSERT(!aliasDb.moveAfterTopologicallyValid(c->node(), aMut->node()));
307 AT_ASSERT(aliasDb.moveAfterTopologicallyValid(d->node(), c->node()));
311 !aliasDb.moveAfterTopologicallyValid(addsB->node(), aMut->node()));
312 AT_ASSERT(aliasDb.moveAfterTopologicallyValid(addsB->node(), c->node()));
317 auto graph = std::make_shared<Graph>();
318 auto a = graph->addInput();
319 auto b = graph->addInput();
321 auto constant = graph->insertConstant(1);
322 auto fresh = graph->insert(aten::rand, {constant});
323 auto usesB = graph->insert(aten::add, {b, fresh});
324 auto aliasesB = graph->insert(aten::select, {a, constant, constant});
325 auto mutatesAliasOfB = graph->insert(aten::add_, {aliasesB, fresh});
326 graph->insert(aten::add, {fresh, aliasesB});
330 AT_ASSERT(!aliasDb.moveAfterTopologicallyValid(
331 aliasesB->node(), mutatesAliasOfB->node()));
332 AT_ASSERT(!aliasDb.moveAfterTopologicallyValid(
333 usesB->node(), mutatesAliasOfB->node()));
343 auto graph = std::make_shared<Graph>();
344 auto constant = graph->insertConstant(1);
345 auto a = graph->insert(aten::rand, {constant});
346 auto b = graph->insert(aten::rand, {constant});
351 [&]() -> std::vector<Value*> {
352 auto aMut = graph->insert(aten::add_, {a, b});
355 [&]() -> std::vector<Value*> {
return {a}; });
357 auto c = graph->insert(aten::add, {a, b});
364 ASSERT_FALSE(aliasDb.moveBeforeTopologicallyValid(c->node(), if_));
376 auto graph = std::make_shared<Graph>();
377 auto constant = graph->insertConstant(1);
378 auto a = graph->insert(aten::rand, {constant});
380 auto forkNode = graph->insertNode(graph->create(prim::fork));
381 auto forkBlock = forkNode->addBlock();
384 auto aMut = graph->insert(aten::add_, {a, constant});
385 forkBlock->registerOutput(aMut);
386 forkNode->output()->setType(FutureType::create(aMut->type()));
388 script::lambdaLiftFork(forkNode);
390 auto fut = forkNode->output();
391 auto wait = graph->insert(aten::wait, {fut})->node();
392 auto d = graph->insert(aten::add, {a, a});
398 ASSERT_FALSE(aliasDb.moveBeforeTopologicallyValid(d->node(), wait));
413 auto graph = std::make_shared<Graph>();
414 auto constant = graph->insertConstant(1);
415 auto a = graph->insert(aten::rand, {constant});
419 [&]() -> std::vector<Value*> {
420 auto forkNode = graph->insertNode(graph->create(prim::fork));
421 auto forkBlock = forkNode->addBlock();
424 auto aMut = graph->insert(aten::add_, {a, constant});
425 forkBlock->registerOutput(aMut);
426 forkNode->output()->setType(FutureType::create(aMut->type()));
428 script::lambdaLiftFork(forkNode);
429 return {forkNode->output()};
431 [&]() -> std::vector<Value*> {
432 auto forkNode = graph->insertNode(graph->create(prim::fork));
433 auto forkBlock = forkNode->addBlock();
436 auto aMut = graph->insert(aten::sub_, {a, constant});
437 forkBlock->registerOutput(aMut);
438 forkNode->output()->setType(FutureType::create(aMut->type()));
440 script::lambdaLiftFork(forkNode);
441 return {forkNode->output()};
444 auto fut = if_->output();
445 auto wait = graph->insert(aten::wait, {fut})->node();
446 auto d = graph->insert(aten::add, {a, a});
452 ASSERT_FALSE(aliasDb.moveBeforeTopologicallyValid(d->node(), wait));
456 void testWriteTracking() {
458 "foo::creates_alias(Tensor(a) x) -> Tensor(a)",
460 const auto creates_alias = Symbol::fromQualString(
"foo::creates_alias");
461 const auto returns_wildcard = Symbol::fromQualString(
"foo::returns_wildcard");
463 auto graph = std::make_shared<Graph>();
464 auto a = graph->addInput();
465 auto b = graph->addInput();
470 auto pureNode = graph->insert(aten::add, {b, b})->node();
471 auto writingNode = graph->insert(aten::add_, {a, b})->node();
472 auto node3 = graph->insert(creates_alias, {a})->node();
473 auto aAlias = node3->output();
478 ASSERT_TRUE(aliasDb.mayAlias(aAlias, a));
479 ASSERT_TRUE(aliasDb.mayAlias(a, b));
481 aliasDb.writesToAlias(pureNode, std::unordered_set<const Value*>{a}));
483 aliasDb.writesToAlias(pureNode, std::unordered_set<const Value*>{b}));
484 ASSERT_TRUE(aliasDb.writesToAlias(
485 writingNode, std::unordered_set<const Value*>{a}));
486 ASSERT_TRUE(aliasDb.writesToAlias(
487 writingNode, std::unordered_set<const Value*>{a, b}));
488 ASSERT_TRUE(aliasDb.writesToAlias(
489 writingNode, std::unordered_set<const Value*>{aAlias}));
493 void testWildcards() {
495 "foo::returns_wildcard(Tensor a) -> Tensor(*)",
498 "foo::writes(Tensor(z!) a) -> Tensor(a)",
500 const auto returns_wildcard = Symbol::fromQualString(
"foo::returns_wildcard");
501 const auto writes = Symbol::fromQualString(
"foo::writes");
503 auto graph = std::make_shared<Graph>();
504 const auto a = graph->addInput();
506 const auto constant = graph->insertConstant(1);
507 const auto fresh = graph->insert(aten::rand, {constant});
508 const auto fresh2 = graph->insert(aten::rand, {constant});
509 const auto wildcard = graph->insert(returns_wildcard, {fresh});
515 ASSERT_FALSE(aliasDb.mayAlias(a, fresh));
516 ASSERT_TRUE(aliasDb.mayAlias(wildcard, fresh));
517 ASSERT_TRUE(aliasDb.mayAlias(wildcard, a));
518 ASSERT_FALSE(aliasDb.mayAlias(
519 std::unordered_set<const Value*>({wildcard}),
520 std::unordered_set<const Value*>()));
521 ASSERT_FALSE(aliasDb.hasWriters(wildcard->node()));
524 graph->insert(writes, {fresh2})->node();
529 ASSERT_TRUE(aliasDb.hasWriters(wildcard->node()));
532 const auto wildcardWrite = graph->insert(writes, {wildcard})->node();
537 ASSERT_TRUE(aliasDb.writesToAlias(
538 wildcardWrite, std::unordered_set<const Value*>{fresh}));
539 ASSERT_TRUE(aliasDb.writesToAlias(
540 wildcardWrite, std::unordered_set<const Value*>{fresh2}));
541 ASSERT_TRUE(aliasDb.writesToAlias(
542 wildcardWrite, std::unordered_set<const Value*>{a}));
543 ASSERT_TRUE(aliasDb.hasWriters(wildcard->node()));
547 void testMemoryDAG() {
548 auto graph = std::make_shared<Graph>();
549 const Value* aValue = graph->addInput();
550 const Value* bValue = graph->addInput();
551 const Value* cValue = graph->addInput();
552 const Value* dValue = graph->addInput();
553 const Value* eValue = graph->addInput();
554 const Value* fValue = graph->addInput();
555 const Value* gValue = graph->addInput();
564 auto a = t.makeFreshValue(aValue);
565 auto b = t.makeFreshValue(bValue);
566 auto c = t.makeFreshValue(cValue);
567 auto d = t.makeFreshValue(dValue);
568 auto e = t.makeFreshValue(eValue);
569 auto f = t.makeFreshValue(fValue);
570 auto g = t.makeFreshValue(gValue);
571 t.makePointerTo(b, a);
572 t.makePointerTo(c, b);
573 t.makePointerTo(d, b);
574 t.makePointerTo(e, a);
575 t.makePointerTo(e, f);
581 ASSERT_TRUE(t.mayAlias(a, a));
582 ASSERT_TRUE(t.mayAlias(g, g));
585 ASSERT_TRUE(t.mayAlias(a, b));
586 ASSERT_TRUE(t.mayAlias(a, c));
587 ASSERT_TRUE(t.mayAlias(c, d));
590 ASSERT_TRUE(t.mayAlias(e, a));
591 ASSERT_TRUE(t.mayAlias(e, f));
593 ASSERT_FALSE(t.mayAlias(a, f));
598 std::multiset<const Element*> foo{c, c, d};
599 std::multiset<const Element*> bar{e, f};
600 std::unordered_set<const Element*> baz{f, g};
601 ASSERT_TRUE(t.mayAlias(foo, bar));
602 ASSERT_TRUE(t.mayAlias(bar, baz));
603 ASSERT_FALSE(t.mayAlias(foo, baz));
Registration class for new operators.
An utility class for setting temporary insertion points.