1 #include "caffe2/core/net_async_task_graph.h" 3 #include "caffe2/core/net_parallel.h" 7 AsyncTaskGraph::AsyncTaskGraph(
8 ExecutorHelper* helper,
9 const ExecutionOptions& options)
10 : helper_(helper), options_(options), frozen_(false) {}
12 bool AsyncTaskGraph::CreateNode(
14 const std::vector<OperatorBase*>& ops) {
15 CAFFE_ENFORCE(!frozen_);
16 if (!nodes_.count(node_id)) {
17 nodes_[node_id] = caffe2::make_unique<AsyncTask>(ops);
24 bool AsyncTaskGraph::AddDependency(
26 const std::vector<int>& parent_node_ids) {
27 CAFFE_ENFORCE(!frozen_);
28 CAFFE_ENFORCE(!parent_node_ids.empty());
29 CAFFE_ENFORCE(nodes_.count(child_node_id));
30 for (
auto node_id : parent_node_ids) {
31 CAFFE_ENFORCE(nodes_.count(node_id));
33 CAFFE_ENFORCE(!parents_.count(child_node_id));
35 auto* child_task = nodes_[child_node_id].get();
36 auto child_device = child_task->GetDeviceOption();
38 std::vector<AsyncTaskFuture*> parent_futures;
39 for (
auto node_id : parent_node_ids) {
40 parents_[child_node_id].insert(node_id);
41 children_[node_id].insert(child_node_id);
42 parent_futures.push_back(&nodes_[node_id]->GetFuture());
45 AsyncTaskFuture* parents_future =
nullptr;
46 if (parent_futures.size() > 1) {
47 edge_futures_.push_back(
48 caffe2::make_unique<AsyncTaskFuture>(parent_futures));
49 parents_future = edge_futures_.back().get();
51 CAFFE_ENFORCE_EQ(parent_futures.size(), 1);
52 parents_future = parent_futures.back();
56 parents_future->SetCallback(
57 [
this, child_task, child_device](
const AsyncTaskFuture* f) {
58 CAFFE_ENFORCE(f->IsCompleted());
62 auto* pool = helper_->GetPool(child_device);
63 if (pool->inThreadPool() && options_.use_dfs_scheduling_) {
64 child_task->Run(options_);
66 pool->run([
this, child_task]() { child_task->Run(options_); });
70 child_task->GetFuture().SetCompleted(f->ErrorMessage().c_str());
77 void AsyncTaskGraph::FreezeGraph() {
82 CAFFE_ENFORCE(!run_future_);
83 CAFFE_ENFORCE(root_tasks_.empty());
85 std::vector<AsyncTaskFuture*> final_futures;
86 for (
auto& kv : nodes_) {
87 auto task_id = kv.first;
88 auto* task = kv.second.get();
90 if (parents_[task_id].empty()) {
91 root_tasks_.push_back(task);
94 if (children_[task_id].empty()) {
95 auto& future = task->GetFuture();
96 final_futures.push_back(&future);
100 CAFFE_ENFORCE(!root_tasks_.empty());
101 CAFFE_ENFORCE(!final_futures.empty());
103 run_future_ = caffe2::make_unique<AsyncTaskFuture>(final_futures);
108 AsyncTaskFuture* AsyncTaskGraph::ExecuteGraph() {
109 CAFFE_ENFORCE(frozen_);
110 CAFFE_ENFORCE(run_future_ && !run_future_->IsCompleted());
113 for (
auto* task : root_tasks_) {
114 auto task_device = task->GetDeviceOption();
115 helper_->GetPool(task_device)->run([
this, task]() { task->Run(options_); });
118 return run_future_.get();
121 AsyncTaskFuture* AsyncTaskGraph::GetFuture() {
122 CAFFE_ENFORCE(frozen_);
123 return run_future_.get();
126 void AsyncTaskGraph::Reset() {
127 CAFFE_ENFORCE(frozen_);
128 for (
auto& kv : nodes_) {
131 for (
auto& future : edge_futures_) {
132 future->ResetState();
135 run_future_->ResetState();
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...