Caffe2 - C++ API
A deep learning, cross platform ML framework
net_async_task_graph.cc
1 #include "caffe2/core/net_async_task_graph.h"
2 
3 #include "caffe2/core/net_parallel.h"
4 
5 namespace caffe2 {
6 
7 AsyncTaskGraph::AsyncTaskGraph(
8  ExecutorHelper* helper,
9  const ExecutionOptions& options)
10  : helper_(helper), options_(options), frozen_(false) {}
11 
12 bool AsyncTaskGraph::CreateNode(
13  int node_id,
14  const std::vector<OperatorBase*>& ops) {
15  CAFFE_ENFORCE(!frozen_);
16  if (!nodes_.count(node_id)) {
17  nodes_[node_id] = caffe2::make_unique<AsyncTask>(ops);
18  return true;
19  } else {
20  return false;
21  }
22 }
23 
24 bool AsyncTaskGraph::AddDependency(
25  int child_node_id,
26  const std::vector<int>& parent_node_ids) {
27  CAFFE_ENFORCE(!frozen_);
28  CAFFE_ENFORCE(!parent_node_ids.empty());
29  CAFFE_ENFORCE(nodes_.count(child_node_id));
30  for (auto node_id : parent_node_ids) {
31  CAFFE_ENFORCE(nodes_.count(node_id));
32  }
33  CAFFE_ENFORCE(!parents_.count(child_node_id));
34 
35  auto* child_task = nodes_[child_node_id].get();
36  auto child_device = child_task->GetDeviceOption();
37 
38  std::vector<AsyncTaskFuture*> parent_futures;
39  for (auto node_id : parent_node_ids) {
40  parents_[child_node_id].insert(node_id);
41  children_[node_id].insert(child_node_id);
42  parent_futures.push_back(&nodes_[node_id]->GetFuture());
43  }
44 
45  AsyncTaskFuture* parents_future = nullptr;
46  if (parent_futures.size() > 1) {
47  edge_futures_.push_back(
48  caffe2::make_unique<AsyncTaskFuture>(parent_futures));
49  parents_future = edge_futures_.back().get();
50  } else {
51  CAFFE_ENFORCE_EQ(parent_futures.size(), 1);
52  parents_future = parent_futures.back();
53  }
54 
55  // TODO: CUDA polling
56  parents_future->SetCallback(
57  [this, child_task, child_device](const AsyncTaskFuture* f) {
58  CAFFE_ENFORCE(f->IsCompleted());
59  if (!f->IsFailed()) {
60  // if we're in the correct thread pool and DFS scheduling is enabled,
61  // immediately call task inline, otherwise send task into thread pool
62  auto* pool = helper_->GetPool(child_device);
63  if (pool->inThreadPool() && options_.use_dfs_scheduling_) {
64  child_task->Run(options_);
65  } else {
66  pool->run([this, child_task]() { child_task->Run(options_); });
67  }
68  } else {
69  // skip task execution and propagate error further
70  child_task->GetFuture().SetCompleted(f->ErrorMessage().c_str());
71  }
72  });
73 
74  return true;
75 }
76 
77 void AsyncTaskGraph::FreezeGraph() {
78  if (frozen_) {
79  return;
80  }
81 
82  CAFFE_ENFORCE(!run_future_);
83  CAFFE_ENFORCE(root_tasks_.empty());
84 
85  std::vector<AsyncTaskFuture*> final_futures;
86  for (auto& kv : nodes_) {
87  auto task_id = kv.first;
88  auto* task = kv.second.get();
89 
90  if (parents_[task_id].empty()) {
91  root_tasks_.push_back(task);
92  }
93 
94  if (children_[task_id].empty()) {
95  auto& future = task->GetFuture();
96  final_futures.push_back(&future);
97  }
98  }
99 
100  CAFFE_ENFORCE(!root_tasks_.empty());
101  CAFFE_ENFORCE(!final_futures.empty());
102 
103  run_future_ = caffe2::make_unique<AsyncTaskFuture>(final_futures);
104 
105  frozen_ = true;
106 }
107 
108 AsyncTaskFuture* AsyncTaskGraph::ExecuteGraph() {
109  CAFFE_ENFORCE(frozen_);
110  CAFFE_ENFORCE(run_future_ && !run_future_->IsCompleted());
111 
112  // TODO: run root tasks inline in inference mode
113  for (auto* task : root_tasks_) {
114  auto task_device = task->GetDeviceOption();
115  helper_->GetPool(task_device)->run([this, task]() { task->Run(options_); });
116  }
117 
118  return run_future_.get();
119 }
120 
121 AsyncTaskFuture* AsyncTaskGraph::GetFuture() {
122  CAFFE_ENFORCE(frozen_);
123  return run_future_.get();
124 }
125 
126 void AsyncTaskGraph::Reset() {
127  CAFFE_ENFORCE(frozen_);
128  for (auto& kv : nodes_) {
129  kv.second->Reset();
130  }
131  for (auto& future : edge_futures_) {
132  future->ResetState();
133  }
134  if (run_future_) {
135  run_future_->ResetState();
136  }
137 }
138 
139 }; // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13