Caffe2 - C++ API
A deep learning, cross platform ML framework
test_dynamic_dag.h
1 #pragma once
2 
3 #include "test/cpp/jit/test_base.h"
4 #include "test/cpp/jit/test_utils.h"
5 
6 #include "torch/csrc/jit/dynamic_dag.h"
7 
8 namespace torch {
9 namespace jit {
10 namespace test {
11 
12 std::unique_ptr<detail::DynamicDAG<std::string>> newDynamicDAG() {
13  return std::unique_ptr<detail::DynamicDAG<std::string>>(
14  new detail::DynamicDAG<std::string>());
15 }
16 
17 void testNewVertex() {
18  auto graph = newDynamicDAG();
19  AT_ASSERT(graph->debugNumVertices() == 0);
20  auto a = graph->newVertex("a");
21  AT_ASSERT(graph->debugNumVertices() == 1);
22  AT_ASSERT(a->ord == 0);
23  AT_ASSERT(a->data.size() == 1);
24  AT_ASSERT(a->data[0] == "a");
25  AT_ASSERT(a->in_edges().size() == 0);
26  AT_ASSERT(a->out_edges().size() == 0);
27  auto b = graph->newVertex("b");
28  auto c = graph->newVertex("c");
29  AT_ASSERT(graph->debugNumVertices() == 3);
30  AT_ASSERT(b->ord == 1);
31  AT_ASSERT(c->ord == 2);
32 }
33 
34 void testAddEdgeBasic() {
35  // a -> b -> c
36  // \---------^
37  auto graph = newDynamicDAG();
38  auto a = graph->newVertex("a");
39  auto b = graph->newVertex("b");
40  auto c = graph->newVertex("c");
41  graph->addEdge(a, b);
42  graph->addEdge(b, c);
43  graph->addEdge(a, c);
44  AT_ASSERT(a->in_edges().size() == 0);
45  AT_ASSERT(a->out_edges().size() == 2);
46  AT_ASSERT(a->out_edges().contains(b));
47  AT_ASSERT(a->out_edges().contains(c));
48  AT_ASSERT(b->in_edges().size() == 1);
49  AT_ASSERT(b->out_edges().size() == 1);
50  AT_ASSERT(b->in_edges().contains(a));
51  AT_ASSERT(b->out_edges().contains(c));
52  AT_ASSERT(c->in_edges().size() == 2);
53  AT_ASSERT(c->out_edges().size() == 0);
54  AT_ASSERT(c->in_edges().contains(a));
55  AT_ASSERT(c->in_edges().contains(b));
56 }
57 
58 void testAddEdgeCycleDetection() {
59  // a -> b -> c
60  // ^---------/
61  auto graph = newDynamicDAG();
62  auto a = graph->newVertex("a");
63  auto b = graph->newVertex("b");
64  auto c = graph->newVertex("c");
65  graph->addEdge(a, b);
66  graph->addEdge(b, c);
67  bool erred = false;
68  try {
69  graph->addEdge(c, a);
70  } catch (c10::Error& err) {
71  erred = true;
72  }
73  AT_ASSERT(erred);
74 }
75 
76 void testAddEdgeReordersBasic() {
77  // a, b => b -> a
78  auto graph = newDynamicDAG();
79  auto a = graph->newVertex("a");
80  auto b = graph->newVertex("b");
81  AT_ASSERT(a->ord == 0);
82  AT_ASSERT(b->ord == 1);
83  graph->addEdge(b, a);
84  AT_ASSERT(a->ord == 1);
85  AT_ASSERT(b->ord == 0);
86 }
87 
88 void testAddEdgeReordersComplicated() {
89  // a -> b c -> d with addEdge(d, b) ==>
90  // c -> d -> a -> b
91  auto graph = newDynamicDAG();
92  auto a = graph->newVertex("a");
93  auto b = graph->newVertex("b");
94  auto c = graph->newVertex("c");
95  auto d = graph->newVertex("d");
96  graph->addEdge(a, b);
97  graph->addEdge(c, d);
98  AT_ASSERT(a->ord == 0);
99  AT_ASSERT(b->ord == 1);
100  AT_ASSERT(c->ord == 2);
101  AT_ASSERT(d->ord == 3);
102  graph->addEdge(d, a);
103  AT_ASSERT(c->ord == 0);
104  AT_ASSERT(d->ord == 1);
105  AT_ASSERT(a->ord == 2);
106  AT_ASSERT(b->ord == 3);
107  AT_ASSERT(c->in_edges().size() == 0);
108  AT_ASSERT(c->out_edges().size() == 1);
109  AT_ASSERT(c->out_edges().contains(d));
110  AT_ASSERT(d->in_edges().size() == 1);
111  AT_ASSERT(d->out_edges().size() == 1);
112  AT_ASSERT(d->in_edges().contains(c));
113  AT_ASSERT(d->out_edges().contains(a));
114  AT_ASSERT(a->in_edges().size() == 1);
115  AT_ASSERT(a->out_edges().size() == 1);
116  AT_ASSERT(a->in_edges().contains(d));
117  AT_ASSERT(a->out_edges().contains(b));
118  AT_ASSERT(b->in_edges().size() == 1);
119  AT_ASSERT(b->out_edges().size() == 0);
120  AT_ASSERT(b->in_edges().contains(a));
121 }
122 
123 void testRemoveEdgeBasic() {
124  // a -> b
125  auto graph = newDynamicDAG();
126  auto a = graph->newVertex("a");
127  auto b = graph->newVertex("b");
128  graph->addEdge(a, b);
129  AT_ASSERT(graph->debugNumVertices() == 2);
130  graph->removeEdge(a, b);
131  AT_ASSERT(graph->debugNumVertices() == 2);
132  AT_ASSERT(a->out_edges().size() == 0);
133  AT_ASSERT(b->in_edges().size() == 0);
134 }
135 
136 void testRemoveVertexBasic() {
137  // a -> b
138  auto graph = newDynamicDAG();
139  auto a = graph->newVertex("a");
140  auto b = graph->newVertex("b");
141  auto c = graph->newVertex("c");
142  graph->addEdge(a, b);
143  graph->addEdge(b, c);
144  AT_ASSERT(graph->debugNumVertices() == 3);
145  graph->removeVertex(b);
146  AT_ASSERT(graph->debugNumVertices() == 2);
147  AT_ASSERT(a->out_edges().size() == 0);
148  AT_ASSERT(c->in_edges().size() == 0);
149 }
150 
151 void testContractEdgeBasic() {
152  // a -> b -> c -> d
153  auto graph = newDynamicDAG();
154  auto a = graph->newVertex("a");
155  auto b = graph->newVertex("b");
156  auto c = graph->newVertex("c");
157  auto d = graph->newVertex("d");
158  graph->addEdge(a, b);
159  graph->addEdge(b, c);
160  graph->addEdge(c, d);
161  graph->contractEdge(b, c);
162  AT_ASSERT(graph->debugNumVertices() == 3);
163  AT_ASSERT(a->out_edges().size() == 1);
164  AT_ASSERT(d->in_edges().size() == 1);
165  AT_ASSERT(*a->out_edges().begin() == *d->in_edges().begin());
166  auto* contracted = *a->out_edges().begin();
167  AT_ASSERT(contracted->data.size() == 2);
168  AT_ASSERT(contracted->data[0] == "b");
169  AT_ASSERT(contracted->data[1] == "c");
170  AT_ASSERT(contracted->out_edges().size() == 1);
171  AT_ASSERT(contracted->in_edges().size() == 1);
172  AT_ASSERT(contracted->in_edges().contains(a));
173  AT_ASSERT(contracted->out_edges().contains(d));
174 }
175 
176 void testContractEdgeCycleDetection() {
177  // a -> b -> c
178  // `---------^
179  // contractEdge(a, c) will cause a cycle
180  auto graph = newDynamicDAG();
181  auto a = graph->newVertex("a");
182  auto b = graph->newVertex("b");
183  auto c = graph->newVertex("c");
184  graph->addEdge(a, b);
185  graph->addEdge(b, c);
186  graph->addEdge(a, c);
187  AT_ASSERT(!graph->contractEdge(a, c));
188 }
189 
190 void testDynamicDAG() {
191  testNewVertex();
192  testAddEdgeBasic();
193  testAddEdgeCycleDetection();
194  testAddEdgeReordersBasic();
195  testAddEdgeReordersComplicated();
196  testRemoveEdgeBasic();
197  testRemoveVertexBasic();
198  testContractEdgeBasic();
199  testContractEdgeCycleDetection();
200 }
201 } // namespace test
202 } // namespace jit
203 } // namespace torch
Definition: module.cpp:17
The primary ATen error class.
Definition: Exception.h:27
Definition: jit_type.h:17