1 #include "caffe2/core/net_dag_utils.h" 5 #include <unordered_map> 6 #include <unordered_set> 8 #include "caffe2/core/operator.h" 9 #include "caffe2/core/static_tracepoint.h" 10 #include "caffe2/core/timer.h" 11 #include "caffe2/proto/caffe2_pb.h" 12 #include "caffe2/utils/proto_utils.h" 18 void prune(
int node_idx, std::vector<OpGraphNode>& nodes) {
20 std::vector<bool> ancestors(nodes.size(),
false);
22 std::stack<std::pair<int, int>> nodes_stack;
24 nodes_stack.push(std::make_pair(node_idx, -1));
26 while (!nodes_stack.empty()) {
27 const auto& node_pair = nodes_stack.top();
28 int curr = node_pair.first;
29 int prev = node_pair.second;
33 CAFFE_ENFORCE(curr < (
int)ancestors.size(),
"Out of bound access");
34 if (ancestors[curr]) {
35 ancestors[curr] =
false;
45 std::vector<int> new_parents;
46 for (
auto parent : nodes[curr].parents_) {
47 if (parent != prev && ancestors[parent]) {
49 nodes[parent].children_.erase(
51 nodes[parent].children_.begin(),
52 nodes[parent].children_.end(),
54 nodes[parent].children_.end());
56 new_parents.push_back(parent);
59 nodes[curr].parents_ = new_parents;
62 ancestors[curr] =
true;
65 if (nodes[curr].visited_inputs == nodes[curr].num_orig_parents) {
66 const auto& children = nodes[curr].children_;
67 for (
auto child : children) {
68 nodes[child].visited_inputs++;
69 nodes_stack.push(std::make_pair(child, curr));
79 std::vector<OpGraphNode> pruneOpNodeGraph(
80 const std::vector<OperatorNode>& nodes) {
82 std::vector<OpGraphNode> pruned;
88 for (
auto& node : nodes) {
90 nd.children_ = node.children_;
91 nd.parents_ = node.parents_;
92 nd.num_orig_parents = nd.parents_.size();
96 for (
int i = 0; i < (int)pruned.size(); ++i) {
97 if (pruned[i].parents_.size() == 0) {
102 LOG(INFO) <<
"Operator graph pruning prior to chain compute took: " 103 << t.Seconds() <<
" secs";
107 void updateOperatorNodes(
108 std::vector<OperatorNode>& nodes,
109 const ExecutionChains& chains) {
110 for (
int i = 0; i < (int)nodes.size(); ++i) {
111 auto& node = nodes[i];
112 if (chains.find(i) != chains.end()) {
113 node.is_chain_start_ =
true;
115 node.is_chain_start_ =
false;
117 node.runtime_parent_count_ = 0;
118 node.scheduled_.clear();
123 ExecutionChains computeChains(std::vector<OperatorNode>& orig_nodes) {
124 const std::vector<OpGraphNode> nodes = pruneOpNodeGraph(orig_nodes);
125 vector<int> initial_frontier;
126 for (
int idx = 0; idx < (int)nodes.size(); ++idx) {
127 if (nodes[idx].parents_.size() == 0) {
128 initial_frontier.push_back(idx);
134 std::unordered_map<int, int> node_seen_count;
136 for (
int root_index : initial_frontier) {
137 const auto& root = nodes[root_index];
138 std::stack<std::pair<int, std::vector<int>::const_iterator>> depth_stack;
139 depth_stack.push(make_pair(root_index, root.children_.begin()));
140 node_seen_count[root_index]++;
142 node_seen_count[root_index] == 1,
145 " visit count must be == 1");
147 while (depth_stack.size() > 0) {
148 auto cur = depth_stack.top();
150 if (cur.second != nodes[cur.first].children_.end()) {
151 int node_index = *cur.second;
152 node_seen_count[node_index]++;
154 depth_stack.push(cur);
155 if (node_seen_count[node_index] == 1) {
158 make_pair(node_index, nodes[node_index].children_.begin()));
166 ExecutionChains chains;
167 std::unordered_set<int> seen_nodes;
168 std::vector<int> chain;
169 std::pair<int, std::vector<int>::const_iterator> cur;
170 std::stack<std::pair<int, std::vector<int>::const_iterator>> depth_stack;
171 auto check_current_for_chaining = [&]() ->
bool {
173 node_seen_count[cur.first] == 1 &&
174 (chain.size() == 0 ||
187 orig_nodes[cur.first].operator_->device_option(),
188 orig_nodes[chain.back()].operator_->device_option()) &&
189 (!orig_nodes[chain.back()].operator_->HasAsyncPart() ||
190 orig_nodes[cur.first].operator_->SupportsAsyncScheduling()))));
192 auto commit_chain = [&]() {
193 if (chain.size() > 0) {
195 chains.insert({chain.front(), chain}).second,
198 " was already added.");
199 VLOG(2) <<
"Added chain: " << chain.front() <<
"with elements";
200 for (
auto ch : chain) {
201 VLOG(2) << ch <<
", ";
206 auto depth_traverse = [&]() {
207 while (cur.second != nodes[cur.first].children_.end() &&
208 seen_nodes.find(*cur.second) != seen_nodes.end()) {
212 if (cur.second != nodes[cur.first].children_.end()) {
213 auto next = make_pair(*cur.second, nodes[*cur.second].children_.begin());
214 depth_stack.push(cur);
215 depth_stack.push(next);
218 for (
int root_index : initial_frontier) {
220 make_pair(root_index, nodes[root_index].children_.begin()));
221 while (depth_stack.size() > 0) {
222 cur = depth_stack.top();
224 if (seen_nodes.find(cur.first) == seen_nodes.end()) {
225 seen_nodes.insert(cur.first);
228 if (nodes[cur.first].children_.size() == 1) {
229 if (check_current_for_chaining()) {
231 VLOG(1) <<
"Adding to existing chain" << cur.first;
232 chain.push_back(cur.first);
233 int index = *nodes[cur.first].children_.begin();
234 depth_stack.push(make_pair(index, nodes[index].children_.begin()));
239 chain.push_back(cur.first);
240 int index = *nodes[cur.first].children_.begin();
241 depth_stack.push(make_pair(index, nodes[index].children_.begin()));
244 nodes[cur.first].children_.size() == 0 &&
245 check_current_for_chaining()) {
247 chain.push_back(cur.first);
254 chain.push_back(cur.first);
270 seen_nodes.size() == nodes.size(),
271 "Haven't seen all the nodes, expected number of nodes ",
277 updateOperatorNodes(orig_nodes, chains);
282 ExecutionChains computeGroups(std::vector<OperatorNode>& orig_nodes) {
283 const std::vector<OpGraphNode> nodes = pruneOpNodeGraph(orig_nodes);
284 ExecutionChains chains;
285 std::vector<int> sync_frontier;
286 std::vector<int> async_frontier;
288 std::vector<int> in_degrees;
289 in_degrees.reserve(nodes.size());
293 std::back_inserter(in_degrees),
294 [](
const OpGraphNode& n) {
return n.parents_.size(); });
297 for (
int idx = 0; idx < (int)nodes.size(); ++idx) {
298 if (in_degrees[idx] == 0) {
299 if (orig_nodes[idx].operator_->HasAsyncPart()) {
300 async_frontier.push_back(idx);
302 sync_frontier.push_back(idx);
311 while (!(async_frontier.empty() && sync_frontier.empty())) {
313 for (
const auto i : sync_frontier) {
316 sync_frontier.clear();
317 std::vector<int> chain;
321 chain.push_back(idx);
322 for (
int child : nodes[idx].children_) {
323 if (--in_degrees[child] == 0) {
324 if (orig_nodes[child].operator_->HasAsyncPart()) {
325 async_frontier.push_back(child);
333 if (!chain.empty()) {
334 chains.emplace(chain.front(), chain);
338 for (
const auto i : async_frontier) {
341 async_frontier.clear();
347 for (
int child : nodes[idx].children_) {
348 if (--in_degrees[child] == 0) {
349 if (orig_nodes[child].operator_->HasAsyncPart()) {
352 sync_frontier.push_back(child);
359 updateOperatorNodes(orig_nodes, chains);
363 ExecutionChains singleChains(std::vector<OperatorNode>& nodes) {
364 ExecutionChains chains;
365 for (
int i = 0; i < (int)nodes.size(); ++i) {
368 updateOperatorNodes(nodes, chains);
372 std::vector<OperatorNode> prepareOperatorNodes(
373 const std::shared_ptr<const NetDef>& net_def,
375 std::vector<OperatorNode> operator_nodes(net_def->op_size());
376 std::map<string, int> blob_creator;
377 std::map<string, std::set<int>> blob_readers;
378 bool net_def_has_device_option = net_def->has_device_option();
380 for (
int idx = 0; idx < net_def->op_size(); ++idx) {
381 const OperatorDef& op_def = net_def->op(idx);
382 VLOG(1) <<
"Creating operator #" << idx <<
": " << op_def.name() <<
": " 384 if (!op_def.has_device_option() && net_def_has_device_option) {
385 OperatorDef temp_def(op_def);
386 temp_def.mutable_device_option()->CopyFrom(net_def->device_option());
387 operator_nodes[idx].operator_ = CreateOperator(temp_def, ws, idx);
389 auto op = CreateOperator(op_def, ws, idx);
391 std::shared_ptr<const OperatorDef>{net_def, &(net_def->op(idx))});
392 operator_nodes[idx].operator_ = std::move(op);
397 [&](
const google::protobuf::RepeatedPtrField<std::string>& inputs) {
398 for (
const string& input : inputs) {
399 if (blob_creator.count(input) == 0) {
400 VLOG(1) <<
"Input " << input <<
" not produced by this net. " 401 <<
"Assuming it is pre-existing.";
403 int parent = blob_creator[input];
404 VLOG(1) <<
"op dependency (RaW " << input <<
"): " << parent
406 operator_nodes[idx].parents_.push_back(parent);
407 operator_nodes[parent].children_.push_back(idx);
410 blob_readers[input].insert(idx);
413 checkInputs(op_def.input());
414 checkInputs(op_def.control_input());
417 for (
const string& output : op_def.output()) {
418 if (blob_creator.count(output) != 0) {
421 int waw_parent = blob_creator[output];
422 VLOG(1) <<
"op dependency (WaW " << output <<
"): " << waw_parent
424 operator_nodes[idx].parents_.push_back(waw_parent);
425 operator_nodes[waw_parent].children_.push_back(idx);
429 for (
const int war_parent : blob_readers[output]) {
430 VLOG(1) <<
"op dependency (WaR " << output <<
"): " << war_parent
432 operator_nodes[idx].parents_.push_back(war_parent);
433 operator_nodes[war_parent].children_.push_back(idx);
436 blob_creator[output] = idx;
441 blob_readers[output].clear();
447 for (
int i = 0; i < (int)operator_nodes.size(); ++i) {
448 auto& node = operator_nodes[i];
450 auto& p = node.parents_;
451 std::sort(p.begin(), p.end());
452 p.erase(std::unique(p.begin(), p.end()), p.end());
453 p.erase(std::remove(p.begin(), p.end(), i), p.end());
455 auto& c = node.children_;
456 std::sort(c.begin(), c.end());
457 c.erase(std::unique(c.begin(), c.end()), c.end());
458 c.erase(std::remove(c.begin(), c.end(), i), c.end());
461 return operator_nodes;
464 std::vector<OpGraphNode> prepareChainGraphNodes(
465 const std::vector<dag_utils::OperatorNode>& operator_nodes,
466 const std::vector<std::vector<int>>& execution_chains) {
467 std::unordered_map<int, int> op_to_chain_idx;
468 for (
int chain_idx = 0; chain_idx < (int)execution_chains.size(); ++chain_idx) {
469 const auto& chain_indices = execution_chains[chain_idx];
470 for (
const auto& chain_op_idx : chain_indices) {
471 CAFFE_ENFORCE(!op_to_chain_idx.count(chain_op_idx));
472 op_to_chain_idx[chain_op_idx] = chain_idx;
476 std::vector<OpGraphNode> chain_nodes(execution_chains.size());
477 for (
int op_idx = 0; op_idx < (int)operator_nodes.size(); ++op_idx) {
478 CAFFE_ENFORCE(op_to_chain_idx.count(op_idx));
479 auto chain_idx = op_to_chain_idx[op_idx];
480 auto& chain = chain_nodes[chain_idx];
481 auto& op_node = operator_nodes[op_idx];
483 for (
const auto& child_idx : op_node.children_) {
484 CAFFE_ENFORCE(op_to_chain_idx.count(child_idx));
485 auto child_chain_idx = op_to_chain_idx[child_idx];
486 if (child_chain_idx != chain_idx) {
488 chain.children_.begin(), chain.children_.end(), child_chain_idx);
489 if (it == chain.children_.end()) {
490 chain.children_.push_back(child_chain_idx);
495 for (
const auto& parent_idx : op_node.parents_) {
496 CAFFE_ENFORCE(op_to_chain_idx.count(parent_idx));
497 auto parent_chain_idx = op_to_chain_idx[parent_idx];
498 if (parent_chain_idx != chain_idx) {
500 chain.parents_.begin(), chain.parents_.end(), parent_chain_idx);
501 if (it == chain.parents_.end()) {
502 chain.parents_.push_back(parent_chain_idx);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...