Caffe2 - C++ API A deep learning, cross platform ML framework
dynamic_dag.h
1 #pragma once
2
3 #include <algorithm>
4 #include <iostream>
5 #include <type_traits>
6 #include <unordered_set>
7 #include <vector>
8
9 #include <ATen/core/functional.h>
10 #include <torch/csrc/utils/memory.h>
11
12 namespace torch {
13 namespace jit {
14 namespace detail {
15
16 // DynamicDAG is a simple directed acyclic graph that dynamically maintains a
17 // topological order as edges/vertices are added and removed.
18 //
19 // [Example applications]
20 // - Let's say you have a DAG where each vertex is black or red. How do we
21 // merge black nodes that are directly connected by contracting the
22 // edge between them while still maintaining the DAG and a topological order?
23 // Use contractEdge().
24 // - Let's say you have a DAG where each vertex is a Node* and the edges
25 // represent data dependencies. We wish to determine if adding a new Node*
26 // with certain data dependencies (or moving an existing one to use new
27 // dependencies) is valid. Use DynamicDAG::addEdge() to add the new data
28 // dependencies to the DAG: it will either find a valid reordering of the
29 // DAG's topological order or throw if the resulting DAG is invalid.
30 //
31 // The implementation is based off of the PK algorithm in the following paper:
32 // "A Dynamic Topsort Algorithm for Directed Acyclic Graphs"
33 // by David Pearce and Paul Kelly
34 // https://www.doc.ic.ac.uk/~phjk/Publications/DynamicTopoSortAlg-JEA-07.pdf
36
37 template <typename T>
38 struct Vertex;
39 template <typename T>
40 struct DynamicDAG;
41 template <typename T>
42 using vertex_list = std::vector<Vertex<T>*>;
43 template <typename T>
44 using unique_vertex = std::unique_ptr<Vertex<T>>;
45
46 enum class DFSDirection { forward, backward };
47
48 // Used to represent adjacency lists in DynamicDAG.
49 // Has set semantics: stores distinct elements.
50 //
51 // Because our graphs shouldn't fan out or in very much,
52 // we use std::vector<Vertex<T>*> to record edges.
53 // In all of the complexity analysis it is assumed that
54 // inserting, erasing, and finding take constant time.
55 template <typename T>
56 struct vertex_set {
57  using iterator = typename vertex_list<T>::iterator;
58  using reverse_iterator = typename vertex_list<T>::reverse_iterator;
59
60  // returns if we inserted v into the set.
61  bool insert(Vertex<T>* v) {
62  if (contains(v)) {
63  return false;
64  } else {
65  data_.push_back(v);
66  return true;
67  }
68  }
69  void erase(Vertex<T>* v) {
70  data_.erase(std::find(data_.begin(), data_.end(), v));
71  }
72  bool contains(Vertex<T>* v) const {
73  return std::find(data_.begin(), data_.end(), v) != data_.end();
74  }
75  void sort() {
76  std::sort(data_.begin(), data_.end(), [](Vertex<T>* a, Vertex<T>* b) {
77  return a->ord < b->ord;
78  });
79  }
80  size_t size() const {
81  return data_.size();
82  }
83  iterator begin() {
84  return data_.begin();
85  }
86  iterator end() {
87  return data_.end();
88  }
89  reverse_iterator rbegin() {
90  return data_.rbegin();
91  }
92  reverse_iterator rend() {
93  return data_.rend();
94  }
95
96  private:
97  std::vector<Vertex<T>*> data_;
98 };
99
100 template <typename T>
101 struct IOEdges {
102  vertex_set<T> in_edges;
103  vertex_set<T> out_edges;
104 };
105
106 // Simple RAII wrapper around a vertex_list<T>.
107 // When adding a vertex to the list, mark it as visited.
108 // Clears the visited flag of each vertex in the vertex_list on deletion.
109 template <typename T>
110 struct visited_list {
111  ~visited_list() {
112  for (auto* v : data_) {
113  v->visited_ = false;
114  }
115  }
116
117  void push_back(Vertex<T>* elt) {
118  AT_ASSERT(!elt->visited_);
119  elt->visited_ = true;
120  data_.push_back(elt);
121  }
122
123  void sort() {
124  std::sort(data_.begin(), data_.end(), [](Vertex<T>* a, Vertex<T>* b) {
125  return a->ord < b->ord;
126  });
127  }
128
129  const vertex_list<T>& vector() {
130  return data_;
131  }
132
133  private:
134  vertex_list<T> data_;
135 };
136
137 template <typename T>
138 struct Vertex {
139  Vertex(size_t ord, T datum) : ord(ord), visited_(false) {
140  data.push_back(datum);
141  }
142
143  std::vector<T> data;
144  size_t ord; // unique topological index
145
146  std::string toString();
147  vertex_set<T>& in_edges() {
148  return edges_.in_edges;
149  }
150  vertex_set<T>& out_edges() {
151  return edges_.out_edges;
152  }
153  IOEdges<T>&& move_edges() {
154  return std::move(edges_);
155  }
156
157  bool visited() {
158  return visited_;
159  }
160
161  private:
162  IOEdges<T> edges_;
163
164  friend visited_list<T>;
165  bool visited_; // If this vertex has been visited
166 };
167
168 template <typename T>
169 struct DynamicDAG {
170  Vertex<T>* newVertex(T datum);
171  IOEdges<T> removeVertex(Vertex<T>* v);
172
173  void addEdge(Vertex<T>* producer, Vertex<T>* consumer);
174  void removeEdge(Vertex<T>* producer, Vertex<T>* consumer);
175  bool contractEdge(Vertex<T>* producer, Vertex<T>* consumer);
176
177  // max_size() >= the number of live vertices.
178  // for all vertices v, v.ord < max_size()
179  size_t max_size() const {
180  return vertices_.size();
181  };
182  c10::optional<Vertex<T>*> at(size_t ord) const;
183
184  std::string toString();
185
186  // Use for debugging. Don't call these often.
187  size_t debugNumVertices() const;
188  void debugCheckInvariants();
189
190  private:
191  void mergeProducerIntoConsumer(Vertex<T>* producer, Vertex<T>* consumer);
192  void mergeConsumerIntoProducer(Vertex<T>* producer, Vertex<T>* consumer);
193  void reorder(visited_list<T> deltaF, visited_list<T> deltaB);
194  bool contractionProducesCycle(Vertex<T>* producer, Vertex<T>* consumer);
195  bool dfsSearch(
196  DFSDirection direction,
197  Vertex<T>* start,
198  Vertex<T>* end,
199  size_t bound,
200  visited_list<T>& visited);
201
202  // Store vertices indexed by their topological order.
203  // If a vertex v has ord 5, then it can be found at vertices_[5].
204  // There may be gaps in vertices_; this is to enable fast deletion.
205  std::vector<unique_vertex<T>> vertices_;
206 };
207
208 // O(vertices_.size()). Used for testing, don't call this often.
209 template <typename T>
210 size_t DynamicDAG<T>::debugNumVertices() const {
211  return std::count_if(
212  vertices_.begin(), vertices_.end(), [](const unique_vertex<T>& v) {
213  if (v)
214  return true;
215  return false;
216  });
217 }
218
219 template <typename T>
221  vertices_.push_back(torch::make_unique<Vertex<T>>(vertices_.size(), datum));
222  return vertices_.back().get();
223 }
224
225 template <typename T>
226 void DynamicDAG<T>::removeEdge(Vertex<T>* producer, Vertex<T>* consumer) {
227  AT_ASSERT(producer != consumer);
228  AT_ASSERT(producer->out_edges().contains(consumer));
229  AT_ASSERT(consumer->in_edges().contains(producer));
230  producer->out_edges().erase(consumer);
231  consumer->in_edges().erase(producer);
232 }
233
234 template <typename T>
236  for (size_t ord = 0; ord < vertices_.size(); ++ord) {
237  const auto& vertex = vertices_.at(ord);
238  if (!vertex)
239  continue;
240
241  AT_ASSERTM(vertex->ord == ord, toString());
242  for (auto* v : vertex->in_edges()) {
243  AT_ASSERTM(v->ord < ord, toString());
244  }
245  for (auto* v : vertex->out_edges()) {
246  AT_ASSERTM(v->ord > ord, toString());
247  }
248  }
249 }
250
251 template <typename T>
252 c10::optional<Vertex<T>*> DynamicDAG<T>::at(size_t ord) const {
253  const auto& vertex = vertices_.at(ord);
254  if (!vertex) {
255  return c10::nullopt;
256  } else {
257  return vertex.get();
258  }
259 }
260
261 template <typename T>
263  for (auto* parent : v->in_edges()) {
264  parent->out_edges().erase(v);
265  }
266  for (auto* child : v->out_edges()) {
267  child->in_edges().erase(v);
268  }
269  auto edges = v->move_edges();
270  vertices_[v->ord] = nullptr;
271  return edges;
272 }
273
274 /*
276  * When adding an edge x -> y,
277  * - if ord(x) < ord(y), don't do anything.
278  * - if ord(y) < ord(x), some graph reordering must occur.
279  *
280  * Assume we are adding an edge x -> y and that ord(x) > ord(y).
281  * First, if there is a path y ----> x through some other vertices, then this
282  * edge addition would create a cycle. Figure this out via DFS and throw if
283  * necessary.
284  *
285  * Now, consider the set of all vertices v such that ord(x) > ord(v) > ord(y).
286  * Call this set the affected region (AR) -- these are the only vertices we
287  * need to consider for reordering to make the resulting graph valid.
288  *
289  * Find all children of y (through DFS) in AR (call this set deltaF and add y to
290  * it) Find all parents of x in AR (call this set deltaB and add x to it).
291  *
292  * Move y and all the children of y to after x and all the parents of x. The
293  * result topological ordering is valid.
294  *
295  * [Visual algorithm reference]
296  * Higher nodes come earlier in topological order.
297  * We are adding an edge between x -> y.
298  * The topological ordering is e, y, c, a, d, b, x, f.
299  * The affected region is {y, c, a, d, b, x}. e and f cannot be involved
300  * in the reorder.
301  *
302  * (e) <- ord = 0 -> (e)
303  * | |
304  * v v
305  * (y) <- ord = 1 -> \ (c)
306  * ^ \ -----\ |
307  * (c) | v <- ord = 2 -> -----/ (d) v
308  * \ | (a) <- ord = 3 -> / \->(x)
309  * || | /\
310  * (d) || | <- ord = 4 -> (y)<-/ \
311  * | || v \ |
312  * \ v| (b) <- ord = 5 -> \->(a) |
313  * ->(x) <- ord = 6 -> (b)<--/ v
314  * \->(f) <- ord = 7 -> (f)
315  *
316  * We find all children of y in the affected region. deltaF = {y, a, b}
317  * We find all parents of x via DFS. deltaB = {c, d, x}
318  *
319  * Now, we reorder all vertices in deltaB to come before deltaF. This is
320  * a little involved and happens in four steps:
321  *
322  * 1) sort all vertices in deltaB, and all vertices in deltaF.
323  * deltaB (sorted) = {c(2), d(4), x(6)}. deltaB ords = { 2, 4, 6 }
324  * deltaF (sorted) = {y(1), a(3), b(5)}. deltaF ords = { 1, 3, 5 }
325  *
326  * 2) append the two lists: the result is the order we want these vertices to
327  * have.
328  * L = {c(2), d(4), x(6), y(1), a(3), b(5)}.
329  *
330  * 3) Merge the sorted ords: R = { 1, 2, 3, 4, 5, 6 }.
331  *
332  * 4) Reassign the vertices in L in order with the sorted ords.
333  * We always use the vertices in deltaB, then deltaF, in that order.
334  * L = { c(1), d(2), x(3), y(4) a(5), b(6) }
335  *
336  * This produces th graph shown on the right.
337  *
338  * [Analysis]
339  * This is O(|AR| log |AR|). |AR| is equal to ord(consumer) - ord(producer).
340  * AR is the "affected region": { v s.t. ord(v) in [ord(producer),
341  * ord(consumer)] } consisting of the only vertices that can possibly be moved
342  * around due to this edge addition.
343  *
344  * NB: Pearce and Kelly give a complexity bound of <<delta>> where
345  * delta = union(deltaF, deltaB) and <<S>> on a set S is
346  * <<S>> = |S| + |edges out of vertices of S| + |edges into vertices of S|.
347  */
348 template <typename T>
349 void DynamicDAG<T>::addEdge(Vertex<T>* producer, Vertex<T>* consumer) {
350  AT_ASSERT(producer != consumer);
351
352  // NB: DynamicDAG is a simple graph. If an edge exists already, don't do
353  // anything.
354  bool is_distinct = producer->out_edges().insert(consumer);
355  if (!is_distinct)
356  return;
357  is_distinct = consumer->in_edges().insert(producer);
358  AT_ASSERT(is_distinct);
359
360  if (producer->ord <= consumer->ord) {
361  // topological ordering is already consistent, no need to update.
362  return;
363  }
364
365  visited_list<T> deltaF;
366  visited_list<T> deltaB;
367
368  // Search for vertices that are reachable from consumer that have a now
369  // incorrect topological ordering.
370  if (dfsSearch(
371  DFSDirection::forward,
372  consumer,
373  producer,
374  /*bound=*/producer->ord,
375  deltaF)) {
376  // Path found! This means there's a cycle.
377  AT_ERROR("Cycle detected while trying to add edge.");
378  }
379
380  // Search for vertices that can reach producer that have a now incorrect
381  // topological ordering
382  AT_ASSERT(!dfsSearch(
383  DFSDirection::backward,
384  producer,
385  consumer,
386  /*bound=*/consumer->ord,
387  deltaB));
388
389  // Reorder the vertices that are reachable from consumer to occur BEFORE
390  // the vertices that can reach producer.
391  reorder(std::move(deltaF), std::move(deltaB));
392 }
393
394 // Define the affected region for contractEdge(producer, consumer) as
395 // { v s.t. ord(v) in [ord(producer), ord(consumer)] }.
396 // These are the only vertices that can possibly be moved around
397 // during edge contraction.
398 //
399 // contractEdge is O(|AR| log |AR| * min(|out_edges(producer)|,
400 // |in_edges(consumer)|))
401 template <typename T>
402 bool DynamicDAG<T>::contractEdge(Vertex<T>* producer, Vertex<T>* consumer) {
403  AT_ASSERT(producer != consumer);
404  if (contractionProducesCycle(producer, consumer)) {
405  return false;
406  }
407
408  removeEdge(producer, consumer);
409
410  // Optimization: pick which order to merge depending on potential complexity.
411  if (producer->out_edges().size() > consumer->in_edges().size()) {
412  mergeConsumerIntoProducer(producer, consumer);
413  } else {
414  mergeProducerIntoConsumer(producer, consumer);
415  }
416
417  return true;
418 }
419
420 template <typename T>
422  Vertex<T>* producer,
423  Vertex<T>* consumer) {
424  // Optimization: we want to concat lists [producer.data, consumer.data].
425  // Instead of inserting into the beginning of consumer.data, do a swap.
426  producer->data.insert(
427  producer->data.end(), consumer->data.begin(), consumer->data.end());
428  std::swap(consumer->data, producer->data);
429
430  auto edges = removeVertex(producer);
431
432  // Each of these are constant b/c ord(consumer) > ord(producer) > ord(parent)
433  // so the edge addition still preserves the existing topological order.
434  for (auto* parent : edges.in_edges) {
436  }
437
438  // NB: each addEdge call is linear in (ord(consumer) - ord(child)).
439  // This makes this function O(|out_edges(producer)| * |AR| log |AR|).
440  for (auto* child : edges.out_edges) {
442  }
443 }
444
445 template <typename T>
447  Vertex<T>* producer,
448  Vertex<T>* consumer) {
449  producer->data.insert(
450  producer->data.end(), consumer->data.begin(), consumer->data.end());
451
452  auto edges = removeVertex(consumer);
453
454  // Each of these are constant b/c ord(child) > ord(consumer) > ord(producer)
455  // so the edge addition still preserves the existing topological order.
456  for (auto* child : edges.out_edges) {
458  }
459
460  // NB: each addEdge call is linear in (ord(producer) - ord(parent)).
461  // This makes this function O(|in_edges(consumer)| * |AR| log |AR|).
462  for (auto* parent : edges.in_edges) {
464  }
465 }
466
467 template <typename T>
469  Vertex<T>* producer,
470  Vertex<T>* consumer) {
471  visited_list<T> visited;
472
473  // If there are multiple paths from producer to consumer then contracting
474  // (merging) producer and consumer would create a cycle.
475  //
476  // Search for a path from producer to consumer while ignoring the
477  // producer -> consumer edge.
478  size_t upper_bound = consumer->ord;
479  for (auto* child : producer->out_edges()) {
480  if (child == consumer)
481  continue;
482  if (child->visited())
483  continue; // already visited by dfs
484  if (dfsSearch(
485  DFSDirection::forward, child, consumer, upper_bound, visited)) {
486  return true;
487  }
488  }
489  return false;
490 }
491
492 static bool is_within_bound(
493  DFSDirection direction,
494  size_t value,
495  size_t bound) {
496  if (direction == DFSDirection::forward) {
497  return value < bound; // upper bound
498  } else {
499  return value > bound; // lower bound
500  }
501 }
502
503 // Searches for a path from start to end via a forward or backward dfs.
504 // Returns if a path exists from start to end.
505 // In addition, dfsSearch inserts visited vertices into the visited list.
506 template <typename T>
508  DFSDirection direction,
509  Vertex<T>* start,
510  Vertex<T>* end,
511  size_t bound,
512  visited_list<T>& visited) {
513  vertex_list<T> stack;
514
515  auto visit = [&](Vertex<T>* v) {
516  visited.push_back(v);
517  stack.push_back(v);
518  };
519
520  visit(start);
521
522  while (!stack.empty()) {
523  auto* vertex = stack.back();
524  stack.pop_back();
525
526  auto& next_edges = (direction == DFSDirection::forward)
527  ? vertex->out_edges()
528  : vertex->in_edges();
529
530  for (Vertex<T>* next : next_edges) {
531  if (next == end) {
532  // Path found from start to end.
533  visit(next);
534  return true;
535  }
536  if (!next->visited() && is_within_bound(direction, next->ord, bound)) {
537  visit(next);
538  }
539  }
540  }
541  return false;
542 }
543
544 // Reorder deltaB vertices to occur before deltaF vertices.
545 template <typename T>
547  deltaB.sort();
548  deltaF.sort();
549
550  const auto& deltaB_ = deltaB.vector();
551  const auto& deltaF_ = deltaF.vector();
552
553  size_t num_affected = deltaB_.size() + deltaF_.size();
554
555  // Gather vertices in the desired order. They don't have correct ords yet.
556  std::vector<unique_vertex<T>> desired_vertex_ordering;
557  desired_vertex_ordering.reserve(num_affected);
558  for (auto it = deltaB_.begin(); it != deltaB_.end(); ++it) {
559  desired_vertex_ordering.push_back(std::move(vertices_.at((*it)->ord)));
560  }
561  for (auto it = deltaF_.begin(); it != deltaF_.end(); ++it) {
562  desired_vertex_ordering.push_back(std::move(vertices_.at((*it)->ord)));
563  }
564
565  // Sort the ords by merging two already sorted lists into a large sorted list.
566  // input (example): deltaB = { v(1), v(4), v(7) } ,
567  // deltaF = { v(0), v(2), v(5) }.
568  // output: { 0, 1, 2, 4, 5, 7 }.
569  std::vector<size_t> gathered_ords;
570  gathered_ords.reserve(num_affected);
571  for (const auto* v : deltaB_) {
572  gathered_ords.push_back(v->ord);
573  }
574  auto middle = gathered_ords.size();
575  for (const auto* v : deltaF_) {
576  gathered_ords.push_back(v->ord);
577  }
578  std::inplace_merge(
579  gathered_ords.begin(),
580  gathered_ords.begin() + middle,
581  gathered_ords.end());
582
583  // Return the vertices back into the vertices_ storage.
584  for (size_t i = 0; i < num_affected; ++i) {
585  desired_vertex_ordering[i]->ord = gathered_ords[i];
586  vertices_[gathered_ords[i]] = std::move(desired_vertex_ordering[i]);
587  }
588 }
589
590 template <typename T>
591 std::string DynamicDAG<T>::toString() {
592  std::stringstream ss;
593  for (auto& v : vertices_) {
594  if (v) {
595  ss << v->toString() << "\n";
596  }
597  }
598  return ss.str();
599 }
600
601 template <typename T>
602 std::string Vertex<T>::toString() {
603  std::stringstream ss;
604  ss << "node(" << ord << ")\n";
605  ss << "[";
606  for (auto* c : in_edges()) {
607  ss << c->ord << " ";
608  }
609  ss << "] -> {\n";
610  for (auto& d : data) {
611  if (std::is_pointer<T>::value) {
612  ss << " " << *d;
613  } else {
614  ss << " " << d;
615  }
616  }
617  ss << "} (" << ord << ") -> [";
618  for (auto* c : out_edges()) {
619  ss << c->ord << " ";
620  }
621  ss << "]\n";
622  return ss.str();
623 }
624
625 } // namespace detail
626 } // namespace jit
627 } // namespace torch
Definition: jit_type.h:17
Flush-To-Zero and Denormals-Are-Zero mode.