1 #include "caffe2/core/net_async_base.h"     3 #include "caffe2/core/net_async_tracing.h"     4 #include "caffe2/core/operator.h"     5 #include "caffe2/core/timer.h"     9     caffe2_streams_per_gpu,
    11     "Number of streams per worker per GPU"    12     " to use in GPU thread pool (experimental)");
    15     caffe2_net_async_inference_mode,
    17     "If set, use one single chain containing all ops");
    20     caffe2_net_async_max_gpus,
    22     "Max number of GPUs allowed in net async executor");
    25     caffe2_net_async_max_numa_nodes,
    27     "Max number of NUMA nodes allowed in net async executor");
    30     caffe2_net_async_thread_pool_size,
    32     "Number of threads in device thread pool by default");
    35     caffe2_net_async_check_stream_status,
    37     "Select next non-busy stream");
    40     caffe2_net_async_use_single_pool,
    42     "Use single thread pool for all devices");
    45     caffe2_net_async_use_per_net_pools,
    47     "Use per net thread pools");
    50     caffe2_net_async_run_root_tasks_inline,
    52     "Run root tasks in current thread instread of scheduling to threadpool");
    56 std::vector<int>& AsyncNetBase::getStreamCounters() {
    57   static thread_local std::vector<int> stream_counters_;
    58   return stream_counters_;
    61 AsyncNetBase::AsyncNetBase(
    62     const std::shared_ptr<const NetDef>& net_def,
    64     : NetBase(net_def, ws), options_(net_def), counters_(net_def) {
    65   operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws);
    66   helper_ = caffe2::make_unique<AsyncNetExecutorHelper>(
this);
    67   operators_.reserve(operator_nodes_.size());
    68   for (
const auto& node : operator_nodes_) {
    69     auto op_ptr = node.operator_.get();
    70     op_ptr->SetExecutorHelper(helper_.get());
    71     operators_.push_back(op_ptr);
    74   if (FLAGS_caffe2_net_async_inference_mode) {
    75     execution_chains_ = dag_utils::computeGroups(operator_nodes_);
    77     execution_chains_ = dag_utils::computeChains(operator_nodes_);
    79   chains_.reserve(execution_chains_.size());
    80   for (
const auto& kv : execution_chains_) {
    81     chains_.push_back(kv.second);
    83   chain_nodes_ = dag_utils::prepareChainGraphNodes(operator_nodes_, chains_);
    85   events_.reserve(chains_.size());
    86   for (
const auto& chain : chains_) {
    87     const auto& last_op = operators_[chain.back()];
    88     events_.push_back(&last_op->event());
    90     if (!options_.report_stats_) {
    91       for (
const auto& op_id : chain) {
    92         if (op_id == chain.back() || op_id == chain.front()) {
    95         const auto& op = operators_[op_id];
   101   num_workers_ = net_def->has_num_workers() ? net_def->num_workers() : -1;
   103   tracer_ = tracing::create(
this, net_def->name());
   105     LOG(INFO) << 
"Tracing net: " << net_def->name();
   109 bool AsyncNetBase::handleRunError() {
   110 #ifdef CAFFE2_USE_EXCEPTION_PTR   112   int first_exc_task_id = -1;
   113   int64_t first_exc_ts = 0;
   114   for (
int task_id = 0; task_id < tasksNum(); ++task_id) {
   115     if (event(task_id).HasException()) {
   116       if (first_exc_task_id >= 0) {
   117         auto exc_ts = event(task_id).ExceptionTimestamp();
   118         if (exc_ts < first_exc_ts) {
   119           first_exc_task_id = task_id;
   120           first_exc_ts = exc_ts;
   123         first_exc_task_id = task_id;
   124         first_exc_ts = event(task_id).ExceptionTimestamp();
   128   if (first_exc_task_id >= 0) {
   129     LOG(ERROR) << 
"Rethrowing exception from the run of '" << Name() << 
"'";
   130     event(first_exc_task_id).RethrowException();
   132 #endif // CAFFE2_USE_EXCEPTION_PTR   135     LOG(ERROR) << 
"Error encountered in the run of '" << Name() << 
"'";
   140 bool AsyncNetBase::RunAsync() {
   141   tracing::startIter(tracer_);
   146 TaskThreadPoolBase* AsyncNetBase::poolGetter(
   151   std::unique_lock<std::mutex> pools_lock(pools_mutex_);
   152   auto pool = pools[device_id][pool_size];
   154     pool = c10::ThreadPoolRegistry()->Create(
   155         DeviceTypeName(device_type),
   158         options_.use_per_net_pools_);
   159     pools[device_id][pool_size] = pool;
   164 TaskThreadPoolBase* AsyncNetBase::pool() {
   167   dev.set_device_type(PROTO_CPU);
   171 TaskThreadPoolBase* AsyncNetBase::pool(
const DeviceOption& device_option) {
   172   if (options_.use_single_pool_) {
   173     return poolGetter(cpu_pools_, PROTO_CPU, -1, num_workers_);
   175   const auto device_type = device_option.device_type();
   176   if (IsCPUDeviceType(device_type)) {
   177     auto numa_node_id = -1;
   178     if (device_option.has_numa_node_id()) {
   179       numa_node_id = device_option.numa_node_id();
   180       CAFFE_ENFORCE_GE(numa_node_id, 0, 
"Invalid NUMA node id: ", numa_node_id);
   184         FLAGS_caffe2_net_async_max_numa_nodes,
   185         "Invalid NUMA node id: ",
   187     return poolGetter(cpu_pools_, device_type, numa_node_id, num_workers_);
   188   } 
else if (IsGPUDeviceType(device_type)) {
   189     auto gpu_id = device_option.device_id();
   191         gpu_id >= 0 && gpu_id < FLAGS_caffe2_net_async_max_gpus,
   192         "Invalid GPU id: " + c10::to_string(gpu_id));
   193     return poolGetter(gpu_pools_, device_type, gpu_id, num_workers_);
   195     CAFFE_THROW(
"Unsupported device type " + c10::to_string(device_type));
   199 int AsyncNetBase::stream(
int task_id) {
   200   const auto& device_option = event(task_id).GetDeviceOption();
   202   if (IsGPUDeviceType(device_option.device_type())) {
   203     int gpu_id = device_option.device_id();
   204     CAFFE_ENFORCE_GE(gpu_id, 0, 
"Invalid gpu id: " + c10::to_string(gpu_id));
   205     if ((
unsigned)gpu_id >= getStreamCounters().size()) {
   206       getStreamCounters().resize(gpu_id + 1, 0);
   209       stream_id = getStreamCounters().at(gpu_id)++;
   210       getStreamCounters().at(gpu_id) %= options_.streams_per_gpu_;
   211     } 
while (options_.check_stream_status_ &&
   212              !isStreamFree(task_id, stream_id));
   217 bool AsyncNetBase::isStreamFree(
int task_id, 
int stream_id)
 const {
   218   auto& task = chains_[task_id];
   219   auto& last_task_op = operators_[task.back()];
   220   return last_task_op->IsStreamFree(stream_id);
   223 bool AsyncNetBase::canSchedule(
   225     const std::vector<EventStatus>* status,
   226     bool* parent_failed) {
   227   auto first_child_op_id = chains_[task_id].front();
   228   for (
auto parent_id : parents(task_id)) {
   229     auto last_parent_op_id = chains_[parent_id].back();
   230     EventStatus parent_status;
   232       parent_status = status->at(parent_id);
   234       parent_status = operators_[last_parent_op_id]->event().Query();
   237     if (parent_status == EventStatus::EVENT_FAILED) {
   239         *parent_failed = 
true;
   244     bool can_schedule = Event::CanSchedule(
   245         operators_[last_parent_op_id]->event().GetType(),
   247         operators_[first_child_op_id]->event().GetType(),
   248         operators_[first_child_op_id]->SupportsAsyncScheduling());
   257 bool AsyncNetBase::canSchedule(
int parent_id, 
int child_id) {
   258   auto& parent_event = event(parent_id);
   259   auto first_child_op_id = chains_[child_id].front();
   260   auto* first_child_op = operators_[first_child_op_id];
   261   return Event::CanSchedule(
   262       parent_event.GetType(),
   263       parent_event.Query(),
   264       first_child_op->event().GetType(),
   265       first_child_op->SupportsAsyncScheduling());
   268 int AsyncNetBase::tasksNum()
 const {
   269   return chains_.size();
   272 Event& AsyncNetBase::event(
int task_id)
 const {
   273   auto& task = chains_[task_id];
   274   auto& last_task_op = operators_[task.back()];
   275   return last_task_op->event();
   278 EventStatus AsyncNetBase::query(
int task_id)
 const {
   279   return event(task_id).Query();
   282 const std::vector<int>& AsyncNetBase::children(
int task_id)
 const {
   283   const auto& task_node = chain_nodes_[task_id];
   284   return task_node.children_;
   287 const std::vector<int>& AsyncNetBase::parents(
int task_id)
 const {
   288   const auto& task_node = chain_nodes_[task_id];
   289   return task_node.parents_;
   292 int AsyncNetBase::getParentCount(
int child_id) {
   293   auto& child_ops = chains_[child_id];
   294   auto& child_node = operator_nodes_[child_ops.front()];
   295   return child_node.runtime_parent_count_.load();
   298 int AsyncNetBase::updateParentCount(
int child_id) {
   299   auto& child_ops = chains_[child_id];
   300   auto& child_node = operator_nodes_[child_ops.front()];
   301   int parent_count = --child_node.runtime_parent_count_;
   302   CAFFE_ENFORCE_GE(parent_count, 0);
   306 bool AsyncNetBase::testAndSetScheduled(
int task_id) {
   307   auto& task_ops = chains_[task_id];
   308   auto& task_op_node = operator_nodes_[task_ops.front()];
   309   return !task_op_node.scheduled_.test_and_set();
   312 int AsyncNetBase::numOps(
int task_id)
 const {
   313   return chains_[task_id].size();
   316 int AsyncNetBase::firstTaskOpId(
int task_id)
 const {
   317   return chains_[task_id].front();
   320 int AsyncNetBase::lastTaskOpId(
int task_id)
 const {
   321   return chains_[task_id].back();
   324 const OperatorBase* AsyncNetBase::firstTaskOp(
int task_id)
 const {
   325   return operator_nodes_[firstTaskOpId(task_id)].operator_.get();
   328 const OperatorBase* AsyncNetBase::lastTaskOp(
int task_id)
 const {
   329   return operator_nodes_[lastTaskOpId(task_id)].operator_.get();
   332 OperatorBase* AsyncNetBase::firstTaskOp(
int task_id) {
   333   return operator_nodes_[firstTaskOpId(task_id)].operator_.get();
   336 OperatorBase* AsyncNetBase::lastTaskOp(
int task_id) {
   337   return operator_nodes_[lastTaskOpId(task_id)].operator_.get();
   340 void AsyncNetBase::asyncWait(
   343     const std::vector<int>& wait_task_ids)
 const {
   344   auto first_op_id = chains_[task_id].front();
   345   auto& first_op = operators_[first_op_id];
   346   std::vector<const Event*> events;
   347   events.reserve(wait_task_ids.size());
   348   for (
auto wait_task_id : wait_task_ids) {
   349     events.push_back(&event(wait_task_id));
   351   first_op->WaitEvents(events, stream_id);
   354 void AsyncNetBase::reset() {
   355   for (
auto& op : GetOperators()) {
   358   for (
auto task_id = 0; task_id < tasksNum(); ++task_id) {
   359     auto& task_ops = chains_[task_id];
   360     auto& task_op_node = operator_nodes_[task_ops.front()];
   361     task_op_node.runtime_parent_count_ = parents(task_id).size();
   362     task_op_node.scheduled_.clear();
   368 void AsyncNetBase::handleChainError(
   372     bool save_exception) noexcept {
   373   std::string err_msg = err_str;
   375     err_msg += 
",  op " + (op->has_debug_def() ? op->type() : 
" unknown");
   377   LOG(ERROR) << err_msg;
   379   if (query(task_id) == EventStatus::EVENT_INITIALIZED) {
   380     if (save_exception) {
   381       event(task_id).SetFinishedWithException(err_msg.c_str());
   383       event(task_id).SetFinished(err_msg.c_str());
   388 bool AsyncNetBase::run(
int task_id, 
int stream_id) noexcept {
   389   OperatorBase* op = 
nullptr;
   394     if (!options_.finish_chain_) {
   395       asyncWait(task_id, stream_id, parents(task_id));
   397     for (
auto& op_id : chains_[task_id]) {
   398       op = operators_[op_id];
   399       bool success = 
false;
   400       if (!options_.report_stats_) {
   406             tracing::TRACE_STREAM,
   408         success = op->RunAsync(stream_id);
   410         counters_.AddPerOpStartTime(op_id);
   411         success = op->RunAsync(stream_id);
   412         if (success && op->device_option().device_type() != PROTO_CPU) {
   415         counters_.AddPerOpEndTime(op_id);
   419         handleChainError(task_id, op, 
"Failed to execute an op");
   425     if (options_.finish_chain_) {
   426       operators_[chains_[task_id].back()]->event().Finish();
   428   } 
catch (
const std::exception& e) {
   429     handleChainError(task_id, op, e.what(),  
true);
   435         "Failed to execute task: unknown error",
   443 void AsyncNetBase::finishTasks(
const std::unordered_set<int>& task_ids) {
   444   for (
const auto& task_id : task_ids) {
   445     event(task_id).Finish();
   449 void AsyncNetBase::finalizeEvents() {
   450   for (
auto task_id = 0; task_id < tasksNum(); ++task_id) {
   451     auto status = query(task_id);
   452     if (status == EventStatus::EVENT_SCHEDULED) {
   453       event(task_id).Finish();
   454     } 
else if (status == EventStatus::EVENT_INITIALIZED) {
   455       event(task_id).SetFinished();
   457     if (event(task_id).Query() != EventStatus::EVENT_SUCCESS) {
   463 ProfDAGProtos AsyncNetBase::GetOperatorStats()
 const {
   464   return counters_.GetReport().GetOperatorStats();
   467 ProfDAGProtos AsyncNetBase::GetPerOperatorCost()
 const {
   468   return counters_.GetReport().GetPerOperatorCost();
   471 ProfDAGReport AsyncNetBase::GetProfReport()
 const {
   472   return counters_.GetReport();
   475 AsyncNetBase::~AsyncNetBase() {
   476   if (options_.report_stats_) {
   477     counters_.GetReport().PrintStats();
   481 ExecutionOptions::ExecutionOptions(
   482     const std::shared_ptr<const NetDef>& net_def) {
   483   static const std::string kDag = 
"dag";
   484   static const std::string kProfDag = 
"prof_dag";
   485   static const std::string kAsyncDag = 
"async_dag";
   486   static const std::string kSimpleNet = 
"simple";
   488   std::string net_type;
   489   if (net_def->has_type() && !net_def->type().empty()) {
   490     net_type = net_def->type();
   492     net_type = kSimpleNet;
   494   if (net_type == kDag || net_type == kProfDag) {
   495     streams_per_gpu_ = 1;
   496     finish_chain_ = 
true;
   497     always_schedule_child_ = 
true;
   498     check_stream_status_ = 
false;
   499     use_single_pool_ = 
true;
   500     use_per_net_pools_ = 
true;
   502     report_stats_ = (net_type == kProfDag);
   503   } 
else if (net_type == kAsyncDag) {
   504     streams_per_gpu_ = 1;
   505     finish_chain_ = 
false;
   506     always_schedule_child_ = 
true;
   507     check_stream_status_ = 
false;
   508     use_single_pool_ = 
true;
   509     use_per_net_pools_ = 
true;
   511     report_stats_ = 
false;
   513     streams_per_gpu_ = FLAGS_caffe2_streams_per_gpu;
   514     finish_chain_ = 
false;
   515     always_schedule_child_ = 
false;
   516     check_stream_status_ = FLAGS_caffe2_net_async_check_stream_status;
   517     use_single_pool_ = FLAGS_caffe2_net_async_use_single_pool;
   518     use_per_net_pools_ = FLAGS_caffe2_net_async_use_per_net_pools;
   519     is_blocking_ = 
false;
   520     report_stats_ = 
false;
   523   use_dfs_scheduling_ = 
false;
   525   for (
int arg_idx = 0; arg_idx < net_def->arg_size(); ++arg_idx) {
   526     auto& arg = net_def->arg(arg_idx);
   527     if (arg.has_name() && arg.name() == 
"enable_profiling") {
   528       CAFFE_ENFORCE(arg.has_i(), 
"enable_profiling should be an int");
   529       report_stats_ = arg.i() == 1;
   531     if (arg.has_name() && arg.name() == 
"deferrable_mode") {
   532       CAFFE_ENFORCE(arg.has_i(), 
"deferrable_mode should be an int");
   533       use_dfs_scheduling_ = arg.i() == 1; 
   537   run_root_tasks_inline_ = FLAGS_caffe2_net_async_run_root_tasks_inline;
   544 C10_REGISTER_CREATOR(
   547     caffe2::GetAsyncNetThreadPool<TaskThreadPool, caffe2::PROTO_CPU>);
   548 C10_REGISTER_CREATOR(
   551     caffe2::GetAsyncNetThreadPool<TaskThreadPool, caffe2::PROTO_CUDA>);
   552 C10_REGISTER_CREATOR(
   555     caffe2::GetAsyncNetThreadPool<TaskThreadPool, caffe2::PROTO_HIP>);
 A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...