1 #include "caffe2/core/net_async_scheduling.h"     3 #include "caffe2/core/net_async_tracing.h"     7 AsyncSchedulingNet::AsyncSchedulingNet(
     8     const std::shared_ptr<const NetDef>& net_def,
    10     : AsyncNetBase(net_def, ws), running_(false) {}
    12 void AsyncSchedulingNet::reset() {
    13   AsyncNetBase::reset();
    14   processed_tasks_num_ = 0;
    17 void AsyncSchedulingNet::Wait() {
    18   std::unique_lock<std::mutex> lock(running_mutex_);
    20     running_cv_.wait(lock);
    24 bool AsyncSchedulingNet::isInlineTask(
int parent_id, 
int child_id)
 const {
    25   if (!options_.use_dfs_scheduling_) {
    28   const auto* last_parent_op = lastTaskOp(parent_id);
    29   const auto* first_child_op = firstTaskOp(child_id);
    32       last_parent_op->device_option(), first_child_op->device_option());
    38 void AsyncSchedulingNet::schedule(
int task_id, 
bool run_inline) noexcept {
    39   if (!testAndSetScheduled(task_id)) {
    42   auto schedule_func = [
this, task_id]() {
    46         if (options_.streams_per_gpu_ > 1) {
    48             stream_id = stream(task_id);
    49           } 
catch (
const std::exception& e) {
    50             C10_LOG_EVERY_MS(ERROR, 1000)
    51                 << 
"Failed to select a stream: " << e.what();
    54         if (!run(task_id, stream_id)) {
    59       if (options_.report_stats_) {
    61           auto last_op_id = lastTaskOpId(task_id);
    62           auto* last_op = lastTaskOp(task_id);
    63           if (last_op->device_option().device_type() == PROTO_CPU &&
    64               last_op->HasAsyncPart()) {
    65             last_op->event().SetCallback([
this, last_op_id] {
    66               counters_.AddPerOpAsyncEndTime(last_op_id);
    69         } 
catch (
const std::exception& e) {
    70           C10_LOG_EVERY_MS(ERROR, 1000)
    71               << 
"Failed to report operator stats: " << e.what();
    75       for (
auto child_id : children(task_id)) {
    76         int parent_count = updateParentCount(child_id);
    77         if (parent_count == 0) {
    84           if (!success_ || options_.always_schedule_child_ ||
    85               options_.finish_chain_ || canSchedule(child_id)) {
    88             schedule(child_id, isInlineTask(task_id, child_id));
    90             bool parent_failed = 
false;
    91             bool parent_needs_polling = 
false;
    92             std::vector<int> parents_with_callback;
    94             for (
auto parent_id : parents(child_id)) {
    95               auto& parent_event = event(parent_id);
    96               auto parent_status = parent_event.Query();
    98               if (parent_status == EventStatus::EVENT_FAILED) {
   101               } 
else if (parent_status == EventStatus::EVENT_SCHEDULED) {
   104                 if (!canSchedule(parent_id, child_id)) {
   107                   if (parent_event.SupportsCallback()) {
   108                     parents_with_callback.push_back(parent_id);
   110                     parent_needs_polling = 
true;
   114               } 
else if (parent_status != EventStatus::EVENT_SUCCESS) {
   115                 VLOG(1) << 
"Unexpected parent task state: " << parent_status
   116                         << 
", task id: " << child_id
   117                         << 
", parent task id: " << parent_id;
   118                 parent_failed = 
true;
   126               schedule(child_id, isInlineTask(task_id, child_id));
   127             } 
else if (parent_needs_polling) {
   130               const auto& child_device_option =
   131                   event(child_id).GetDeviceOption();
   132               pool(child_device_option)
   134                       &AsyncSchedulingNet::pollAndSchedule, 
this, child_id));
   135             } 
else if (!parents_with_callback.empty()) {
   138               for (
auto parent_id : parents_with_callback) {
   139                 event(parent_id).SetCallback(std::bind(
   140                     &AsyncSchedulingNet::parentCallback, 
this, parent_id));
   144               schedule(child_id, isInlineTask(task_id, child_id));
   154         for (
auto tid = 0; tid < tasksNum(); ++tid) {
   155           if (event(tid).Query() == EventStatus::EVENT_SCHEDULED) {
   159               event(tid).SetFinished(
"Cancelled");
   160             } 
catch (
const EnforceNotMet&) {
   171       auto tasks_num = tasksNum();
   172       auto cur_processed_tasks = ++processed_tasks_num_;
   173       if (cur_processed_tasks == tasks_num) {
   176     } 
catch (
const std::exception& e) {
   178       LOG(FATAL) << 
"Unexpected error during graph scheduling run: "   181       LOG(FATAL) << 
"Unknown error during graph scheduling run";
   188     const auto& device_option = event(task_id).GetDeviceOption();
   189     pool(device_option)->run(schedule_func);
   193 void AsyncSchedulingNet::parentCallback(
int parent_id) {
   194   if (event(parent_id).Query() != EventStatus::EVENT_SUCCESS) {
   197   for (
auto child_id : children(parent_id)) {
   198     int parent_count = getParentCount(child_id);
   199     if (parent_count == 0) {
   200       if (!success_ || canSchedule(child_id)) {
   207 void AsyncSchedulingNet::pollAndSchedule(
int task_id) {
   208   bool parent_failed = 
false;
   209   bool can_schedule = canSchedule(task_id, 
nullptr, &parent_failed);
   217   if (can_schedule || !success_ || parent_failed) {
   220     const auto& device_option = event(task_id).GetDeviceOption();
   222         ->run(std::bind(&AsyncSchedulingNet::pollAndSchedule, 
this, task_id));
   226 void AsyncSchedulingNet::finishRun() {
   227   std::unique_lock<std::mutex> lock(running_mutex_);
   230   if (options_.report_stats_) {
   231     counters_.ReportRunEnd();
   236   running_cv_.notify_all();
   239 bool AsyncSchedulingNet::RunAsync() {
   241     std::unique_lock<std::mutex> lock(running_mutex_);
   243       LOG(ERROR) << 
"Detected concurrent runs";
   250     tracing::startIter(tracer_);
   251     if (options_.report_stats_) {
   252       counters_.ReportRunStart();
   254   } 
catch (
const std::exception& e) {
   255     LOG(ERROR) << 
"Exception while starting an async run: " << e.what();
   259     LOG(ERROR) << 
"Exception while starting an async run: unknown error";
   266   for (
auto task_id = 0; task_id < tasksNum(); ++task_id) {
   267     if (parents(task_id).empty()) {
   268       schedule(task_id, options_.run_root_tasks_inline_);
   272   if (tasksNum() == 0) {
   276   if (options_.is_blocking_) {
   283 AsyncSchedulingNet::~AsyncSchedulingNet() {
   287 REGISTER_NET(async_scheduling, AsyncSchedulingNet);
 A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...