Caffe2 - C++ API
A deep learning, cross platform ML framework
net_async_scheduling.cc
1 #include "caffe2/core/net_async_scheduling.h"
2 
3 #include "caffe2/core/net_async_tracing.h"
4 
5 namespace caffe2 {
6 
7 AsyncSchedulingNet::AsyncSchedulingNet(
8  const std::shared_ptr<const NetDef>& net_def,
9  Workspace* ws)
10  : AsyncNetBase(net_def, ws), running_(false) {}
11 
12 void AsyncSchedulingNet::reset() {
13  AsyncNetBase::reset();
14  processed_tasks_num_ = 0;
15 }
16 
17 void AsyncSchedulingNet::Wait() {
18  std::unique_lock<std::mutex> lock(running_mutex_);
19  while (running_) {
20  running_cv_.wait(lock);
21  }
22 }
23 
24 bool AsyncSchedulingNet::isInlineTask(int parent_id, int child_id) const {
25  if (!options_.use_dfs_scheduling_) {
26  return false;
27  }
28  const auto* last_parent_op = lastTaskOp(parent_id);
29  const auto* first_child_op = firstTaskOp(child_id);
30  // check that we do not cross device boundary
31  return IsSameDevice(
32  last_parent_op->device_option(), first_child_op->device_option());
33 }
34 
35 // schedule() is not supposed to throw, all exceptions in the ops are caught
36 // and reported in the end of the graph's execution, the full graph of tasks
37 // is expected to be scheduled
38 void AsyncSchedulingNet::schedule(int task_id, bool run_inline) noexcept {
39  if (!testAndSetScheduled(task_id)) {
40  return;
41  }
42  auto schedule_func = [this, task_id]() {
43  try {
44  if (success_) {
45  int stream_id = 0;
46  if (options_.streams_per_gpu_ > 1) {
47  try {
48  stream_id = stream(task_id);
49  } catch (const std::exception& e) {
50  C10_LOG_EVERY_MS(ERROR, 1000)
51  << "Failed to select a stream: " << e.what();
52  }
53  }
54  if (!run(task_id, stream_id)) {
55  success_ = false;
56  }
57  }
58 
59  if (options_.report_stats_) {
60  try {
61  auto last_op_id = lastTaskOpId(task_id);
62  auto* last_op = lastTaskOp(task_id);
63  if (last_op->device_option().device_type() == PROTO_CPU &&
64  last_op->HasAsyncPart()) {
65  last_op->event().SetCallback([this, last_op_id] {
66  counters_.AddPerOpAsyncEndTime(last_op_id);
67  });
68  }
69  } catch (const std::exception& e) {
70  C10_LOG_EVERY_MS(ERROR, 1000)
71  << "Failed to report operator stats: " << e.what();
72  }
73  }
74 
75  for (auto child_id : children(task_id)) {
76  int parent_count = updateParentCount(child_id);
77  if (parent_count == 0) {
78  // Schedule a child if:
79  // - there is failure, we skip an op execution and finish the job
80  // - forced scheduling though always_schedule_child_
81  // - finish_chain_ is set, in this case parents are
82  // guaranteed to be finished
83  // - in all other cases, check parents with canSchedule
84  if (!success_ || options_.always_schedule_child_ ||
85  options_.finish_chain_ || canSchedule(child_id)) {
86  // if DFS scheduling is enabled, run children inline,
87  // ignore DFS scheduling in callbacks
88  schedule(child_id, isInlineTask(task_id, child_id));
89  } else {
90  bool parent_failed = false;
91  bool parent_needs_polling = false;
92  std::vector<int> parents_with_callback;
93 
94  for (auto parent_id : parents(child_id)) {
95  auto& parent_event = event(parent_id);
96  auto parent_status = parent_event.Query();
97 
98  if (parent_status == EventStatus::EVENT_FAILED) {
99  parent_failed = true;
100  break;
101  } else if (parent_status == EventStatus::EVENT_SCHEDULED) {
102  // parent is not finished yet, check if this is blocking us
103  // from scheduling a child
104  if (!canSchedule(parent_id, child_id)) {
105  // we can't schedule a child because of this parent,
106  // check if parent supports callback
107  if (parent_event.SupportsCallback()) {
108  parents_with_callback.push_back(parent_id);
109  } else {
110  parent_needs_polling = true;
111  break;
112  }
113  }
114  } else if (parent_status != EventStatus::EVENT_SUCCESS) {
115  VLOG(1) << "Unexpected parent task state: " << parent_status
116  << ", task id: " << child_id
117  << ", parent task id: " << parent_id;
118  parent_failed = true;
119  break;
120  }
121  }
122 
123  if (parent_failed) {
124  // one of parents failed, set failure flag and wrap up execution
125  success_ = false;
126  schedule(child_id, isInlineTask(task_id, child_id));
127  } else if (parent_needs_polling) {
128  // some parents are blocking us from scheduling a child and don't
129  // support callbacks, using polling
130  const auto& child_device_option =
131  event(child_id).GetDeviceOption();
132  pool(child_device_option)
133  ->run(std::bind(
134  &AsyncSchedulingNet::pollAndSchedule, this, child_id));
135  } else if (!parents_with_callback.empty()) {
136  // some parents are blocking us from scheduling a child and they
137  // support callbacks
138  for (auto parent_id : parents_with_callback) {
139  event(parent_id).SetCallback(std::bind(
140  &AsyncSchedulingNet::parentCallback, this, parent_id));
141  }
142  } else {
143  // we're ready to schedule a child
144  schedule(child_id, isInlineTask(task_id, child_id));
145  }
146  }
147  }
148  }
149 
150  // In case of net's failure, make sure all pending tasks are finished
151  if (!success_) {
152  // Simple logic to capture all pending tasks - check all tasks
153  // at the end of each task in case of net's failure
154  for (auto tid = 0; tid < tasksNum(); ++tid) {
155  if (event(tid).Query() == EventStatus::EVENT_SCHEDULED) {
156  // SetFinished may throw, e.g. when we call it on already finished
157  // event, and in some other cases (CUDA)
158  try {
159  event(tid).SetFinished("Cancelled");
160  } catch (const EnforceNotMet&) {
161  // ignore
162  }
163  }
164  }
165  }
166 
167  // finishRun may cause waiters to wake up and destroy the net,
168  // before we call finishRun we need to make sure all other (finishing)
169  // tasks are done;
170  // Bumping and checking the counter after the task's job is done
171  auto tasks_num = tasksNum();
172  auto cur_processed_tasks = ++processed_tasks_num_;
173  if (cur_processed_tasks == tasks_num) {
174  finishRun();
175  }
176  } catch (const std::exception& e) {
177  // error of core scheduling and/or logic, will call terminate
178  LOG(FATAL) << "Unexpected error during graph scheduling run: "
179  << e.what();
180  } catch (...) {
181  LOG(FATAL) << "Unknown error during graph scheduling run";
182  }
183  };
184 
185  if (run_inline) {
186  schedule_func();
187  } else {
188  const auto& device_option = event(task_id).GetDeviceOption();
189  pool(device_option)->run(schedule_func);
190  }
191 }
192 
193 void AsyncSchedulingNet::parentCallback(int parent_id) {
194  if (event(parent_id).Query() != EventStatus::EVENT_SUCCESS) {
195  success_ = false;
196  }
197  for (auto child_id : children(parent_id)) {
198  int parent_count = getParentCount(child_id);
199  if (parent_count == 0) {
200  if (!success_ || canSchedule(child_id)) {
201  schedule(child_id);
202  }
203  }
204  }
205 }
206 
207 void AsyncSchedulingNet::pollAndSchedule(int task_id) {
208  bool parent_failed = false;
209  bool can_schedule = canSchedule(task_id, nullptr, &parent_failed);
210  if (parent_failed) {
211  success_ = false;
212  }
213  // schedule the task if:
214  // - parents are ready
215  // - we failed / cleanup started (no ops will run)
216 
217  if (can_schedule || !success_ || parent_failed) {
218  schedule(task_id);
219  } else {
220  const auto& device_option = event(task_id).GetDeviceOption();
221  pool(device_option)
222  ->run(std::bind(&AsyncSchedulingNet::pollAndSchedule, this, task_id));
223  }
224 }
225 
226 void AsyncSchedulingNet::finishRun() {
227  std::unique_lock<std::mutex> lock(running_mutex_);
228  // wait for scheduled ops and make sure all events are marked as finished
229  finalizeEvents();
230  if (options_.report_stats_) {
231  counters_.ReportRunEnd();
232  }
233  // notify observers and waiters
234  StopAllObservers();
235  running_ = false;
236  running_cv_.notify_all();
237 }
238 
239 bool AsyncSchedulingNet::RunAsync() {
240  try {
241  std::unique_lock<std::mutex> lock(running_mutex_);
242  if (running_) {
243  LOG(ERROR) << "Detected concurrent runs";
244  return false;
245  }
246  running_ = true;
247  reset();
248 
249  StartAllObservers();
250  tracing::startIter(tracer_);
251  if (options_.report_stats_) {
252  counters_.ReportRunStart();
253  }
254  } catch (const std::exception& e) {
255  LOG(ERROR) << "Exception while starting an async run: " << e.what();
256  finishRun();
257  throw;
258  } catch (...) {
259  LOG(ERROR) << "Exception while starting an async run: unknown error";
260  finishRun();
261  throw;
262  }
263 
264  // schedule() is not expected to throw, at this moment all the initial tasks
265  // will be scheduled and the full graph of tasks will be executed
266  for (auto task_id = 0; task_id < tasksNum(); ++task_id) {
267  if (parents(task_id).empty()) {
268  schedule(task_id, options_.run_root_tasks_inline_);
269  }
270  }
271 
272  if (tasksNum() == 0) {
273  finishRun();
274  }
275 
276  if (options_.is_blocking_) {
277  Wait();
278  }
279 
280  return true;
281 }
282 
283 AsyncSchedulingNet::~AsyncSchedulingNet() {
284  Wait();
285 }
286 
287 REGISTER_NET(async_scheduling, AsyncSchedulingNet);
288 
289 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13