Caffe2 - C++ API
A deep learning, cross platform ML framework
net_async_polling.cc
1 
17 #include "caffe2/core/net_async_polling.h"
18 
19 #include "caffe2/core/operator.h"
20 #include "caffe2/core/timer.h"
21 
22 CAFFE2_DECLARE_bool(caffe2_dag_net_collect_stats);
23 
24 namespace caffe2 {
25 
26 AsyncPollingNet::AsyncPollingNet(
27  const std::shared_ptr<const NetDef>& net_def,
28  Workspace* ws)
29  : AsyncNetBase(net_def, ws), running_(false) {
30  task_timers_.resize(tasksNum());
31  for (auto task_id = 0; task_id < tasksNum(); ++task_id) {
32  task_timers_[task_id] = caffe2::make_unique<Timer>();
33  }
34 
35  stats_.reserve(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
36  for (auto device_idx = 0;
37  device_idx < DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;
38  ++device_idx) {
39  stats_.emplace_back(
40  "async_net/stats/" + net_def->name() + "/" +
41  caffe2::DeviceTypeName(device_idx));
42  }
43 
44  reset();
45 }
46 
47 bool AsyncPollingNet::DoRunAsync() {
48  CAFFE_ENFORCE(!running_, "Concurrent RunAsync calls");
49  running_ = true;
50  reset();
51 
52  StartAllObservers();
53 
54  Timer timer;
55  bool success = pollAndSchedule();
56  if (FLAGS_caffe2_dag_net_collect_stats) {
57  CAFFE_EVENT(stats_[CPU], poll_time_ms, timer.MilliSeconds());
58  }
59  if (!success) {
60  finalizeEvents();
61  }
62 
63  StopAllObservers();
64  running_ = false;
65  return success;
66 }
67 
68 void AsyncPollingNet::schedule(int task_id) {
69  if (FLAGS_caffe2_dag_net_collect_stats) {
70  task_timers_[task_id]->Start();
71  }
72  const auto& device_option = event(task_id).GetDeviceOption();
73  pool(device_option)->run([this, task_id, device_option]() {
74  int stream_id = stream(task_id);
75 
76  if (FLAGS_caffe2_dag_net_collect_stats) {
77  CAFFE_EVENT(
78  stats_[device_option.device_type()],
79  task_pool_wait_time_us,
80  task_timers_[task_id]->MicroSeconds());
81  }
82 
83  // Non-blocking wait, setups scheduling of dependent async computations;
84  // canSchedule ensures that there's no busy wait,
85  // for CUDA events we need to insert CUDA event synchronization to ensure
86  // that async CUDA computations are executed in correct order
87  asyncWait(task_id, stream_id, parents(task_id));
88  try {
89  if (FLAGS_caffe2_dag_net_collect_stats) {
90  Timer run_time;
91  run(task_id, stream_id);
92  CAFFE_EVENT(
93  stats_[device_option.device_type()],
94  task_run_time_us,
95  run_time.MicroSeconds());
96  } else {
97  run(task_id, stream_id);
98  }
99  } catch (const std::exception&) {
100  has_chain_failed_ = true;
101  }
102  });
103 }
104 
105 void AsyncPollingNet::reset() {
106  status_.clear();
107  status_.resize(tasksNum(), EventStatus::EVENT_INITIALIZED);
108  has_chain_failed_ = false;
109 }
110 
111 bool AsyncPollingNet::pollAndSchedule() {
112  std::unordered_set<int> scheduled_tasks;
113  std::unordered_set<int> current_tasks;
114 
115  for (auto task_id = 0; task_id < tasksNum(); ++task_id) {
116  if (parents(task_id).empty()) {
117  current_tasks.insert(task_id);
118  scheduled_tasks.insert(task_id);
119  schedule(task_id);
120  }
121  }
122 
123  Timer timer;
124  while (!current_tasks.empty()) {
125  std::unordered_set<int> updated_tasks;
126  std::unordered_set<int> next_tasks;
127  updated_tasks.reserve(current_tasks.size());
128 
129  if (FLAGS_caffe2_dag_net_collect_stats) {
130  timer.Start();
131  }
132  if (has_chain_failed_) {
133  finishTasks(current_tasks);
134  return false;
135  }
136  for (auto& task_id : current_tasks) {
137  auto prev_status = status_[task_id];
138  status_[task_id] = query(task_id);
139  if (status_[task_id] == EventStatus::EVENT_FAILED) {
140  finishTasks(current_tasks);
141  return false;
142  }
143 
144  if (prev_status != status_[task_id]) {
145  updated_tasks.insert(task_id);
146  if (FLAGS_caffe2_dag_net_collect_stats) {
147  updateTaskStats(task_id);
148  }
149  }
150 
151  if (status_[task_id] != EventStatus::EVENT_SUCCESS) {
152  next_tasks.insert(task_id);
153  }
154  }
155  if (FLAGS_caffe2_dag_net_collect_stats) {
156  CAFFE_EVENT(
157  stats_[CPU], poll_status_update_time_us, timer.MicroSeconds());
158  }
159 
160  std::unordered_set<int> visited_children;
161  for (auto& task_id : updated_tasks) {
162  CAFFE_ENFORCE(
163  status_[task_id] == EventStatus::EVENT_SCHEDULED ||
164  status_[task_id] == EventStatus::EVENT_SUCCESS);
165 
166  for (auto& child_id : children(task_id)) {
167  if (!visited_children.count(child_id)) {
168  visited_children.insert(child_id);
169  // Important - check whether we have already scheduled the task,
170  // e.g. a child CUDA task can be scheduled after parent CUDA
171  // task becomes EventStatus::EVENT_SCHEDULED and also later when
172  // parent CUDA task becomes EventStatus::EVENT_SUCCESS
173  if (!scheduled_tasks.count(child_id) &&
174  canSchedule(child_id, &status_)) {
175  next_tasks.insert(child_id);
176  scheduled_tasks.insert(child_id);
177  schedule(child_id);
178  }
179  }
180  }
181  }
182 
183  current_tasks.swap(next_tasks);
184  }
185  return true;
186 }
187 
188 void AsyncPollingNet::updateTaskStats(int task_id) {
189  const auto& device_option = event(task_id).GetDeviceOption();
190  if (status_[task_id] == EventStatus::EVENT_SCHEDULED) {
191  CAFFE_EVENT(
192  stats_[device_option.device_type()],
193  task_time_to_scheduled_us,
194  task_timers_[task_id]->MicroSeconds());
195  }
196  if (status_[task_id] == EventStatus::EVENT_SUCCESS) {
197  CAFFE_EVENT(
198  stats_[device_option.device_type()],
199  task_time_to_succeeded_ms,
200  task_timers_[task_id]->MilliSeconds());
201  }
202 }
203 
204 AsyncPollingNet::~AsyncPollingNet() {}
205 
206 REGISTER_NET(async_polling, AsyncPollingNet);
207 
208 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.