Caffe2 - C++ API
A deep learning, cross platform ML framework
net_dag_utils.cc
1 #include "caffe2/core/net_dag_utils.h"
2 
3 #include <set>
4 #include <stack>
5 #include <unordered_map>
6 #include <unordered_set>
7 
8 #include "caffe2/core/operator.h"
9 #include "caffe2/core/static_tracepoint.h"
10 #include "caffe2/core/timer.h"
11 #include "caffe2/proto/caffe2_pb.h"
12 #include "caffe2/utils/proto_utils.h"
13 
14 namespace caffe2 {
15 namespace dag_utils {
16 
17 namespace {
18 void prune(int node_idx, std::vector<OpGraphNode>& nodes) {
19  // Ancestor table for tracking the visited nodes
20  std::vector<bool> ancestors(nodes.size(), false);
21  // stack element is pair of <curr_node, previous_node>
22  std::stack<std::pair<int, int>> nodes_stack;
23  // initialize the prev_node to be -1
24  nodes_stack.push(std::make_pair(node_idx, -1));
25 
26  while (!nodes_stack.empty()) {
27  const auto& node_pair = nodes_stack.top();
28  int curr = node_pair.first;
29  int prev = node_pair.second;
30 
31  // If the node has already been visited, pop curr out of
32  // stack and clean up the ancestor table
33  CAFFE_ENFORCE(curr < (int)ancestors.size(), "Out of bound access");
34  if (ancestors[curr]) {
35  ancestors[curr] = false;
36  nodes_stack.pop();
37  continue;
38  }
39 
40  // Check if this has a parent that can be pruned:
41  // if parent is not the previous node visited and is
42  // an ancestor of the current traversar, it can be
43  // pruned.
44  if (prev >= 0) {
45  std::vector<int> new_parents;
46  for (auto parent : nodes[curr].parents_) {
47  if (parent != prev && ancestors[parent]) {
48  // We can prune this one
49  nodes[parent].children_.erase(
50  std::remove(
51  nodes[parent].children_.begin(),
52  nodes[parent].children_.end(),
53  curr),
54  nodes[parent].children_.end());
55  } else {
56  new_parents.push_back(parent);
57  }
58  }
59  nodes[curr].parents_ = new_parents;
60  }
61 
62  ancestors[curr] = true;
63 
64  // Descend -- but only once from each node
65  if (nodes[curr].visited_inputs == nodes[curr].num_orig_parents) {
66  const auto& children = nodes[curr].children_;
67  for (auto child : children) {
68  nodes[child].visited_inputs++;
69  nodes_stack.push(std::make_pair(child, curr));
70  }
71  }
72  }
73 }
74 
79 std::vector<OpGraphNode> pruneOpNodeGraph(
80  const std::vector<OperatorNode>& nodes) {
81  Timer t;
82  std::vector<OpGraphNode> pruned;
83 
84  // Create a separate list of pruned operatornodes used
85  // for the chaining computation. Because of the unique_ptr
86  // in the OperatorNode, we cannot do a copy but have to
87  // copy just the fields we need.
88  for (auto& node : nodes) {
89  OpGraphNode nd;
90  nd.children_ = node.children_;
91  nd.parents_ = node.parents_;
92  nd.num_orig_parents = nd.parents_.size();
93  pruned.push_back(nd);
94  }
95 
96  for (int i = 0; i < (int)pruned.size(); ++i) {
97  if (pruned[i].parents_.size() == 0) {
98  prune(i, pruned);
99  }
100  }
101 
102  LOG(INFO) << "Operator graph pruning prior to chain compute took: "
103  << t.Seconds() << " secs";
104  return pruned;
105 }
106 
107 void updateOperatorNodes(
108  std::vector<OperatorNode>& nodes,
109  const ExecutionChains& chains) {
110  for (int i = 0; i < (int)nodes.size(); ++i) {
111  auto& node = nodes[i];
112  if (chains.find(i) != chains.end()) {
113  node.is_chain_start_ = true;
114  } else {
115  node.is_chain_start_ = false;
116  }
117  node.runtime_parent_count_ = 0;
118  node.scheduled_.clear();
119  }
120 }
121 } // namespace
122 
123 ExecutionChains computeChains(std::vector<OperatorNode>& orig_nodes) {
124  const std::vector<OpGraphNode> nodes = pruneOpNodeGraph(orig_nodes);
125  vector<int> initial_frontier;
126  for (int idx = 0; idx < (int)nodes.size(); ++idx) {
127  if (nodes[idx].parents_.size() == 0) {
128  initial_frontier.push_back(idx);
129  }
130  }
131 
132  // We need to construct the node_seen_count to know how many inner edges each
133  // node has.
134  std::unordered_map<int, int> node_seen_count;
135 
136  for (int root_index : initial_frontier) {
137  const auto& root = nodes[root_index];
138  std::stack<std::pair<int, std::vector<int>::const_iterator>> depth_stack;
139  depth_stack.push(make_pair(root_index, root.children_.begin()));
140  node_seen_count[root_index]++;
141  CAFFE_ENFORCE(
142  node_seen_count[root_index] == 1,
143  "root node ",
144  root_index,
145  " visit count must be == 1");
146 
147  while (depth_stack.size() > 0) {
148  auto cur = depth_stack.top();
149  depth_stack.pop();
150  if (cur.second != nodes[cur.first].children_.end()) {
151  int node_index = *cur.second;
152  node_seen_count[node_index]++;
153  cur.second++;
154  depth_stack.push(cur);
155  if (node_seen_count[node_index] == 1) {
156  // Visit each child only once.
157  depth_stack.push(
158  make_pair(node_index, nodes[node_index].children_.begin()));
159  }
160  }
161  }
162  }
163  // Now, we compute the set of execution chains An execution chain is
164  // a linear set of nodes that can be executed on a single stream
165  // (e.g. a chain of single input, single output operators)
166  ExecutionChains chains;
167  std::unordered_set<int> seen_nodes;
168  std::vector<int> chain;
169  std::pair<int, std::vector<int>::const_iterator> cur;
170  std::stack<std::pair<int, std::vector<int>::const_iterator>> depth_stack;
171  auto check_current_for_chaining = [&]() -> bool {
172  return (
173  node_seen_count[cur.first] == 1 &&
174  (chain.size() == 0 ||
175  (
176  // A chain of operators is executed without additional
177  // synchronization by calling RunAsync sequentially on each
178  // operator and passing the same stream id on each call.
179  // RunAsync may schedule an async computation on device.
180  // In order to be scheduled on the same chain two operators
181  // (parent and dependent) need to satisfy:
182  // 1. Both ops are on the same device _and_
183  // 2. Parent op does not have an async part or
184  // dependent op can be executed as an async dependency
185 
186  IsSameDevice(
187  orig_nodes[cur.first].operator_->device_option(),
188  orig_nodes[chain.back()].operator_->device_option()) &&
189  (!orig_nodes[chain.back()].operator_->HasAsyncPart() ||
190  orig_nodes[cur.first].operator_->SupportsAsyncScheduling()))));
191  };
192  auto commit_chain = [&]() {
193  if (chain.size() > 0) {
194  CAFFE_ENFORCE(
195  chains.insert({chain.front(), chain}).second,
196  "Chain ",
197  chain.front(),
198  " was already added.");
199  VLOG(2) << "Added chain: " << chain.front() << "with elements";
200  for (auto ch : chain) {
201  VLOG(2) << ch << ", ";
202  }
203  chain.clear();
204  }
205  };
206  auto depth_traverse = [&]() {
207  while (cur.second != nodes[cur.first].children_.end() &&
208  seen_nodes.find(*cur.second) != seen_nodes.end()) {
209  cur.second++;
210  }
211 
212  if (cur.second != nodes[cur.first].children_.end()) {
213  auto next = make_pair(*cur.second, nodes[*cur.second].children_.begin());
214  depth_stack.push(cur);
215  depth_stack.push(next);
216  }
217  };
218  for (int root_index : initial_frontier) {
219  depth_stack.push(
220  make_pair(root_index, nodes[root_index].children_.begin()));
221  while (depth_stack.size() > 0) {
222  cur = depth_stack.top();
223  depth_stack.pop();
224  if (seen_nodes.find(cur.first) == seen_nodes.end()) {
225  seen_nodes.insert(cur.first);
226  // Has one child, can be candidate for chain or can be added to the
227  // previous chain.
228  if (nodes[cur.first].children_.size() == 1) {
229  if (check_current_for_chaining()) {
230  // Add oneself to the current chain.
231  VLOG(1) << "Adding to existing chain" << cur.first;
232  chain.push_back(cur.first);
233  int index = *nodes[cur.first].children_.begin();
234  depth_stack.push(make_pair(index, nodes[index].children_.begin()));
235  } else {
236  // Can't belong to the previous chain, commit previous chain and
237  // start a new one.
238  commit_chain();
239  chain.push_back(cur.first);
240  int index = *nodes[cur.first].children_.begin();
241  depth_stack.push(make_pair(index, nodes[index].children_.begin()));
242  }
243  } else if (
244  nodes[cur.first].children_.size() == 0 &&
245  check_current_for_chaining()) {
246  // Add current node to the current chain and commit.
247  chain.push_back(cur.first);
248  commit_chain();
249  } else {
250  // Node has more than one child.
251  commit_chain();
252  // Add current node as an independent chain since it won't be a part
253  // of a bigger chain.
254  chain.push_back(cur.first);
255  commit_chain();
256  depth_traverse();
257  }
258  } else {
259  // This node has been seen before, we will only traverse its children.
260  // Commit any pending chains and continue traversing.
261  commit_chain();
262  depth_traverse();
263  }
264  } // End while
265 
266  // Check if this if is even needed.
267  commit_chain();
268  }
269  CAFFE_ENFORCE(
270  seen_nodes.size() == nodes.size(),
271  "Haven't seen all the nodes, expected number of nodes ",
272  nodes.size(),
273  ", but seen only ",
274  seen_nodes.size(),
275  ".");
276 
277  updateOperatorNodes(orig_nodes, chains);
278  return chains;
279 }
280 
281 // Here chains are essentially groups, we used chain/group interchangeably
282 ExecutionChains computeGroups(std::vector<OperatorNode>& orig_nodes) {
283  const std::vector<OpGraphNode> nodes = pruneOpNodeGraph(orig_nodes);
284  ExecutionChains chains;
285  std::vector<int> sync_frontier;
286  std::vector<int> async_frontier;
287 
288  std::vector<int> in_degrees;
289  in_degrees.reserve(nodes.size());
290  std::transform(
291  nodes.begin(),
292  nodes.end(),
293  std::back_inserter(in_degrees),
294  [](const OpGraphNode& n) { return n.parents_.size(); });
295 
296  // Screen out the primary root nodes
297  for (int idx = 0; idx < (int)nodes.size(); ++idx) {
298  if (in_degrees[idx] == 0) {
299  if (orig_nodes[idx].operator_->HasAsyncPart()) {
300  async_frontier.push_back(idx);
301  } else {
302  sync_frontier.push_back(idx);
303  }
304  }
305  }
306 
307  // We check sync ops on the froniter first and then async ops. This gives us a
308  // head start to execute sync ops locally while waiting for async ops to
309  // finish.
310  std::queue<int> q;
311  while (!(async_frontier.empty() && sync_frontier.empty())) {
312  // Sync ops
313  for (const auto i : sync_frontier) {
314  q.push(i);
315  }
316  sync_frontier.clear();
317  std::vector<int> chain;
318  while (!q.empty()) {
319  int idx = q.front();
320  q.pop();
321  chain.push_back(idx);
322  for (int child : nodes[idx].children_) {
323  if (--in_degrees[child] == 0) {
324  if (orig_nodes[child].operator_->HasAsyncPart()) {
325  async_frontier.push_back(child);
326  } else {
327  q.push(child);
328  }
329  }
330  }
331  }
332  // add the whole group of continuous sync ops into one chain
333  if (!chain.empty()) {
334  chains.emplace(chain.front(), chain);
335  }
336 
337  // Async ops
338  for (const auto i : async_frontier) {
339  q.push(i);
340  }
341  async_frontier.clear();
342  while (!q.empty()) {
343  int idx = q.front();
344  q.pop();
345  // Put each individual node as a new chain
346  chains[idx] = {idx};
347  for (int child : nodes[idx].children_) {
348  if (--in_degrees[child] == 0) {
349  if (orig_nodes[child].operator_->HasAsyncPart()) {
350  q.push(child);
351  } else {
352  sync_frontier.push_back(child);
353  }
354  }
355  }
356  }
357  }
358 
359  updateOperatorNodes(orig_nodes, chains);
360  return chains;
361 }
362 
363 ExecutionChains singleChains(std::vector<OperatorNode>& nodes) {
364  ExecutionChains chains;
365  for (int i = 0; i < (int)nodes.size(); ++i) {
366  chains[i] = {i};
367  }
368  updateOperatorNodes(nodes, chains);
369  return chains;
370 }
371 
372 std::vector<OperatorNode> prepareOperatorNodes(
373  const std::shared_ptr<const NetDef>& net_def,
374  Workspace* ws) {
375  std::vector<OperatorNode> operator_nodes(net_def->op_size());
376  std::map<string, int> blob_creator;
377  std::map<string, std::set<int>> blob_readers;
378  bool net_def_has_device_option = net_def->has_device_option();
379  // Initialize the operators
380  for (int idx = 0; idx < net_def->op_size(); ++idx) {
381  const OperatorDef& op_def = net_def->op(idx);
382  VLOG(1) << "Creating operator #" << idx << ": " << op_def.name() << ": "
383  << op_def.type();
384  if (!op_def.has_device_option() && net_def_has_device_option) {
385  OperatorDef temp_def(op_def);
386  temp_def.mutable_device_option()->CopyFrom(net_def->device_option());
387  operator_nodes[idx].operator_ = CreateOperator(temp_def, ws, idx);
388  } else {
389  auto op = CreateOperator(op_def, ws, idx);
390  op->set_debug_def(
391  std::shared_ptr<const OperatorDef>{net_def, &(net_def->op(idx))});
392  operator_nodes[idx].operator_ = std::move(op);
393  }
394  // Check the inputs, and set up parents if necessary. This addressese the
395  // read after write case.
396  auto checkInputs =
397  [&](const google::protobuf::RepeatedPtrField<std::string>& inputs) {
398  for (const string& input : inputs) {
399  if (blob_creator.count(input) == 0) {
400  VLOG(1) << "Input " << input << " not produced by this net. "
401  << "Assuming it is pre-existing.";
402  } else {
403  int parent = blob_creator[input];
404  VLOG(1) << "op dependency (RaW " << input << "): " << parent
405  << "->" << idx;
406  operator_nodes[idx].parents_.push_back(parent);
407  operator_nodes[parent].children_.push_back(idx);
408  }
409  // Add the current idx to the readers of this input.
410  blob_readers[input].insert(idx);
411  }
412  };
413  checkInputs(op_def.input());
414  checkInputs(op_def.control_input());
415 
416  // Check the outputs.
417  for (const string& output : op_def.output()) {
418  if (blob_creator.count(output) != 0) {
419  // This addresses the write after write case - we will assume that all
420  // writes are inherently sequential.
421  int waw_parent = blob_creator[output];
422  VLOG(1) << "op dependency (WaW " << output << "): " << waw_parent
423  << "->" << idx;
424  operator_nodes[idx].parents_.push_back(waw_parent);
425  operator_nodes[waw_parent].children_.push_back(idx);
426  }
427  // This addresses the write after read case - we will assume that writes
428  // should only occur after all previous reads are finished.
429  for (const int war_parent : blob_readers[output]) {
430  VLOG(1) << "op dependency (WaR " << output << "): " << war_parent
431  << "->" << idx;
432  operator_nodes[idx].parents_.push_back(war_parent);
433  operator_nodes[war_parent].children_.push_back(idx);
434  }
435  // Renew the creator of the output name.
436  blob_creator[output] = idx;
437  // The write would create an implicit barrier that all earlier readers of
438  // this output is now parents of the current op, and future writes would
439  // not need to depend on these earlier readers. Thus, we can clear up the
440  // blob readers.
441  blob_readers[output].clear();
442  }
443  }
444 
445  // Now, make sure that the parent list and the children list do not contain
446  // duplicated items.
447  for (int i = 0; i < (int)operator_nodes.size(); ++i) {
448  auto& node = operator_nodes[i];
449  // Sort, remove duplicates, and delete self dependency.
450  auto& p = node.parents_;
451  std::sort(p.begin(), p.end());
452  p.erase(std::unique(p.begin(), p.end()), p.end());
453  p.erase(std::remove(p.begin(), p.end(), i), p.end());
454  // Do the same for the children vector.
455  auto& c = node.children_;
456  std::sort(c.begin(), c.end());
457  c.erase(std::unique(c.begin(), c.end()), c.end());
458  c.erase(std::remove(c.begin(), c.end(), i), c.end());
459  }
460 
461  return operator_nodes;
462 }
463 
464 std::vector<OpGraphNode> prepareChainGraphNodes(
465  const std::vector<dag_utils::OperatorNode>& operator_nodes,
466  const std::vector<std::vector<int>>& execution_chains) {
467  std::unordered_map<int, int> op_to_chain_idx;
468  for (int chain_idx = 0; chain_idx < (int)execution_chains.size(); ++chain_idx) {
469  const auto& chain_indices = execution_chains[chain_idx];
470  for (const auto& chain_op_idx : chain_indices) {
471  CAFFE_ENFORCE(!op_to_chain_idx.count(chain_op_idx));
472  op_to_chain_idx[chain_op_idx] = chain_idx;
473  }
474  }
475 
476  std::vector<OpGraphNode> chain_nodes(execution_chains.size());
477  for (int op_idx = 0; op_idx < (int)operator_nodes.size(); ++op_idx) {
478  CAFFE_ENFORCE(op_to_chain_idx.count(op_idx));
479  auto chain_idx = op_to_chain_idx[op_idx];
480  auto& chain = chain_nodes[chain_idx];
481  auto& op_node = operator_nodes[op_idx];
482 
483  for (const auto& child_idx : op_node.children_) {
484  CAFFE_ENFORCE(op_to_chain_idx.count(child_idx));
485  auto child_chain_idx = op_to_chain_idx[child_idx];
486  if (child_chain_idx != chain_idx) {
487  auto it = std::find(
488  chain.children_.begin(), chain.children_.end(), child_chain_idx);
489  if (it == chain.children_.end()) {
490  chain.children_.push_back(child_chain_idx);
491  }
492  }
493  }
494 
495  for (const auto& parent_idx : op_node.parents_) {
496  CAFFE_ENFORCE(op_to_chain_idx.count(parent_idx));
497  auto parent_chain_idx = op_to_chain_idx[parent_idx];
498  if (parent_chain_idx != chain_idx) {
499  auto it = std::find(
500  chain.parents_.begin(), chain.parents_.end(), parent_chain_idx);
501  if (it == chain.parents_.end()) {
502  chain.parents_.push_back(parent_chain_idx);
503  }
504  }
505  }
506  }
507 
508  return chain_nodes;
509 }
510 
511 } // namespace dag_utils
512 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13