1 #include "caffe2/core/net_async_base.h" 3 #include "caffe2/core/net_async_tracing.h" 4 #include "caffe2/core/operator.h" 5 #include "caffe2/core/timer.h" 9 caffe2_streams_per_gpu,
11 "Number of streams per worker per GPU" 12 " to use in GPU thread pool (experimental)");
15 caffe2_net_async_inference_mode,
17 "If set, use one single chain containing all ops");
20 caffe2_net_async_max_gpus,
22 "Max number of GPUs allowed in net async executor");
25 caffe2_net_async_max_numa_nodes,
27 "Max number of NUMA nodes allowed in net async executor");
30 caffe2_net_async_thread_pool_size,
32 "Number of threads in device thread pool by default");
35 caffe2_net_async_check_stream_status,
37 "Select next non-busy stream");
40 caffe2_net_async_use_single_pool,
42 "Use single thread pool for all devices");
45 caffe2_net_async_use_per_net_pools,
47 "Use per net thread pools");
50 caffe2_net_async_run_root_tasks_inline,
52 "Run root tasks in current thread instread of scheduling to threadpool");
56 std::vector<int>& AsyncNetBase::getStreamCounters() {
57 static thread_local std::vector<int> stream_counters_;
58 return stream_counters_;
61 AsyncNetBase::AsyncNetBase(
62 const std::shared_ptr<const NetDef>& net_def,
64 : NetBase(net_def, ws), options_(net_def), counters_(net_def) {
65 operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws);
66 helper_ = caffe2::make_unique<AsyncNetExecutorHelper>(
this);
67 operators_.reserve(operator_nodes_.size());
68 for (
const auto& node : operator_nodes_) {
69 auto op_ptr = node.operator_.get();
70 op_ptr->SetExecutorHelper(helper_.get());
71 operators_.push_back(op_ptr);
74 if (FLAGS_caffe2_net_async_inference_mode) {
75 execution_chains_ = dag_utils::computeGroups(operator_nodes_);
77 execution_chains_ = dag_utils::computeChains(operator_nodes_);
79 chains_.reserve(execution_chains_.size());
80 for (
const auto& kv : execution_chains_) {
81 chains_.push_back(kv.second);
83 chain_nodes_ = dag_utils::prepareChainGraphNodes(operator_nodes_, chains_);
85 events_.reserve(chains_.size());
86 for (
const auto& chain : chains_) {
87 const auto& last_op = operators_[chain.back()];
88 events_.push_back(&last_op->event());
90 if (!options_.report_stats_) {
91 for (
const auto& op_id : chain) {
92 if (op_id == chain.back() || op_id == chain.front()) {
95 const auto& op = operators_[op_id];
101 num_workers_ = net_def->has_num_workers() ? net_def->num_workers() : -1;
103 tracer_ = tracing::create(
this, net_def->name());
105 LOG(INFO) <<
"Tracing net: " << net_def->name();
109 bool AsyncNetBase::handleRunError() {
110 #ifdef CAFFE2_USE_EXCEPTION_PTR 112 int first_exc_task_id = -1;
113 int64_t first_exc_ts = 0;
114 for (
int task_id = 0; task_id < tasksNum(); ++task_id) {
115 if (event(task_id).HasException()) {
116 if (first_exc_task_id >= 0) {
117 auto exc_ts = event(task_id).ExceptionTimestamp();
118 if (exc_ts < first_exc_ts) {
119 first_exc_task_id = task_id;
120 first_exc_ts = exc_ts;
123 first_exc_task_id = task_id;
124 first_exc_ts = event(task_id).ExceptionTimestamp();
128 if (first_exc_task_id >= 0) {
129 LOG(ERROR) <<
"Rethrowing exception from the run of '" << Name() <<
"'";
130 event(first_exc_task_id).RethrowException();
132 #endif // CAFFE2_USE_EXCEPTION_PTR 135 LOG(ERROR) <<
"Error encountered in the run of '" << Name() <<
"'";
140 bool AsyncNetBase::RunAsync() {
141 tracing::startIter(tracer_);
146 TaskThreadPoolBase* AsyncNetBase::poolGetter(
151 std::unique_lock<std::mutex> pools_lock(pools_mutex_);
152 auto pool = pools[device_id][pool_size];
154 pool = c10::ThreadPoolRegistry()->Create(
155 DeviceTypeName(device_type),
158 options_.use_per_net_pools_);
159 pools[device_id][pool_size] = pool;
164 TaskThreadPoolBase* AsyncNetBase::pool() {
167 dev.set_device_type(PROTO_CPU);
171 TaskThreadPoolBase* AsyncNetBase::pool(
const DeviceOption& device_option) {
172 if (options_.use_single_pool_) {
173 return poolGetter(cpu_pools_, PROTO_CPU, -1, num_workers_);
175 const auto device_type = device_option.device_type();
176 if (IsCPUDeviceType(device_type)) {
177 auto numa_node_id = -1;
178 if (device_option.has_numa_node_id()) {
179 numa_node_id = device_option.numa_node_id();
180 CAFFE_ENFORCE_GE(numa_node_id, 0,
"Invalid NUMA node id: ", numa_node_id);
184 FLAGS_caffe2_net_async_max_numa_nodes,
185 "Invalid NUMA node id: ",
187 return poolGetter(cpu_pools_, device_type, numa_node_id, num_workers_);
188 }
else if (IsGPUDeviceType(device_type)) {
189 auto gpu_id = device_option.device_id();
191 gpu_id >= 0 && gpu_id < FLAGS_caffe2_net_async_max_gpus,
192 "Invalid GPU id: " + c10::to_string(gpu_id));
193 return poolGetter(gpu_pools_, device_type, gpu_id, num_workers_);
195 CAFFE_THROW(
"Unsupported device type " + c10::to_string(device_type));
199 int AsyncNetBase::stream(
int task_id) {
200 const auto& device_option = event(task_id).GetDeviceOption();
202 if (IsGPUDeviceType(device_option.device_type())) {
203 int gpu_id = device_option.device_id();
204 CAFFE_ENFORCE_GE(gpu_id, 0,
"Invalid gpu id: " + c10::to_string(gpu_id));
205 if ((
unsigned)gpu_id >= getStreamCounters().size()) {
206 getStreamCounters().resize(gpu_id + 1, 0);
209 stream_id = getStreamCounters().at(gpu_id)++;
210 getStreamCounters().at(gpu_id) %= options_.streams_per_gpu_;
211 }
while (options_.check_stream_status_ &&
212 !isStreamFree(task_id, stream_id));
217 bool AsyncNetBase::isStreamFree(
int task_id,
int stream_id)
const {
218 auto& task = chains_[task_id];
219 auto& last_task_op = operators_[task.back()];
220 return last_task_op->IsStreamFree(stream_id);
223 bool AsyncNetBase::canSchedule(
225 const std::vector<EventStatus>* status,
226 bool* parent_failed) {
227 auto first_child_op_id = chains_[task_id].front();
228 for (
auto parent_id : parents(task_id)) {
229 auto last_parent_op_id = chains_[parent_id].back();
230 EventStatus parent_status;
232 parent_status = status->at(parent_id);
234 parent_status = operators_[last_parent_op_id]->event().Query();
237 if (parent_status == EventStatus::EVENT_FAILED) {
239 *parent_failed =
true;
244 bool can_schedule = Event::CanSchedule(
245 operators_[last_parent_op_id]->event().GetType(),
247 operators_[first_child_op_id]->event().GetType(),
248 operators_[first_child_op_id]->SupportsAsyncScheduling());
257 bool AsyncNetBase::canSchedule(
int parent_id,
int child_id) {
258 auto& parent_event = event(parent_id);
259 auto first_child_op_id = chains_[child_id].front();
260 auto* first_child_op = operators_[first_child_op_id];
261 return Event::CanSchedule(
262 parent_event.GetType(),
263 parent_event.Query(),
264 first_child_op->event().GetType(),
265 first_child_op->SupportsAsyncScheduling());
268 int AsyncNetBase::tasksNum()
const {
269 return chains_.size();
272 Event& AsyncNetBase::event(
int task_id)
const {
273 auto& task = chains_[task_id];
274 auto& last_task_op = operators_[task.back()];
275 return last_task_op->event();
278 EventStatus AsyncNetBase::query(
int task_id)
const {
279 return event(task_id).Query();
282 const std::vector<int>& AsyncNetBase::children(
int task_id)
const {
283 const auto& task_node = chain_nodes_[task_id];
284 return task_node.children_;
287 const std::vector<int>& AsyncNetBase::parents(
int task_id)
const {
288 const auto& task_node = chain_nodes_[task_id];
289 return task_node.parents_;
292 int AsyncNetBase::getParentCount(
int child_id) {
293 auto& child_ops = chains_[child_id];
294 auto& child_node = operator_nodes_[child_ops.front()];
295 return child_node.runtime_parent_count_.load();
298 int AsyncNetBase::updateParentCount(
int child_id) {
299 auto& child_ops = chains_[child_id];
300 auto& child_node = operator_nodes_[child_ops.front()];
301 int parent_count = --child_node.runtime_parent_count_;
302 CAFFE_ENFORCE_GE(parent_count, 0);
306 bool AsyncNetBase::testAndSetScheduled(
int task_id) {
307 auto& task_ops = chains_[task_id];
308 auto& task_op_node = operator_nodes_[task_ops.front()];
309 return !task_op_node.scheduled_.test_and_set();
312 int AsyncNetBase::numOps(
int task_id)
const {
313 return chains_[task_id].size();
316 int AsyncNetBase::firstTaskOpId(
int task_id)
const {
317 return chains_[task_id].front();
320 int AsyncNetBase::lastTaskOpId(
int task_id)
const {
321 return chains_[task_id].back();
324 const OperatorBase* AsyncNetBase::firstTaskOp(
int task_id)
const {
325 return operator_nodes_[firstTaskOpId(task_id)].operator_.get();
328 const OperatorBase* AsyncNetBase::lastTaskOp(
int task_id)
const {
329 return operator_nodes_[lastTaskOpId(task_id)].operator_.get();
332 OperatorBase* AsyncNetBase::firstTaskOp(
int task_id) {
333 return operator_nodes_[firstTaskOpId(task_id)].operator_.get();
336 OperatorBase* AsyncNetBase::lastTaskOp(
int task_id) {
337 return operator_nodes_[lastTaskOpId(task_id)].operator_.get();
340 void AsyncNetBase::asyncWait(
343 const std::vector<int>& wait_task_ids)
const {
344 auto first_op_id = chains_[task_id].front();
345 auto& first_op = operators_[first_op_id];
346 std::vector<const Event*> events;
347 events.reserve(wait_task_ids.size());
348 for (
auto wait_task_id : wait_task_ids) {
349 events.push_back(&event(wait_task_id));
351 first_op->WaitEvents(events, stream_id);
354 void AsyncNetBase::reset() {
355 for (
auto& op : GetOperators()) {
358 for (
auto task_id = 0; task_id < tasksNum(); ++task_id) {
359 auto& task_ops = chains_[task_id];
360 auto& task_op_node = operator_nodes_[task_ops.front()];
361 task_op_node.runtime_parent_count_ = parents(task_id).size();
362 task_op_node.scheduled_.clear();
368 void AsyncNetBase::handleChainError(
372 bool save_exception) noexcept {
373 std::string err_msg = err_str;
375 err_msg +=
", op " + (op->has_debug_def() ? op->type() :
" unknown");
377 LOG(ERROR) << err_msg;
379 if (query(task_id) == EventStatus::EVENT_INITIALIZED) {
380 if (save_exception) {
381 event(task_id).SetFinishedWithException(err_msg.c_str());
383 event(task_id).SetFinished(err_msg.c_str());
388 bool AsyncNetBase::run(
int task_id,
int stream_id) noexcept {
389 OperatorBase* op =
nullptr;
394 if (!options_.finish_chain_) {
395 asyncWait(task_id, stream_id, parents(task_id));
397 for (
auto& op_id : chains_[task_id]) {
398 op = operators_[op_id];
399 bool success =
false;
400 if (!options_.report_stats_) {
406 tracing::TRACE_STREAM,
408 success = op->RunAsync(stream_id);
410 counters_.AddPerOpStartTime(op_id);
411 success = op->RunAsync(stream_id);
412 if (success && op->device_option().device_type() != PROTO_CPU) {
415 counters_.AddPerOpEndTime(op_id);
419 handleChainError(task_id, op,
"Failed to execute an op");
425 if (options_.finish_chain_) {
426 operators_[chains_[task_id].back()]->event().Finish();
428 }
catch (
const std::exception& e) {
429 handleChainError(task_id, op, e.what(),
true);
435 "Failed to execute task: unknown error",
443 void AsyncNetBase::finishTasks(
const std::unordered_set<int>& task_ids) {
444 for (
const auto& task_id : task_ids) {
445 event(task_id).Finish();
449 void AsyncNetBase::finalizeEvents() {
450 for (
auto task_id = 0; task_id < tasksNum(); ++task_id) {
451 auto status = query(task_id);
452 if (status == EventStatus::EVENT_SCHEDULED) {
453 event(task_id).Finish();
454 }
else if (status == EventStatus::EVENT_INITIALIZED) {
455 event(task_id).SetFinished();
457 if (event(task_id).Query() != EventStatus::EVENT_SUCCESS) {
463 ProfDAGProtos AsyncNetBase::GetOperatorStats()
const {
464 return counters_.GetReport().GetOperatorStats();
467 ProfDAGProtos AsyncNetBase::GetPerOperatorCost()
const {
468 return counters_.GetReport().GetPerOperatorCost();
471 ProfDAGReport AsyncNetBase::GetProfReport()
const {
472 return counters_.GetReport();
475 AsyncNetBase::~AsyncNetBase() {
476 if (options_.report_stats_) {
477 counters_.GetReport().PrintStats();
481 ExecutionOptions::ExecutionOptions(
482 const std::shared_ptr<const NetDef>& net_def) {
483 static const std::string kDag =
"dag";
484 static const std::string kProfDag =
"prof_dag";
485 static const std::string kAsyncDag =
"async_dag";
486 static const std::string kSimpleNet =
"simple";
488 std::string net_type;
489 if (net_def->has_type() && !net_def->type().empty()) {
490 net_type = net_def->type();
492 net_type = kSimpleNet;
494 if (net_type == kDag || net_type == kProfDag) {
495 streams_per_gpu_ = 1;
496 finish_chain_ =
true;
497 always_schedule_child_ =
true;
498 check_stream_status_ =
false;
499 use_single_pool_ =
true;
500 use_per_net_pools_ =
true;
502 report_stats_ = (net_type == kProfDag);
503 }
else if (net_type == kAsyncDag) {
504 streams_per_gpu_ = 1;
505 finish_chain_ =
false;
506 always_schedule_child_ =
true;
507 check_stream_status_ =
false;
508 use_single_pool_ =
true;
509 use_per_net_pools_ =
true;
511 report_stats_ =
false;
513 streams_per_gpu_ = FLAGS_caffe2_streams_per_gpu;
514 finish_chain_ =
false;
515 always_schedule_child_ =
false;
516 check_stream_status_ = FLAGS_caffe2_net_async_check_stream_status;
517 use_single_pool_ = FLAGS_caffe2_net_async_use_single_pool;
518 use_per_net_pools_ = FLAGS_caffe2_net_async_use_per_net_pools;
519 is_blocking_ =
false;
520 report_stats_ =
false;
523 use_dfs_scheduling_ =
false;
525 for (
int arg_idx = 0; arg_idx < net_def->arg_size(); ++arg_idx) {
526 auto& arg = net_def->arg(arg_idx);
527 if (arg.has_name() && arg.name() ==
"enable_profiling") {
528 CAFFE_ENFORCE(arg.has_i(),
"enable_profiling should be an int");
529 report_stats_ = arg.i() == 1;
531 if (arg.has_name() && arg.name() ==
"deferrable_mode") {
532 CAFFE_ENFORCE(arg.has_i(),
"deferrable_mode should be an int");
533 use_dfs_scheduling_ = arg.i() == 1;
537 run_root_tasks_inline_ = FLAGS_caffe2_net_async_run_root_tasks_inline;
544 C10_REGISTER_CREATOR(
547 caffe2::GetAsyncNetThreadPool<TaskThreadPool, caffe2::PROTO_CPU>);
548 C10_REGISTER_CREATOR(
551 caffe2::GetAsyncNetThreadPool<TaskThreadPool, caffe2::PROTO_CUDA>);
552 C10_REGISTER_CREATOR(
555 caffe2::GetAsyncNetThreadPool<TaskThreadPool, caffe2::PROTO_HIP>);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...