Caffe2 - C++ API
A deep learning, cross platform ML framework
net_async_base.cc
1 #include "caffe2/core/net_async_base.h"
2 
3 #include "caffe2/core/net_async_tracing.h"
4 #include "caffe2/core/operator.h"
5 #include "caffe2/core/timer.h"
6 
7 // experimental support for multiple streams per worker per GPU
8 C10_DEFINE_int(
9  caffe2_streams_per_gpu,
10  1,
11  "Number of streams per worker per GPU"
12  " to use in GPU thread pool (experimental)");
13 
14 C10_DEFINE_bool(
15  caffe2_net_async_inference_mode,
16  false,
17  "If set, use one single chain containing all ops");
18 
19 C10_DEFINE_int(
20  caffe2_net_async_max_gpus,
21  16,
22  "Max number of GPUs allowed in net async executor");
23 
24 C10_DEFINE_int(
25  caffe2_net_async_max_numa_nodes,
26  8,
27  "Max number of NUMA nodes allowed in net async executor");
28 
29 C10_DEFINE_int(
30  caffe2_net_async_thread_pool_size,
31  0,
32  "Number of threads in device thread pool by default");
33 
34 C10_DEFINE_bool(
35  caffe2_net_async_check_stream_status,
36  false,
37  "Select next non-busy stream");
38 
39 C10_DEFINE_bool(
40  caffe2_net_async_use_single_pool,
41  false,
42  "Use single thread pool for all devices");
43 
44 C10_DEFINE_bool(
45  caffe2_net_async_use_per_net_pools,
46  false,
47  "Use per net thread pools");
48 
49 C10_DEFINE_bool(
50  caffe2_net_async_run_root_tasks_inline,
51  false,
52  "Run root tasks in current thread instread of scheduling to threadpool");
53 
54 namespace caffe2 {
55 
56 std::vector<int>& AsyncNetBase::getStreamCounters() {
57  static thread_local std::vector<int> stream_counters_;
58  return stream_counters_;
59 }
60 
61 AsyncNetBase::AsyncNetBase(
62  const std::shared_ptr<const NetDef>& net_def,
63  Workspace* ws)
64  : NetBase(net_def, ws), options_(net_def), counters_(net_def) {
65  operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws);
66  helper_ = caffe2::make_unique<AsyncNetExecutorHelper>(this);
67  operators_.reserve(operator_nodes_.size());
68  for (const auto& node : operator_nodes_) {
69  auto op_ptr = node.operator_.get();
70  op_ptr->SetExecutorHelper(helper_.get());
71  operators_.push_back(op_ptr);
72  }
73 
74  if (FLAGS_caffe2_net_async_inference_mode) {
75  execution_chains_ = dag_utils::computeGroups(operator_nodes_);
76  } else {
77  execution_chains_ = dag_utils::computeChains(operator_nodes_);
78  }
79  chains_.reserve(execution_chains_.size());
80  for (const auto& kv : execution_chains_) {
81  chains_.push_back(kv.second);
82  }
83  chain_nodes_ = dag_utils::prepareChainGraphNodes(operator_nodes_, chains_);
84 
85  events_.reserve(chains_.size());
86  for (const auto& chain : chains_) {
87  const auto& last_op = operators_[chain.back()];
88  events_.push_back(&last_op->event());
89  // keep events for inner chain ops in case of profiling
90  if (!options_.report_stats_) {
91  for (const auto& op_id : chain) {
92  if (op_id == chain.back() || op_id == chain.front()) {
93  continue;
94  }
95  const auto& op = operators_[op_id];
96  op->DisableEvent();
97  }
98  }
99  }
100 
101  num_workers_ = net_def->has_num_workers() ? net_def->num_workers() : -1;
102 
103  tracer_ = tracing::create(this, net_def->name());
104  if (tracer_) {
105  LOG(INFO) << "Tracing net: " << net_def->name();
106  }
107 }
108 
109 bool AsyncNetBase::handleRunError() {
110 #ifdef CAFFE2_USE_EXCEPTION_PTR
111  // Check net's events for exceptions and rethrow chronologically the first one
112  int first_exc_task_id = -1;
113  int64_t first_exc_ts = 0;
114  for (int task_id = 0; task_id < tasksNum(); ++task_id) {
115  if (event(task_id).HasException()) {
116  if (first_exc_task_id >= 0) {
117  auto exc_ts = event(task_id).ExceptionTimestamp();
118  if (exc_ts < first_exc_ts) {
119  first_exc_task_id = task_id;
120  first_exc_ts = exc_ts;
121  }
122  } else {
123  first_exc_task_id = task_id;
124  first_exc_ts = event(task_id).ExceptionTimestamp();
125  }
126  }
127  }
128  if (first_exc_task_id >= 0) {
129  LOG(ERROR) << "Rethrowing exception from the run of '" << Name() << "'";
130  event(first_exc_task_id).RethrowException();
131  }
132 #endif // CAFFE2_USE_EXCEPTION_PTR
133 
134  if (!success_) {
135  LOG(ERROR) << "Error encountered in the run of '" << Name() << "'";
136  }
137  return success_;
138 }
139 
140 bool AsyncNetBase::RunAsync() {
141  tracing::startIter(tracer_);
142  reset();
143  return DoRunAsync();
144 }
145 
146 TaskThreadPoolBase* AsyncNetBase::poolGetter(
147  PoolsMap& pools,
148  int device_type,
149  int device_id,
150  int pool_size) {
151  std::unique_lock<std::mutex> pools_lock(pools_mutex_);
152  auto pool = pools[device_id][pool_size];
153  if (!pool) {
154  pool = c10::ThreadPoolRegistry()->Create(
155  DeviceTypeName(device_type),
156  device_id,
157  pool_size,
158  options_.use_per_net_pools_);
159  pools[device_id][pool_size] = pool;
160  }
161  return pool.get();
162 }
163 
164 TaskThreadPoolBase* AsyncNetBase::pool() {
165  // By default using a non-pinned CPU option
166  DeviceOption dev;
167  dev.set_device_type(PROTO_CPU);
168  return pool(dev);
169 }
170 
171 TaskThreadPoolBase* AsyncNetBase::pool(const DeviceOption& device_option) {
172  if (options_.use_single_pool_) {
173  return poolGetter(cpu_pools_, PROTO_CPU, -1, num_workers_);
174  }
175  const auto device_type = device_option.device_type();
176  if (IsCPUDeviceType(device_type)) {
177  auto numa_node_id = -1;
178  if (device_option.has_numa_node_id()) {
179  numa_node_id = device_option.numa_node_id();
180  CAFFE_ENFORCE_GE(numa_node_id, 0, "Invalid NUMA node id: ", numa_node_id);
181  }
182  CAFFE_ENFORCE_LT(
183  numa_node_id,
184  FLAGS_caffe2_net_async_max_numa_nodes,
185  "Invalid NUMA node id: ",
186  numa_node_id);
187  return poolGetter(cpu_pools_, device_type, numa_node_id, num_workers_);
188  } else if (IsGPUDeviceType(device_type)) {
189  auto gpu_id = device_option.device_id();
190  CAFFE_ENFORCE(
191  gpu_id >= 0 && gpu_id < FLAGS_caffe2_net_async_max_gpus,
192  "Invalid GPU id: " + c10::to_string(gpu_id));
193  return poolGetter(gpu_pools_, device_type, gpu_id, num_workers_);
194  } else {
195  CAFFE_THROW("Unsupported device type " + c10::to_string(device_type));
196  }
197 }
198 
199 int AsyncNetBase::stream(int task_id) {
200  const auto& device_option = event(task_id).GetDeviceOption();
201  int stream_id = 0;
202  if (IsGPUDeviceType(device_option.device_type())) {
203  int gpu_id = device_option.device_id();
204  CAFFE_ENFORCE_GE(gpu_id, 0, "Invalid gpu id: " + c10::to_string(gpu_id));
205  if ((unsigned)gpu_id >= getStreamCounters().size()) {
206  getStreamCounters().resize(gpu_id + 1, 0);
207  }
208  do {
209  stream_id = getStreamCounters().at(gpu_id)++;
210  getStreamCounters().at(gpu_id) %= options_.streams_per_gpu_;
211  } while (options_.check_stream_status_ &&
212  !isStreamFree(task_id, stream_id));
213  }
214  return stream_id;
215 }
216 
217 bool AsyncNetBase::isStreamFree(int task_id, int stream_id) const {
218  auto& task = chains_[task_id];
219  auto& last_task_op = operators_[task.back()];
220  return last_task_op->IsStreamFree(stream_id);
221 }
222 
223 bool AsyncNetBase::canSchedule(
224  int task_id,
225  const std::vector<EventStatus>* status,
226  bool* parent_failed) {
227  auto first_child_op_id = chains_[task_id].front();
228  for (auto parent_id : parents(task_id)) {
229  auto last_parent_op_id = chains_[parent_id].back();
230  EventStatus parent_status;
231  if (status) {
232  parent_status = status->at(parent_id);
233  } else {
234  parent_status = operators_[last_parent_op_id]->event().Query();
235  }
236 
237  if (parent_status == EventStatus::EVENT_FAILED) {
238  if (parent_failed) {
239  *parent_failed = true;
240  }
241  return false;
242  }
243 
244  bool can_schedule = Event::CanSchedule(
245  operators_[last_parent_op_id]->event().GetType(),
246  parent_status,
247  operators_[first_child_op_id]->event().GetType(),
248  operators_[first_child_op_id]->SupportsAsyncScheduling());
249  if (!can_schedule) {
250  return false;
251  }
252  }
253 
254  return true;
255 }
256 
257 bool AsyncNetBase::canSchedule(int parent_id, int child_id) {
258  auto& parent_event = event(parent_id);
259  auto first_child_op_id = chains_[child_id].front();
260  auto* first_child_op = operators_[first_child_op_id];
261  return Event::CanSchedule(
262  parent_event.GetType(),
263  parent_event.Query(),
264  first_child_op->event().GetType(),
265  first_child_op->SupportsAsyncScheduling());
266 }
267 
268 int AsyncNetBase::tasksNum() const {
269  return chains_.size();
270 }
271 
272 Event& AsyncNetBase::event(int task_id) const {
273  auto& task = chains_[task_id];
274  auto& last_task_op = operators_[task.back()];
275  return last_task_op->event();
276 }
277 
278 EventStatus AsyncNetBase::query(int task_id) const {
279  return event(task_id).Query();
280 }
281 
282 const std::vector<int>& AsyncNetBase::children(int task_id) const {
283  const auto& task_node = chain_nodes_[task_id];
284  return task_node.children_;
285 }
286 
287 const std::vector<int>& AsyncNetBase::parents(int task_id) const {
288  const auto& task_node = chain_nodes_[task_id];
289  return task_node.parents_;
290 }
291 
292 int AsyncNetBase::getParentCount(int child_id) {
293  auto& child_ops = chains_[child_id];
294  auto& child_node = operator_nodes_[child_ops.front()];
295  return child_node.runtime_parent_count_.load();
296 }
297 
298 int AsyncNetBase::updateParentCount(int child_id) {
299  auto& child_ops = chains_[child_id];
300  auto& child_node = operator_nodes_[child_ops.front()];
301  int parent_count = --child_node.runtime_parent_count_;
302  CAFFE_ENFORCE_GE(parent_count, 0);
303  return parent_count;
304 }
305 
306 bool AsyncNetBase::testAndSetScheduled(int task_id) {
307  auto& task_ops = chains_[task_id];
308  auto& task_op_node = operator_nodes_[task_ops.front()];
309  return !task_op_node.scheduled_.test_and_set();
310 }
311 
312 int AsyncNetBase::numOps(int task_id) const {
313  return chains_[task_id].size();
314 }
315 
316 int AsyncNetBase::firstTaskOpId(int task_id) const {
317  return chains_[task_id].front();
318 }
319 
320 int AsyncNetBase::lastTaskOpId(int task_id) const {
321  return chains_[task_id].back();
322 }
323 
324 const OperatorBase* AsyncNetBase::firstTaskOp(int task_id) const {
325  return operator_nodes_[firstTaskOpId(task_id)].operator_.get();
326 }
327 
328 const OperatorBase* AsyncNetBase::lastTaskOp(int task_id) const {
329  return operator_nodes_[lastTaskOpId(task_id)].operator_.get();
330 }
331 
332 OperatorBase* AsyncNetBase::firstTaskOp(int task_id) {
333  return operator_nodes_[firstTaskOpId(task_id)].operator_.get();
334 }
335 
336 OperatorBase* AsyncNetBase::lastTaskOp(int task_id) {
337  return operator_nodes_[lastTaskOpId(task_id)].operator_.get();
338 }
339 
340 void AsyncNetBase::asyncWait(
341  int task_id,
342  int stream_id,
343  const std::vector<int>& wait_task_ids) const {
344  auto first_op_id = chains_[task_id].front();
345  auto& first_op = operators_[first_op_id];
346  std::vector<const Event*> events;
347  events.reserve(wait_task_ids.size());
348  for (auto wait_task_id : wait_task_ids) {
349  events.push_back(&event(wait_task_id));
350  }
351  first_op->WaitEvents(events, stream_id);
352 }
353 
354 void AsyncNetBase::reset() {
355  for (auto& op : GetOperators()) {
356  op->ResetEvent();
357  }
358  for (auto task_id = 0; task_id < tasksNum(); ++task_id) {
359  auto& task_ops = chains_[task_id];
360  auto& task_op_node = operator_nodes_[task_ops.front()];
361  task_op_node.runtime_parent_count_ = parents(task_id).size();
362  task_op_node.scheduled_.clear();
363  }
364 
365  success_ = true;
366 }
367 
368 void AsyncNetBase::handleChainError(
369  int task_id,
370  OperatorBase* op,
371  const char* err_str,
372  bool save_exception) noexcept {
373  std::string err_msg = err_str;
374  if (op) {
375  err_msg += ", op " + (op->has_debug_def() ? op->type() : " unknown");
376  }
377  LOG(ERROR) << err_msg;
378  // mark end of chain with an error
379  if (query(task_id) == EventStatus::EVENT_INITIALIZED) {
380  if (save_exception) {
381  event(task_id).SetFinishedWithException(err_msg.c_str());
382  } else {
383  event(task_id).SetFinished(err_msg.c_str());
384  }
385  }
386 }
387 
388 bool AsyncNetBase::run(int task_id, int stream_id) noexcept {
389  OperatorBase* op = nullptr;
390  try {
391  // Optionally insert async wait ops,
392  // skip when finish_chain_ is set -
393  // all parents are guaranteed to be finished
394  if (!options_.finish_chain_) {
395  asyncWait(task_id, stream_id, parents(task_id));
396  }
397  for (auto& op_id : chains_[task_id]) {
398  op = operators_[op_id];
399  bool success = false;
400  if (!options_.report_stats_) {
401  TRACE_EVENT(
402  tracing::TRACE_OP,
403  op_id,
404  tracing::TRACE_TASK,
405  task_id,
406  tracing::TRACE_STREAM,
407  stream_id);
408  success = op->RunAsync(stream_id);
409  } else {
410  counters_.AddPerOpStartTime(op_id);
411  success = op->RunAsync(stream_id);
412  if (success && op->device_option().device_type() != PROTO_CPU) {
413  op->Finish();
414  }
415  counters_.AddPerOpEndTime(op_id);
416  }
417 
418  if (!success) {
419  handleChainError(task_id, op, "Failed to execute an op");
420  return false;
421  }
422  }
423 
424  op = nullptr;
425  if (options_.finish_chain_) {
426  operators_[chains_[task_id].back()]->event().Finish();
427  }
428  } catch (const std::exception& e) {
429  handleChainError(task_id, op, e.what(), /* save_exception */ true);
430  return false;
431  } catch (...) {
432  handleChainError(
433  task_id,
434  op,
435  "Failed to execute task: unknown error",
436  /* save_exception */ true);
437  return false;
438  }
439 
440  return true;
441 }
442 
443 void AsyncNetBase::finishTasks(const std::unordered_set<int>& task_ids) {
444  for (const auto& task_id : task_ids) {
445  event(task_id).Finish();
446  }
447 }
448 
449 void AsyncNetBase::finalizeEvents() {
450  for (auto task_id = 0; task_id < tasksNum(); ++task_id) {
451  auto status = query(task_id);
452  if (status == EventStatus::EVENT_SCHEDULED) {
453  event(task_id).Finish();
454  } else if (status == EventStatus::EVENT_INITIALIZED) {
455  event(task_id).SetFinished();
456  }
457  if (event(task_id).Query() != EventStatus::EVENT_SUCCESS) {
458  success_ = false;
459  }
460  }
461 }
462 
463 ProfDAGProtos AsyncNetBase::GetOperatorStats() const {
464  return counters_.GetReport().GetOperatorStats();
465 }
466 
467 ProfDAGProtos AsyncNetBase::GetPerOperatorCost() const {
468  return counters_.GetReport().GetPerOperatorCost();
469 }
470 
471 ProfDAGReport AsyncNetBase::GetProfReport() const {
472  return counters_.GetReport();
473 }
474 
475 AsyncNetBase::~AsyncNetBase() {
476  if (options_.report_stats_) {
477  counters_.GetReport().PrintStats();
478  }
479 }
480 
481 ExecutionOptions::ExecutionOptions(
482  const std::shared_ptr<const NetDef>& net_def) {
483  static const std::string kDag = "dag";
484  static const std::string kProfDag = "prof_dag";
485  static const std::string kAsyncDag = "async_dag";
486  static const std::string kSimpleNet = "simple";
487 
488  std::string net_type;
489  if (net_def->has_type() && !net_def->type().empty()) {
490  net_type = net_def->type();
491  } else {
492  net_type = kSimpleNet;
493  }
494  if (net_type == kDag || net_type == kProfDag) {
495  streams_per_gpu_ = 1;
496  finish_chain_ = true;
497  always_schedule_child_ = true;
498  check_stream_status_ = false;
499  use_single_pool_ = true;
500  use_per_net_pools_ = true;
501  is_blocking_ = true;
502  report_stats_ = (net_type == kProfDag);
503  } else if (net_type == kAsyncDag) {
504  streams_per_gpu_ = 1;
505  finish_chain_ = false;
506  always_schedule_child_ = true;
507  check_stream_status_ = false;
508  use_single_pool_ = true;
509  use_per_net_pools_ = true;
510  is_blocking_ = true;
511  report_stats_ = false;
512  } else {
513  streams_per_gpu_ = FLAGS_caffe2_streams_per_gpu;
514  finish_chain_ = false;
515  always_schedule_child_ = false;
516  check_stream_status_ = FLAGS_caffe2_net_async_check_stream_status;
517  use_single_pool_ = FLAGS_caffe2_net_async_use_single_pool;
518  use_per_net_pools_ = FLAGS_caffe2_net_async_use_per_net_pools;
519  is_blocking_ = false;
520  report_stats_ = false;
521  }
522 
523  use_dfs_scheduling_ = false;
524 
525  for (int arg_idx = 0; arg_idx < net_def->arg_size(); ++arg_idx) {
526  auto& arg = net_def->arg(arg_idx);
527  if (arg.has_name() && arg.name() == "enable_profiling") {
528  CAFFE_ENFORCE(arg.has_i(), "enable_profiling should be an int");
529  report_stats_ = arg.i() == 1;
530  }
531  if (arg.has_name() && arg.name() == "deferrable_mode") {
532  CAFFE_ENFORCE(arg.has_i(), "deferrable_mode should be an int");
533  use_dfs_scheduling_ = arg.i() == 1; // corr. to DFS scheduling
534  }
535  }
536 
537  run_root_tasks_inline_ = FLAGS_caffe2_net_async_run_root_tasks_inline;
538 }
539 
540 } // namespace caffe2
541 
542 namespace c10 {
543 
544 C10_REGISTER_CREATOR(
545  ThreadPoolRegistry,
546  CPU,
547  caffe2::GetAsyncNetThreadPool<TaskThreadPool, caffe2::PROTO_CPU>);
548 C10_REGISTER_CREATOR(
549  ThreadPoolRegistry,
550  CUDA,
551  caffe2::GetAsyncNetThreadPool<TaskThreadPool, caffe2::PROTO_CUDA>);
552 C10_REGISTER_CREATOR(
553  ThreadPoolRegistry,
554  HIP,
555  caffe2::GetAsyncNetThreadPool<TaskThreadPool, caffe2::PROTO_HIP>);
556 
557 } // namespace c10
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7