Caffe2 - C++ API
A deep learning, cross platform ML framework
net_async_base.cc
1 
17 #include "caffe2/core/net_async_polling.h"
18 
19 #include "caffe2/core/operator.h"
20 #include "caffe2/core/timer.h"
21 
22 CAFFE2_DEFINE_int(
23  caffe2_streams_per_gpu,
24  32,
25  "Number of streams per GPU to use in GPU thread pool");
26 
27 CAFFE2_DECLARE_bool(caffe2_dag_net_collect_stats);
28 
29 CAFFE2_DEFINE_bool(
30  caffe2_net_async_use_single_pool,
31  false,
32  "Use single thread pool for all chain types");
33 
34 CAFFE2_DEFINE_bool(
35  caffe2_net_async_use_single_gpu_pool,
36  false,
37  "Use single thread pool for all GPU chains");
38 
39 CAFFE2_DEFINE_bool(
40  caffe2_net_async_finish_chain,
41  false,
42  "Wait for chain to finish");
43 
44 CAFFE2_DEFINE_int(
45  caffe2_net_async_max_gpus,
46  16,
47  "Max number of GPUs allowed in net async executor");
48 
49 CAFFE2_DEFINE_int(
50  caffe2_net_async_cpu_pool_size,
51  0,
52  "Number of threads in CPU pool (default - number of cores)");
53 
54 CAFFE2_DEFINE_bool(
55  caffe2_net_async_check_stream_status,
56  true,
57  "Select next non-busy stream");
58 
59 namespace caffe2 {
60 
61 thread_local std::vector<int> AsyncNetBase::stream_counters_;
62 
63 AsyncNetBase::AsyncNetBase(
64  const std::shared_ptr<const NetDef>& net_def,
65  Workspace* ws)
66  : NetBase(net_def, ws) {
67  operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws);
68  operators_.reserve(operator_nodes_.size());
69  for (const auto& node : operator_nodes_) {
70  operators_.push_back(node.operator_.get());
71  }
72 
73  const auto& execution_chains = dag_utils::computeChains(operator_nodes_);
74  chains_.reserve(execution_chains.size());
75  for (const auto& kv : execution_chains) {
76  chains_.push_back(kv.second);
77  }
78  chain_nodes_ = dag_utils::prepareChainGraphNodes(operator_nodes_, chains_);
79 
80  events_.reserve(chains_.size());
81  for (const auto& chain : chains_) {
82  const auto& op = operators_[chain.back()];
83  events_.push_back(&op->event());
84  }
85 
86  DeviceOption cpu_option;
87  cpu_option.set_device_type(CPU);
88  cpu_pool_ = ThreadPoolRegistry()->Create(
89  DeviceTypeName(cpu_option.device_type()), cpu_option);
90  gpu_pools_.resize(FLAGS_caffe2_net_async_max_gpus);
91  if (FLAGS_caffe2_net_async_use_single_gpu_pool) {
92  DeviceOption gpu_option;
93  gpu_option.set_device_type(CUDA);
94  gpu_option.set_cuda_gpu_id(0);
95  gpu_pool_ = ThreadPoolRegistry()->Create(
96  DeviceTypeName(gpu_option.device_type()), gpu_option);
97  }
98 }
99 
100 std::shared_ptr<TaskThreadPool> AsyncNetBase::pool(
101  const DeviceOption& device_option) {
102  if (FLAGS_caffe2_net_async_use_single_pool ||
103  device_option.device_type() == CPU) {
104  return cpu_pool_;
105  } else if (device_option.device_type() == CUDA) {
106  if (FLAGS_caffe2_net_async_use_single_gpu_pool) {
107  return gpu_pool_;
108  } else {
109  auto gpu_id = device_option.cuda_gpu_id();
110  CAFFE_ENFORCE(
111  gpu_id >= 0 && gpu_id < FLAGS_caffe2_net_async_max_gpus,
112  "Invalid GPU id: " + caffe2::to_string(gpu_id));
113  auto pool = gpu_pools_[gpu_id];
114  if (!pool) {
115  std::unique_lock<std::mutex> pools_lock(pools_mutex_);
116  pool = gpu_pools_[gpu_id];
117  if (!pool) {
118  pool = ThreadPoolRegistry()->Create(
119  DeviceTypeName(device_option.device_type()), device_option);
120  gpu_pools_[gpu_id] = pool;
121  }
122  }
123  return pool;
124  }
125  } else {
126  CAFFE_THROW(
127  "Unsupported device type " +
128  caffe2::to_string(device_option.device_type()));
129  }
130 }
131 
132 int AsyncNetBase::stream(int task_id) {
133  const auto& device_option = event(task_id).GetDeviceOption();
134  int stream_id = 0;
135  if (device_option.device_type() == CUDA) {
136  int gpu_id = device_option.cuda_gpu_id();
137  CAFFE_ENFORCE_GE(gpu_id, 0, "Invalid gpu id: " + caffe2::to_string(gpu_id));
138  if (gpu_id >= stream_counters_.size()) {
139  stream_counters_.resize(gpu_id + 1, 0);
140  }
141  do {
142  stream_id = stream_counters_[gpu_id]++;
143  stream_counters_[gpu_id] %= FLAGS_caffe2_streams_per_gpu;
144  } while (!isStreamFree(task_id, stream_id) &&
145  FLAGS_caffe2_net_async_check_stream_status);
146  }
147  return stream_id;
148 }
149 
150 bool AsyncNetBase::isStreamFree(int task_id, int stream_id) const {
151  auto& task = chains_[task_id];
152  auto& last_task_op = operators_[task.back()];
153  return last_task_op->IsStreamFree(stream_id);
154 }
155 
156 bool AsyncNetBase::canSchedule(
157  int task_id,
158  const std::vector<EventStatus>* status) {
159  auto first_child_op_id = chains_[task_id].front();
160  for (auto parent_id : parents(task_id)) {
161  auto last_parent_op_id = chains_[parent_id].back();
162  EventStatus parent_status;
163  if (status) {
164  parent_status = status->at(parent_id);
165  } else {
166  parent_status = operators_[last_parent_op_id]->event().Query();
167  }
168  bool can_schedule = Event::CanSchedule(
169  operators_[last_parent_op_id]->event().GetType(),
170  parent_status,
171  operators_[first_child_op_id]->event().GetType(),
172  operators_[first_child_op_id]->SupportsAsyncScheduling());
173  if (!can_schedule) {
174  return false;
175  }
176  }
177 
178  return true;
179 }
180 
181 int AsyncNetBase::tasksNum() const {
182  return chains_.size();
183 }
184 
185 Event& AsyncNetBase::event(int task_id) const {
186  auto& task = chains_[task_id];
187  auto& last_task_op = operators_[task.back()];
188  return last_task_op->event();
189 }
190 
191 EventStatus AsyncNetBase::query(int task_id) const {
192  return event(task_id).Query();
193 }
194 
195 const std::vector<int>& AsyncNetBase::children(int task_id) const {
196  const auto& task_node = chain_nodes_[task_id];
197  return task_node.children_;
198 }
199 
200 const std::vector<int>& AsyncNetBase::parents(int task_id) const {
201  const auto& task_node = chain_nodes_[task_id];
202  return task_node.parents_;
203 }
204 
205 void AsyncNetBase::asyncWait(
206  int task_id,
207  int stream_id,
208  const std::vector<int>& wait_task_ids) const {
209  auto first_op_id = chains_[task_id].front();
210  auto& first_op = operators_[first_op_id];
211  std::vector<const Event*> events;
212  events.reserve(wait_task_ids.size());
213  for (auto wait_task_id : wait_task_ids) {
214  events.push_back(&event(wait_task_id));
215  }
216  first_op->WaitEvents(events, stream_id);
217 }
218 
219 void AsyncNetBase::run(int task_id, int stream_id) {
220  std::string err_msg;
221  for (auto& op_id : chains_[task_id]) {
222  auto& op = operators_[op_id];
223  try {
224  CAFFE_ENFORCE(op->RunAsync(stream_id), "Failed to execute an op");
225  } catch (const std::exception& e) {
226  CAFFE_THROW(
227  std::string(e.what()) + ", op " +
228  (op->has_debug_def() ? op->type() : " unknown"));
229  } catch (...) {
230  CAFFE_THROW(
231  "Failed to execute task: unknown error, op " +
232  (op->has_debug_def() ? op->type() : " unknown"));
233  }
234  }
235 
236  if (FLAGS_caffe2_net_async_finish_chain) {
237  operators_[chains_[task_id].back()]->event().Finish();
238  }
239 }
240 
241 void AsyncNetBase::finishTasks(const std::unordered_set<int>& task_ids) {
242  for (const auto& task_id : task_ids) {
243  event(task_id).Finish();
244  }
245 }
246 
247 void AsyncNetBase::finalizeEvents() {
248  for (auto task_id = 0; task_id < tasksNum(); ++task_id) {
249  auto status = query(task_id);
250  if (status == EventStatus::EVENT_SCHEDULED) {
251  event(task_id).Finish();
252  } else if (status == EventStatus::EVENT_INITIALIZED) {
253  event(task_id).SetFinished();
254  }
255  }
256 }
257 
258 AsyncNetBase::~AsyncNetBase() {}
259 
260 CAFFE_DEFINE_SHARED_REGISTRY(
261  ThreadPoolRegistry,
262  TaskThreadPool,
263  const DeviceOption&);
264 
265 namespace {
266 std::shared_ptr<TaskThreadPool> AsyncNetCPUThreadPoolCreator(
267  const DeviceOption& device_option) {
268  CAFFE_ENFORCE_EQ(
269  device_option.device_type(),
270  CPU,
271  "Unexpected device type for CPU thread pool");
272  return GetAsyncNetCPUThreadPool();
273 }
274 } // namespace
275 
276 CAFFE_REGISTER_CREATOR(ThreadPoolRegistry, CPU, AsyncNetCPUThreadPoolCreator);
277 
278 /* static */
279 std::shared_ptr<TaskThreadPool> GetAsyncNetCPUThreadPool() {
280  static std::weak_ptr<TaskThreadPool> pool;
281  static std::mutex pool_mutex;
282  std::lock_guard<std::mutex> lock(pool_mutex);
283 
284  auto shared_pool = pool.lock();
285  if (!shared_pool) {
286  auto pool_size = FLAGS_caffe2_net_async_cpu_pool_size;
287  if (pool_size <= 0) {
288  auto num_cores = std::thread::hardware_concurrency();
289  CAFFE_ENFORCE(num_cores > 0, "Failed to get number of CPU cores");
290  pool_size = num_cores;
291  }
292  LOG(INFO) << "Using cpu pool size: " << pool_size;
293  shared_pool = std::make_shared<TaskThreadPool>(pool_size);
294  pool = shared_pool;
295  }
296  return shared_pool;
297 }
298 
299 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.