Caffe2 - C++ API
A deep learning, cross platform ML framework
net_async_scheduling.cc
1 
17 #include "caffe2/core/net_async_scheduling.h"
18 
19 CAFFE2_DEFINE_bool(
20  caffe2_net_async_always_schedule_child,
21  false,
22  "Always schedule child chains from parent chain");
23 
24 namespace caffe2 {
25 
26 AsyncSchedulingNet::AsyncSchedulingNet(
27  const std::shared_ptr<const NetDef>& net_def,
28  Workspace* ws)
29  : AsyncNetBase(net_def, ws), running_(false) {
30  reset();
31 }
32 
33 void AsyncSchedulingNet::reset() {
34  processed_tasks_num_ = 0;
35  cleanup_ = false;
36  success_ = true;
37 
38  for (auto task_id = 0; task_id < tasksNum(); ++task_id) {
39  auto& task_ops = chains_[task_id];
40  auto& task_op_node = operator_nodes_[task_ops.front()];
41  task_op_node.runtime_parent_count_ = parents(task_id).size();
42  }
43  exception_messages_.clear();
44 }
45 
46 void AsyncSchedulingNet::Wait() {
47  std::unique_lock<std::mutex> lock(running_mutex_);
48  while (running_) {
49  running_cv_.wait(lock);
50  }
51 }
52 
53 void AsyncSchedulingNet::schedule(int task_id) {
54  const auto& device_option = event(task_id).GetDeviceOption();
55  pool(device_option)->run([this, task_id]() {
56  if (success_) {
57  int stream_id = stream(task_id);
58  asyncWait(task_id, stream_id, parents(task_id));
59  try {
60  run(task_id, stream_id);
61  } catch (const std::exception& e) {
62  std::unique_lock<std::mutex> lock(exception_mutex_);
63  exception_messages_.push_back(e.what());
64  success_ = false;
65  }
66  }
67 
68  auto task_count = ++processed_tasks_num_;
69 
70  for (auto child_id : children(task_id)) {
71  int parent_count = updateParentCount(child_id);
72  if (parent_count == 0) {
73  if (cleanup_ || FLAGS_caffe2_net_async_always_schedule_child ||
74  canSchedule(child_id)) {
75  schedule(child_id);
76  } else {
77  const auto& device_option = event(child_id).GetDeviceOption();
78  pool(device_option)
79  ->run(std::bind(
80  &AsyncSchedulingNet::pollAndSchedule, this, child_id));
81  }
82  }
83  }
84 
85  if (success_) {
86  if (task_count == tasksNum()) {
87  // All tasks are finished, polling thread is sleeping;
88  // only one thread enters here
89  finalizeEvents();
90  finishRun();
91  return;
92  }
93  } else {
94  // Before setting running_ to false and notifying waiters we need to
95  // 1. Ensure that only one thread does the cleanup
96  // 2. Ensure that all other pending tasks in workers and polling threads
97  // are finished and
98  // 3. Ensure that all tasks that were not scheduled have their events set
99  {
100  std::unique_lock<std::mutex> cleanup_lock(cleanup_mutex_);
101  if (cleanup_) {
102  return;
103  }
104  cleanup_ = true;
105  }
106 
107  // Errors are not recoverable and happen in exceptional cases,
108  // ok to busy wait
109  while (processed_tasks_num_ != tasksNum()) {
110  }
111 
112  // Make sure all events are set, wait for scheduled events
113  finalizeEvents();
114 
115  // Notify observers and waiters
116  finishRun();
117  }
118  });
119 }
120 
121 void AsyncSchedulingNet::pollAndSchedule(int task_id) {
122  if (canSchedule(task_id) || cleanup_) {
123  // force schedule the rest of the tasks if cleanup is started
124  schedule(task_id);
125  } else {
126  const auto& device_option = event(task_id).GetDeviceOption();
127  pool(device_option)
128  ->run(std::bind(&AsyncSchedulingNet::pollAndSchedule, this, task_id));
129  }
130 }
131 
132 int AsyncSchedulingNet::updateParentCount(int child_id) {
133  auto& child_ops = chains_[child_id];
134  auto& child_node = operator_nodes_[child_ops.front()];
135  int parent_count = --child_node.runtime_parent_count_;
136  CAFFE_ENFORCE_GE(parent_count, 0);
137  return parent_count;
138 }
139 
140 void AsyncSchedulingNet::finishRun() {
141  // notify observers and waiters
142  StopAllObservers();
143  running_ = false;
144  running_cv_.notify_all();
145 }
146 
147 bool AsyncSchedulingNet::DoRunAsync() {
148  std::unique_lock<std::mutex> lock(running_mutex_);
149  CAFFE_ENFORCE(!running_, "Concurrent RunAsync calls");
150  running_ = true;
151  reset();
152 
153  StartAllObservers();
154 
155  for (auto task_id = 0; task_id < tasksNum(); ++task_id) {
156  if (parents(task_id).empty()) {
157  schedule(task_id);
158  }
159  }
160 
161  return true;
162 }
163 
164 AsyncSchedulingNet::~AsyncSchedulingNet() {}
165 
166 REGISTER_NET(async_scheduling, AsyncSchedulingNet);
167 
168 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.