1 #include "caffe2/core/net_async_scheduling.h" 3 #include "caffe2/core/net_async_tracing.h" 7 AsyncSchedulingNet::AsyncSchedulingNet(
8 const std::shared_ptr<const NetDef>& net_def,
10 : AsyncNetBase(net_def, ws), running_(false) {}
12 void AsyncSchedulingNet::reset() {
13 AsyncNetBase::reset();
14 processed_tasks_num_ = 0;
17 void AsyncSchedulingNet::Wait() {
18 std::unique_lock<std::mutex> lock(running_mutex_);
20 running_cv_.wait(lock);
24 bool AsyncSchedulingNet::isInlineTask(
int parent_id,
int child_id)
const {
25 if (!options_.use_dfs_scheduling_) {
28 const auto* last_parent_op = lastTaskOp(parent_id);
29 const auto* first_child_op = firstTaskOp(child_id);
32 last_parent_op->device_option(), first_child_op->device_option());
38 void AsyncSchedulingNet::schedule(
int task_id,
bool run_inline) noexcept {
39 if (!testAndSetScheduled(task_id)) {
42 auto schedule_func = [
this, task_id]() {
46 if (options_.streams_per_gpu_ > 1) {
48 stream_id = stream(task_id);
49 }
catch (
const std::exception& e) {
50 C10_LOG_EVERY_MS(ERROR, 1000)
51 <<
"Failed to select a stream: " << e.what();
54 if (!run(task_id, stream_id)) {
59 if (options_.report_stats_) {
61 auto last_op_id = lastTaskOpId(task_id);
62 auto* last_op = lastTaskOp(task_id);
63 if (last_op->device_option().device_type() == PROTO_CPU &&
64 last_op->HasAsyncPart()) {
65 last_op->event().SetCallback([
this, last_op_id] {
66 counters_.AddPerOpAsyncEndTime(last_op_id);
69 }
catch (
const std::exception& e) {
70 C10_LOG_EVERY_MS(ERROR, 1000)
71 <<
"Failed to report operator stats: " << e.what();
75 for (
auto child_id : children(task_id)) {
76 int parent_count = updateParentCount(child_id);
77 if (parent_count == 0) {
84 if (!success_ || options_.always_schedule_child_ ||
85 options_.finish_chain_ || canSchedule(child_id)) {
88 schedule(child_id, isInlineTask(task_id, child_id));
90 bool parent_failed =
false;
91 bool parent_needs_polling =
false;
92 std::vector<int> parents_with_callback;
94 for (
auto parent_id : parents(child_id)) {
95 auto& parent_event = event(parent_id);
96 auto parent_status = parent_event.Query();
98 if (parent_status == EventStatus::EVENT_FAILED) {
101 }
else if (parent_status == EventStatus::EVENT_SCHEDULED) {
104 if (!canSchedule(parent_id, child_id)) {
107 if (parent_event.SupportsCallback()) {
108 parents_with_callback.push_back(parent_id);
110 parent_needs_polling =
true;
114 }
else if (parent_status != EventStatus::EVENT_SUCCESS) {
115 VLOG(1) <<
"Unexpected parent task state: " << parent_status
116 <<
", task id: " << child_id
117 <<
", parent task id: " << parent_id;
118 parent_failed =
true;
126 schedule(child_id, isInlineTask(task_id, child_id));
127 }
else if (parent_needs_polling) {
130 const auto& child_device_option =
131 event(child_id).GetDeviceOption();
132 pool(child_device_option)
134 &AsyncSchedulingNet::pollAndSchedule,
this, child_id));
135 }
else if (!parents_with_callback.empty()) {
138 for (
auto parent_id : parents_with_callback) {
139 event(parent_id).SetCallback(std::bind(
140 &AsyncSchedulingNet::parentCallback,
this, parent_id));
144 schedule(child_id, isInlineTask(task_id, child_id));
154 for (
auto tid = 0; tid < tasksNum(); ++tid) {
155 if (event(tid).Query() == EventStatus::EVENT_SCHEDULED) {
159 event(tid).SetFinished(
"Cancelled");
160 }
catch (
const EnforceNotMet&) {
171 auto tasks_num = tasksNum();
172 auto cur_processed_tasks = ++processed_tasks_num_;
173 if (cur_processed_tasks == tasks_num) {
176 }
catch (
const std::exception& e) {
178 LOG(FATAL) <<
"Unexpected error during graph scheduling run: " 181 LOG(FATAL) <<
"Unknown error during graph scheduling run";
188 const auto& device_option = event(task_id).GetDeviceOption();
189 pool(device_option)->run(schedule_func);
193 void AsyncSchedulingNet::parentCallback(
int parent_id) {
194 if (event(parent_id).Query() != EventStatus::EVENT_SUCCESS) {
197 for (
auto child_id : children(parent_id)) {
198 int parent_count = getParentCount(child_id);
199 if (parent_count == 0) {
200 if (!success_ || canSchedule(child_id)) {
207 void AsyncSchedulingNet::pollAndSchedule(
int task_id) {
208 bool parent_failed =
false;
209 bool can_schedule = canSchedule(task_id,
nullptr, &parent_failed);
217 if (can_schedule || !success_ || parent_failed) {
220 const auto& device_option = event(task_id).GetDeviceOption();
222 ->run(std::bind(&AsyncSchedulingNet::pollAndSchedule,
this, task_id));
226 void AsyncSchedulingNet::finishRun() {
227 std::unique_lock<std::mutex> lock(running_mutex_);
230 if (options_.report_stats_) {
231 counters_.ReportRunEnd();
236 running_cv_.notify_all();
239 bool AsyncSchedulingNet::RunAsync() {
241 std::unique_lock<std::mutex> lock(running_mutex_);
243 LOG(ERROR) <<
"Detected concurrent runs";
250 tracing::startIter(tracer_);
251 if (options_.report_stats_) {
252 counters_.ReportRunStart();
254 }
catch (
const std::exception& e) {
255 LOG(ERROR) <<
"Exception while starting an async run: " << e.what();
259 LOG(ERROR) <<
"Exception while starting an async run: unknown error";
266 for (
auto task_id = 0; task_id < tasksNum(); ++task_id) {
267 if (parents(task_id).empty()) {
268 schedule(task_id, options_.run_root_tasks_inline_);
272 if (tasksNum() == 0) {
276 if (options_.is_blocking_) {
283 AsyncSchedulingNet::~AsyncSchedulingNet() {
287 REGISTER_NET(async_scheduling, AsyncSchedulingNet);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...