3 #include "test/cpp/jit/test_base.h" 4 #include "test/cpp/jit/test_utils.h" 6 #include "torch/csrc/jit/dynamic_dag.h" 12 std::unique_ptr<detail::DynamicDAG<std::string>> newDynamicDAG() {
13 return std::unique_ptr<detail::DynamicDAG<std::string>>(
14 new detail::DynamicDAG<std::string>());
17 void testNewVertex() {
18 auto graph = newDynamicDAG();
19 AT_ASSERT(graph->debugNumVertices() == 0);
20 auto a = graph->newVertex(
"a");
21 AT_ASSERT(graph->debugNumVertices() == 1);
22 AT_ASSERT(a->ord == 0);
23 AT_ASSERT(a->data.size() == 1);
24 AT_ASSERT(a->data[0] ==
"a");
25 AT_ASSERT(a->in_edges().size() == 0);
26 AT_ASSERT(a->out_edges().size() == 0);
27 auto b = graph->newVertex(
"b");
28 auto c = graph->newVertex(
"c");
29 AT_ASSERT(graph->debugNumVertices() == 3);
30 AT_ASSERT(b->ord == 1);
31 AT_ASSERT(c->ord == 2);
34 void testAddEdgeBasic() {
37 auto graph = newDynamicDAG();
38 auto a = graph->newVertex(
"a");
39 auto b = graph->newVertex(
"b");
40 auto c = graph->newVertex(
"c");
44 AT_ASSERT(a->in_edges().size() == 0);
45 AT_ASSERT(a->out_edges().size() == 2);
46 AT_ASSERT(a->out_edges().contains(b));
47 AT_ASSERT(a->out_edges().contains(c));
48 AT_ASSERT(b->in_edges().size() == 1);
49 AT_ASSERT(b->out_edges().size() == 1);
50 AT_ASSERT(b->in_edges().contains(a));
51 AT_ASSERT(b->out_edges().contains(c));
52 AT_ASSERT(c->in_edges().size() == 2);
53 AT_ASSERT(c->out_edges().size() == 0);
54 AT_ASSERT(c->in_edges().contains(a));
55 AT_ASSERT(c->in_edges().contains(b));
58 void testAddEdgeCycleDetection() {
61 auto graph = newDynamicDAG();
62 auto a = graph->newVertex(
"a");
63 auto b = graph->newVertex(
"b");
64 auto c = graph->newVertex(
"c");
76 void testAddEdgeReordersBasic() {
78 auto graph = newDynamicDAG();
79 auto a = graph->newVertex(
"a");
80 auto b = graph->newVertex(
"b");
81 AT_ASSERT(a->ord == 0);
82 AT_ASSERT(b->ord == 1);
84 AT_ASSERT(a->ord == 1);
85 AT_ASSERT(b->ord == 0);
88 void testAddEdgeReordersComplicated() {
91 auto graph = newDynamicDAG();
92 auto a = graph->newVertex(
"a");
93 auto b = graph->newVertex(
"b");
94 auto c = graph->newVertex(
"c");
95 auto d = graph->newVertex(
"d");
98 AT_ASSERT(a->ord == 0);
99 AT_ASSERT(b->ord == 1);
100 AT_ASSERT(c->ord == 2);
101 AT_ASSERT(d->ord == 3);
102 graph->addEdge(d, a);
103 AT_ASSERT(c->ord == 0);
104 AT_ASSERT(d->ord == 1);
105 AT_ASSERT(a->ord == 2);
106 AT_ASSERT(b->ord == 3);
107 AT_ASSERT(c->in_edges().size() == 0);
108 AT_ASSERT(c->out_edges().size() == 1);
109 AT_ASSERT(c->out_edges().contains(d));
110 AT_ASSERT(d->in_edges().size() == 1);
111 AT_ASSERT(d->out_edges().size() == 1);
112 AT_ASSERT(d->in_edges().contains(c));
113 AT_ASSERT(d->out_edges().contains(a));
114 AT_ASSERT(a->in_edges().size() == 1);
115 AT_ASSERT(a->out_edges().size() == 1);
116 AT_ASSERT(a->in_edges().contains(d));
117 AT_ASSERT(a->out_edges().contains(b));
118 AT_ASSERT(b->in_edges().size() == 1);
119 AT_ASSERT(b->out_edges().size() == 0);
120 AT_ASSERT(b->in_edges().contains(a));
123 void testRemoveEdgeBasic() {
125 auto graph = newDynamicDAG();
126 auto a = graph->newVertex(
"a");
127 auto b = graph->newVertex(
"b");
128 graph->addEdge(a, b);
129 AT_ASSERT(graph->debugNumVertices() == 2);
130 graph->removeEdge(a, b);
131 AT_ASSERT(graph->debugNumVertices() == 2);
132 AT_ASSERT(a->out_edges().size() == 0);
133 AT_ASSERT(b->in_edges().size() == 0);
136 void testRemoveVertexBasic() {
138 auto graph = newDynamicDAG();
139 auto a = graph->newVertex(
"a");
140 auto b = graph->newVertex(
"b");
141 auto c = graph->newVertex(
"c");
142 graph->addEdge(a, b);
143 graph->addEdge(b, c);
144 AT_ASSERT(graph->debugNumVertices() == 3);
145 graph->removeVertex(b);
146 AT_ASSERT(graph->debugNumVertices() == 2);
147 AT_ASSERT(a->out_edges().size() == 0);
148 AT_ASSERT(c->in_edges().size() == 0);
151 void testContractEdgeBasic() {
153 auto graph = newDynamicDAG();
154 auto a = graph->newVertex(
"a");
155 auto b = graph->newVertex(
"b");
156 auto c = graph->newVertex(
"c");
157 auto d = graph->newVertex(
"d");
158 graph->addEdge(a, b);
159 graph->addEdge(b, c);
160 graph->addEdge(c, d);
161 graph->contractEdge(b, c);
162 AT_ASSERT(graph->debugNumVertices() == 3);
163 AT_ASSERT(a->out_edges().size() == 1);
164 AT_ASSERT(d->in_edges().size() == 1);
165 AT_ASSERT(*a->out_edges().begin() == *d->in_edges().begin());
166 auto* contracted = *a->out_edges().begin();
167 AT_ASSERT(contracted->data.size() == 2);
168 AT_ASSERT(contracted->data[0] ==
"b");
169 AT_ASSERT(contracted->data[1] ==
"c");
170 AT_ASSERT(contracted->out_edges().size() == 1);
171 AT_ASSERT(contracted->in_edges().size() == 1);
172 AT_ASSERT(contracted->in_edges().contains(a));
173 AT_ASSERT(contracted->out_edges().contains(d));
176 void testContractEdgeCycleDetection() {
180 auto graph = newDynamicDAG();
181 auto a = graph->newVertex(
"a");
182 auto b = graph->newVertex(
"b");
183 auto c = graph->newVertex(
"c");
184 graph->addEdge(a, b);
185 graph->addEdge(b, c);
186 graph->addEdge(a, c);
187 AT_ASSERT(!graph->contractEdge(a, c));
190 void testDynamicDAG() {
193 testAddEdgeCycleDetection();
194 testAddEdgeReordersBasic();
195 testAddEdgeReordersComplicated();
196 testRemoveEdgeBasic();
197 testRemoveVertexBasic();
198 testContractEdgeBasic();
199 testContractEdgeCycleDetection();
The primary ATen error class.