Caffe2 - C++ API
A deep learning, cross platform ML framework
graph.cc
1 
17 #include "caffe2/core/graph.h"
18 
19 #include "caffe2/core/common.h"
20 #include "caffe2/core/logging.h"
21 #include "caffe2/core/net.h"
22 #include "caffe2/proto/caffe2.pb.h"
23 
24 namespace caffe2 {
25 
26 namespace transform {
27 
28 Graph::Graph(const NetDef& net) : netdef_(net) {
29  nodes_.clear();
30  nodes_.resize(net.op_size());
31 
32  // Copy over operators
33  for (int x = 0; x < net.op_size(); x++) {
34  node(x).op = net.op(x);
35  }
36 
37  // For any blob, which operator was the last to write to it?
38  // In python, this is known as "versions".
39  std::unordered_map<string, int> edge_parent;
40 
41  for (int i = 0; i < nodes_.size(); i++) {
42  for (const string& blob : node(i).op.input()) {
43  auto it = edge_parent.find(blob);
44  if (it != edge_parent.end()) {
45  int j = it->second;
46  node(i).parents[j].push_back(blob);
47  node(j).children[i].push_back(blob);
48  } else {
49  external_input_.insert(blob);
50  }
51  }
52  for (const string& blob : node(i).op.output()) {
53  edge_parent[blob] = i;
54  }
55  }
56 
57  // Traverse opposite direction to find external outputs
58 
59  // For any blob, which operator was the last to read to from it?
60  std::unordered_map<string, int> edge_child;
61 
62  for (int i = nodes_.size() - 1; i >= 0; i--) {
63  for (const string& blob : node(i).op.output()) {
64  auto it = edge_child.find(blob);
65  if (it == edge_child.end()) {
66  external_output_.insert(blob);
67  }
68  }
69  for (const string& blob : node(i).op.input()) {
70  edge_child[blob] = i;
71  }
72  }
73 }
74 
75 const std::vector<std::pair<string, int>> Graph::GetSubgraphInput(
76  const std::vector<int>& match) {
77  return GetSubgraphPerimeterHelper(true, match);
78 }
79 
80 const std::vector<std::pair<string, int>> Graph::GetSubgraphOutput(
81  const std::vector<int>& match) {
82  return GetSubgraphPerimeterHelper(false, match);
83 }
84 
85 // This helper function will either get:
86 // 1) a list for the blobs that write INTO a subgraph
87 // 2) a list of for the blobs that are written FROM a subgraph.
88 //
89 // The "from_children" flag determines if it is case 1 (true) or case 2 (false).
90 const std::vector<std::pair<string, int>> Graph::GetSubgraphPerimeterHelper(
91  bool from_children,
92  const std::vector<int>& match) {
93  std::vector<std::pair<string, int>> edge_list;
94  std::unordered_set<int> match_set(match.begin(), match.end());
95  for (int x = 0; x < nodes_.size(); x++) {
96  if (!is_node_active(x)) {
97  continue;
98  }
99  if (!match_set.count(x)) { // x is not in subgraph
100  const auto& list = from_children ? node(x).children : node(x).parents;
101  for (const auto& edge : list) {
102  int parent = edge.first;
103  const auto& blobs = edge.second;
104  if (match_set.count(parent)) { // but has a parent that is in subgraph
105  for (const string& blob : blobs) {
106  edge_list.push_back({blob, x});
107  }
108  }
109  }
110  }
111  }
112  // return the list in sorted order, to allow binary searching
113  std::sort(edge_list.begin(), edge_list.end());
114  return edge_list;
115 }
116 
118  std::vector<bool> visited(nodes_.size(), false);
119 
120  // Copy over all the properties of the netdef we're based on
121  NetDef netdef = netdef_;
122 
123  // But we're going to put in our own operators.
124  netdef.clear_op();
125 
126  // Keeps track of the number of parents yet to be processed.
127  std::vector<int> unchecked_parent_count;
128 
129  // We will perform a topological traversal on the nodes, but we will prefer
130  // nodes that come earlier in the execution order.
131 
132  // This is a min-heap, which stores its elements in ascending order.
133  // This stores the nodes in the order we process them to be in.
134  // This guarantees the lowest lexicographical topological ordering.
135 
136  // This also means the original nodes will be kept in their execution order.
137  std::priority_queue<int, std::vector<int>, std::greater<int>> q;
138 
139  // In our graph, G, the nodes don't have a strict ordering. But in the netdef,
140  // they must (since nets are operators executed in some order).
141  // How do we make sure that the order of operators in our generated netdef
142  // is valid?
143  // 1) The ordering of the netdef must be topologically sorted, respect to G.
144  // If A -> B is an edge in the graph G, then A must come before B in the
145  // netdef's ordering.
146  // 2) No blob conflicts: If A -> B is an edge in the graph G, and A writes to
147  // blob X and B reads from blob X, then there cannot be an op that writes
148  // to blob X between A and B in the ordering.
149  //
150  // Perform a Topological Sort, to find an order for the Operators to be in.
151  // We will keep track of the number of parents each node has.
152  // We begin with an empty queue, and push in all nodes that do not have any
153  // parents. Then, we keep track of all unprocessed parents for each node.
154  // When a node has no more unprocessed parents, we can push it into the queue
155  // to be processed. This guarantees condition 1 is satisfied.
156 
157  // TODO(benz): Currently, condition 2 is not guaranteed to be satisified.
158  // However, giving each blob unique names via SSA will satisfy this condition.
159  // Then, the resulting graph can be optimized with memonger.
160 
161  for (int i = 0; i < nodes_.size(); i++) {
162  unchecked_parent_count.push_back(node(i).parents.size());
163  if (node(i).parents.size() == 0 && is_node_active(i)) {
164  q.push(i);
165  visited[i] = true;
166  }
167  }
168 
169  while (!q.empty()) {
170  int idx = q.top();
171  q.pop();
172  if (!is_node_active(idx)) {
173  continue;
174  }
175  // Creates a new OperatorDef in NetDef
176  auto& op = *(netdef.add_op());
177  // Sets it equal to the OperatorDef at node(idx)
178  op = node(idx).op;
179  for (const auto& edge : node(idx).children) {
180  int child = edge.first;
181  if (!visited[child] && is_node_active(child)) {
182  unchecked_parent_count[child]--;
183  if (unchecked_parent_count[child] == 0) {
184  q.push(child);
185  visited[child] = true;
186  }
187  }
188  }
189  }
190  return netdef;
191 }
192 
193 void Graph::DeactivateSubgraph(std::vector<int> subgraph) {
194  for (int idx : subgraph) {
195  // remove all edges connected to inactive node
196  for (const auto& edge : node(idx).parents) {
197  int parent = edge.first;
198  node(parent).children.erase(idx);
199  }
200  for (const auto& edge : node(idx).children) {
201  int child = edge.first;
202  node(child).parents.erase(idx);
203  }
204  // actually mark flags as false
205  node(idx).active = false;
206  }
207 }
208 
209 } // namespace transform
210 
211 OperatorDef* AddOp(
212  NetDef* netdef_ptr,
213  string op_type,
214  std::vector<string> inputs,
215  std::vector<string> outputs) {
216  CHECK(netdef_ptr);
217  auto& netdef = *netdef_ptr;
218  auto op_ptr = netdef.add_op();
219  auto& op = *op_ptr;
220  op.set_type(op_type);
221  for (const string& inp : inputs) {
222  op.add_input(inp);
223  }
224  for (const string& outp : outputs) {
225  op.add_output(outp);
226  }
227  return op_ptr;
228 }
229 
230 bool MatchStrings(string p, string s) {
231  if (p == "*") { // star accepts anything
232  return true;
233  }
234  // TODO(benz): memoize this. (high constant factor boost in performance)
235  vector<string> choices = split('|', p);
236  for (const string& candidate : choices) {
237  if (candidate == s) {
238  return true;
239  }
240  }
241  return false;
242 }
243 
244 bool MatchArguments(const OperatorDef& p_op, const OperatorDef& g_op) {
245  for (const auto& p_arg : p_op.arg()) {
246  if (!p_arg.has_name()) {
247  continue;
248  }
249  bool found = false;
250  for (const auto& g_arg : g_op.arg()) {
251  if (p_arg.name() == g_arg.name()) {
252  found = true;
253  if (p_arg.has_f()) {
254  if (!g_arg.has_f() || p_arg.f() != g_arg.f()) {
255  return false;
256  }
257  }
258  if (p_arg.has_i()) {
259  if (!g_arg.has_i() || p_arg.i() != g_arg.i()) {
260  return false;
261  }
262  }
263  if (p_arg.has_s()) {
264  if (!g_arg.has_s() || !MatchStrings(p_arg.s(), g_arg.s())) {
265  return false;
266  }
267  }
268  if (p_arg.floats_size() != g_arg.floats_size()) {
269  return false;
270  }
271  for (int i = 0; i < p_arg.floats_size(); i++) {
272  if (p_arg.floats(i) != g_arg.floats(i)) {
273  return false;
274  }
275  }
276  if (p_arg.ints_size() != g_arg.ints_size()) {
277  return false;
278  }
279  for (int i = 0; i < p_arg.ints_size(); i++) {
280  if (p_arg.ints(i) != g_arg.ints(i)) {
281  return false;
282  }
283  }
284  if (p_arg.strings_size() != g_arg.strings_size()) {
285  return false;
286  }
287  for (int i = 0; i < p_arg.strings_size(); i++) {
288  if (!MatchStrings(p_arg.strings(i), g_arg.strings(i))) {
289  return false;
290  }
291  }
292  }
293  }
294  if (!found) {
295  return false;
296  }
297  }
298  return true;
299 }
300 
301 } // namespace caffe2
const std::vector< std::pair< string, int > > GetSubgraphInput(const std::vector< int > &subgraph)
Given a subgraph, gets all of the parents of the subgraph, as well as their associated blob names...
Definition: graph.cc:75
void DeactivateSubgraph(std::vector< int > subgraph)
Deactivate a subgraph, and get rid of all edges into this subgraph.
Definition: graph.cc:193
const std::vector< std::pair< string, int > > GetSubgraphOutput(const std::vector< int > &subgraph)
Given a subgraph, gets all of the children of the subgraph, as well as their associated blob names...
Definition: graph.cc:80
bool MatchArguments(const OperatorDef &p_op, const OperatorDef &g_op)
This ensures that each named arg that exists in the pattern exists in g_op, is equal in value...
Definition: graph.cc:244
Copyright (c) 2016-present, Facebook, Inc.
Graph(const NetDef &net_def)
Graph generation.
Definition: graph.cc:28
bool MatchStrings(string p, string s)
This allows for the use of * and | to match operator types, engines, or any other property that is re...
Definition: graph.cc:230
NetDef GetNetDef()
Generates a NetDef Representation for the current graph.
Definition: graph.cc:117