Caffe2 - C++ API
A deep learning, cross platform ML framework
graph_node_list.h
1 #pragma once
2 
3 #include <c10/util/Exception.h>
4 
5 namespace torch {
6 namespace jit {
7 
8 // Intrusive doubly linked lists with sane reverse iterators.
9 // The header file is named generic_graph_node_list.h because it is ONLY
10 // used for Graph's Node lists, and if you want to use it for other
11 // things, you will have to do some refactoring.
12 //
13 // At the moment, the templated type T must support a few operations:
14 //
15 // - It must have a field: T* next_in_graph[2] = { nullptr, nullptr };
16 // which are used for the intrusive linked list pointers.
17 //
18 // - It must have a method 'destroy()', which removes T from the
19 // list and frees a T.
20 //
21 // In practice, we are only using it with Node and const Node. 'destroy()'
22 // needs to be renegotiated if you want to use this somewhere else.
23 //
24 // Besides the benefits of being intrusive, unlike std::list, these lists handle
25 // forward and backward iteration uniformly because we require a
26 // "before-first-element" sentinel. This means that reverse iterators
27 // physically point to the element they logically point to, rather than
28 // the off-by-one behavior for all standard library reverse iterators.
29 
30 static constexpr int kNextDirection = 0;
31 static constexpr int kPrevDirection = 1;
32 
33 template <typename T>
35 
36 template <typename T>
38 
39 struct Node;
45 
46 template <typename T>
48  generic_graph_node_list_iterator() : cur(nullptr), d(kNextDirection) {}
49  generic_graph_node_list_iterator(T* cur, int d) : cur(cur), d(d) {}
51  const generic_graph_node_list_iterator& rhs) = default;
53  default;
55  const generic_graph_node_list_iterator& rhs) = default;
57  generic_graph_node_list_iterator&& rhs) = default;
58  T* operator*() const {
59  return cur;
60  }
61  T* operator->() const {
62  return cur;
63  }
64  generic_graph_node_list_iterator& operator++() {
65  AT_ASSERT(cur);
66  cur = cur->next_in_graph[d];
67  return *this;
68  }
69  generic_graph_node_list_iterator operator++(int) {
71  ++(*this);
72  return old;
73  }
74  generic_graph_node_list_iterator& operator--() {
75  AT_ASSERT(cur);
76  cur = cur->next_in_graph[reverseDir()];
77  return *this;
78  }
79  generic_graph_node_list_iterator operator--(int) {
81  --(*this);
82  return old;
83  }
84 
85  // erase cur without invalidating this iterator
86  // named differently from destroy so that ->/. bugs do not
87  // silently cause the wrong one to be called.
88  // iterator will point to the previous entry after call
89  void destroyCurrent() {
90  T* n = cur;
91  cur = cur->next_in_graph[reverseDir()];
92  n->destroy();
93  }
95  return generic_graph_node_list_iterator(cur, reverseDir());
96  }
97 
98  private:
99  int reverseDir() {
100  return d == kNextDirection ? kPrevDirection : kNextDirection;
101  }
102  T* cur;
103  int d; // direction 0 is forward 1 is reverse, see next_in_graph
104 };
105 
106 template <typename T>
108  using iterator = generic_graph_node_list_iterator<T>;
109  using const_iterator = generic_graph_node_list_iterator<const T>;
111  return generic_graph_node_list_iterator<T>(head->next_in_graph[d], d);
112  }
114  return generic_graph_node_list_iterator<const T>(head->next_in_graph[d], d);
115  }
117  return generic_graph_node_list_iterator<T>(head, d);
118  }
121  }
123  return reverse().begin();
124  }
126  return reverse().begin();
127  }
129  return reverse().end();
130  }
132  return reverse().end();
133  }
134  generic_graph_node_list reverse() {
136  head, d == kNextDirection ? kPrevDirection : kNextDirection);
137  }
138  const generic_graph_node_list reverse() const {
140  head, d == kNextDirection ? kPrevDirection : kNextDirection);
141  }
142  T* front() {
143  return head->next_in_graph[d];
144  }
145  const T* front() const {
146  return head->next_in_graph[d];
147  }
148  T* back() {
149  return head->next_in_graph[!d];
150  }
151  const T* back() const {
152  return head->next_in_graph[!d];
153  }
154  generic_graph_node_list(T* head, int d) : head(head), d(d) {}
155 
156  private:
157  T* head;
158  int d;
159 };
160 
161 template <typename T>
162 static inline bool operator==(
165  return *a == *b;
166 }
167 
168 template <typename T>
169 static inline bool operator!=(
172  return *a != *b;
173 }
174 
175 } // namespace jit
176 } // namespace torch
177 
178 namespace std {
179 
180 template <typename T>
181 struct iterator_traits<torch::jit::generic_graph_node_list_iterator<T>> {
182  using difference_type = int64_t;
183  using value_type = T*;
184  using pointer = T**;
185  using reference = T*&;
186  using iterator_category = bidirectional_iterator_tag;
187 };
188 
189 } // namespace std
Definition: jit_type.h:17