Caffe2 - C++ API
A deep learning, cross platform ML framework
graph.cc
1 #include "caffe2/core/graph.h"
2 
3 #include "caffe2/core/common.h"
4 #include "caffe2/core/logging.h"
5 #include "caffe2/core/net.h"
6 #include "caffe2/proto/caffe2_pb.h"
7 
8 namespace caffe2 {
9 
10 namespace transform {
11 
12 Graph::Graph(const NetDef& net) : netdef_(net) {
13  nodes_.clear();
14  nodes_.resize(net.op_size());
15 
16  // Copy over operators
17  for (int x = 0; x < net.op_size(); x++) {
18  node(x).op = net.op(x);
19  }
20 
21  // For any blob, which operator was the last to write to it?
22  // In python, this is known as "versions".
23  std::unordered_map<string, int> edge_parent;
24 
25  for (int i = 0; i < (int)nodes_.size(); i++) {
26  for (const string& blob : node(i).op.input()) {
27  auto it = edge_parent.find(blob);
28  if (it != edge_parent.end()) {
29  int j = it->second;
30  node(i).parents[j].push_back(blob);
31  node(j).children[i].push_back(blob);
32  } else {
33  external_input_.insert(blob);
34  }
35  }
36  for (const string& blob : node(i).op.output()) {
37  edge_parent[blob] = i;
38  }
39  }
40 
41  // Traverse opposite direction to find external outputs
42 
43  // For any blob, which operator was the last to read to from it?
44  std::unordered_map<string, int> edge_child;
45 
46  for (int i = (int)nodes_.size() - 1; i >= 0; i--) {
47  for (const string& blob : node(i).op.output()) {
48  auto it = edge_child.find(blob);
49  if (it == edge_child.end()) {
50  external_output_.insert(blob);
51  }
52  }
53  for (const string& blob : node(i).op.input()) {
54  edge_child[blob] = i;
55  }
56  }
57 }
58 
59 const std::vector<std::pair<string, int>> Graph::GetSubgraphInput(
60  const std::vector<int>& match) {
61  return GetSubgraphPerimeterHelper(true, match);
62 }
63 
64 const std::vector<std::pair<string, int>> Graph::GetSubgraphOutput(
65  const std::vector<int>& match) {
66  return GetSubgraphPerimeterHelper(false, match);
67 }
68 
69 // This helper function will either get:
70 // 1) a list for the blobs that write INTO a subgraph
71 // 2) a list of for the blobs that are written FROM a subgraph.
72 //
73 // The "from_children" flag determines if it is case 1 (true) or case 2 (false).
74 const std::vector<std::pair<string, int>> Graph::GetSubgraphPerimeterHelper(
75  bool from_children,
76  const std::vector<int>& match) {
77  std::vector<std::pair<string, int>> edge_list;
78  std::unordered_set<int> match_set(match.begin(), match.end());
79  for (int x = 0; x < (int)nodes_.size(); x++) {
80  if (!is_node_active(x)) {
81  continue;
82  }
83  if (!match_set.count(x)) { // x is not in subgraph
84  const auto& list = from_children ? node(x).children : node(x).parents;
85  for (const auto& edge : list) {
86  int parent = edge.first;
87  const auto& blobs = edge.second;
88  if (match_set.count(parent)) { // but has a parent that is in subgraph
89  for (const string& blob : blobs) {
90  edge_list.push_back({blob, x});
91  }
92  }
93  }
94  }
95  }
96  // return the list in sorted order, to allow binary searching
97  std::sort(edge_list.begin(), edge_list.end());
98  return edge_list;
99 }
100 
102  std::vector<bool> visited(nodes_.size(), false);
103 
104  // Copy over all the properties of the netdef we're based on
105  NetDef netdef = netdef_;
106 
107  // But we're going to put in our own operators.
108  netdef.clear_op();
109 
110  // Keeps track of the number of parents yet to be processed.
111  std::vector<int> unchecked_parent_count;
112 
113  // We will perform a topological traversal on the nodes, but we will prefer
114  // nodes that come earlier in the execution order.
115 
116  // This is a min-heap, which stores its elements in ascending order.
117  // This stores the nodes in the order we process them to be in.
118  // This guarantees the lowest lexicographical topological ordering.
119 
120  // This also means the original nodes will be kept in their execution order.
121  std::priority_queue<int, std::vector<int>, std::greater<int>> q;
122 
123  // In our graph, G, the nodes don't have a strict ordering. But in the netdef,
124  // they must (since nets are operators executed in some order).
125  // How do we make sure that the order of operators in our generated netdef
126  // is valid?
127  // 1) The ordering of the netdef must be topologically sorted, respect to G.
128  // If A -> B is an edge in the graph G, then A must come before B in the
129  // netdef's ordering.
130  // 2) No blob conflicts: If A -> B is an edge in the graph G, and A writes to
131  // blob X and B reads from blob X, then there cannot be an op that writes
132  // to blob X between A and B in the ordering.
133  //
134  // Perform a Topological Sort, to find an order for the Operators to be in.
135  // We will keep track of the number of parents each node has.
136  // We begin with an empty queue, and push in all nodes that do not have any
137  // parents. Then, we keep track of all unprocessed parents for each node.
138  // When a node has no more unprocessed parents, we can push it into the queue
139  // to be processed. This guarantees condition 1 is satisfied.
140 
141  // TODO(benz): Currently, condition 2 is not guaranteed to be satisified.
142  // However, giving each blob unique names via SSA will satisfy this condition.
143  // Then, the resulting graph can be optimized with memonger.
144 
145  for (int i = 0; i < (int)nodes_.size(); i++) {
146  unchecked_parent_count.push_back(node(i).parents.size());
147  if (node(i).parents.size() == 0 && is_node_active(i)) {
148  q.push(i);
149  visited[i] = true;
150  }
151  }
152 
153  while (!q.empty()) {
154  int idx = q.top();
155  q.pop();
156  if (!is_node_active(idx)) {
157  continue;
158  }
159  // Creates a new OperatorDef in NetDef
160  auto& op = *(netdef.add_op());
161  // Sets it equal to the OperatorDef at node(idx)
162  op = node(idx).op;
163  for (const auto& edge : node(idx).children) {
164  int child = edge.first;
165  if (!visited[child] && is_node_active(child)) {
166  unchecked_parent_count[child]--;
167  if (unchecked_parent_count[child] == 0) {
168  q.push(child);
169  visited[child] = true;
170  }
171  }
172  }
173  }
174  return netdef;
175 }
176 
177 void Graph::DeactivateSubgraph(std::vector<int> subgraph) {
178  for (int idx : subgraph) {
179  // remove all edges connected to inactive node
180  for (const auto& edge : node(idx).parents) {
181  int parent = edge.first;
182  node(parent).children.erase(idx);
183  }
184  for (const auto& edge : node(idx).children) {
185  int child = edge.first;
186  node(child).parents.erase(idx);
187  }
188  // actually mark flags as false
189  node(idx).active = false;
190  }
191 }
192 
193 } // namespace transform
194 
195 OperatorDef* AddOp(
196  NetDef* netdef_ptr,
197  string op_type,
198  std::vector<string> inputs,
199  std::vector<string> outputs) {
200  CHECK(netdef_ptr);
201  auto& netdef = *netdef_ptr;
202  auto op_ptr = netdef.add_op();
203  auto& op = *op_ptr;
204  op.set_type(op_type);
205  for (const string& inp : inputs) {
206  op.add_input(inp);
207  }
208  for (const string& outp : outputs) {
209  op.add_output(outp);
210  }
211  return op_ptr;
212 }
213 
214 bool MatchStrings(string p, string s) {
215  if (p == "*") { // star accepts anything
216  return true;
217  }
218  // TODO(benz): memoize this. (high constant factor boost in performance)
219  vector<string> choices = split('|', p);
220  for (const string& candidate : choices) {
221  if (candidate == s) {
222  return true;
223  }
224  }
225  return false;
226 }
227 
228 bool MatchArguments(const OperatorDef& p_op, const OperatorDef& g_op) {
229  for (const auto& p_arg : p_op.arg()) {
230  if (!p_arg.has_name()) {
231  continue;
232  }
233  bool found = false;
234  for (const auto& g_arg : g_op.arg()) {
235  if (p_arg.name() == g_arg.name()) {
236  found = true;
237  if (p_arg.has_f()) {
238  if (!g_arg.has_f() || p_arg.f() != g_arg.f()) {
239  return false;
240  }
241  }
242  if (p_arg.has_i()) {
243  if (!g_arg.has_i() || p_arg.i() != g_arg.i()) {
244  return false;
245  }
246  }
247  if (p_arg.has_s()) {
248  if (!g_arg.has_s() || !MatchStrings(p_arg.s(), g_arg.s())) {
249  return false;
250  }
251  }
252  if (p_arg.floats_size() != g_arg.floats_size()) {
253  return false;
254  }
255  for (int i = 0; i < p_arg.floats_size(); i++) {
256  if (p_arg.floats(i) != g_arg.floats(i)) {
257  return false;
258  }
259  }
260  if (p_arg.ints_size() != g_arg.ints_size()) {
261  return false;
262  }
263  for (int i = 0; i < p_arg.ints_size(); i++) {
264  if (p_arg.ints(i) != g_arg.ints(i)) {
265  return false;
266  }
267  }
268  if (p_arg.strings_size() != g_arg.strings_size()) {
269  return false;
270  }
271  for (int i = 0; i < p_arg.strings_size(); i++) {
272  if (!MatchStrings(p_arg.strings(i), g_arg.strings(i))) {
273  return false;
274  }
275  }
276  }
277  }
278  if (!found) {
279  return false;
280  }
281  }
282  return true;
283 }
284 
285 } // namespace caffe2
const std::vector< std::pair< string, int > > GetSubgraphInput(const std::vector< int > &subgraph)
Given a subgraph, gets all of the parents of the subgraph, as well as their associated blob names...
Definition: graph.cc:59
void DeactivateSubgraph(std::vector< int > subgraph)
Deactivate a subgraph, and get rid of all edges into this subgraph.
Definition: graph.cc:177
const std::vector< std::pair< string, int > > GetSubgraphOutput(const std::vector< int > &subgraph)
Given a subgraph, gets all of the children of the subgraph, as well as their associated blob names...
Definition: graph.cc:64
bool MatchStrings(string p, string s)
This allows for the use of * and | to match operator types, engines, or any other property that is re...
Definition: graph.cc:214
bool MatchArguments(const OperatorDef &p_op, const OperatorDef &g_op)
This ensures that each named arg that exists in the pattern exists in g_op, is equal in value...
Definition: graph.cc:228
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Graph(const NetDef &net_def)
Graph generation.
Definition: graph.cc:12
NetDef GetNetDef()
Generates a NetDef Representation for the current graph.
Definition: graph.cc:101