Caffe2 - C++ API
A deep learning, cross platform ML framework
net_async_gpu_thread_pool_gpu.cc
1 
17 #include "caffe2/core/net_async_gpu_thread_pool.h"
18 
19 #include "caffe2/core/context_gpu.h"
20 
21 CAFFE2_DEFINE_int(caffe2_threads_per_gpu, 1, "Number of CPU threads per GPU");
22 
23 namespace caffe2 {
24 
25 namespace {
26 std::shared_ptr<TaskThreadPool> AsyncNetGPUThreadPoolCreator(
27  const DeviceOption& device_option) {
28  CAFFE_ENFORCE_EQ(
29  device_option.device_type(),
30  CUDA,
31  "Unexpected device type for CUDA thread pool");
32  return GetAsyncNetGPUThreadPool(device_option.cuda_gpu_id());
33 }
34 } // namespace
35 
36 CAFFE_REGISTER_CREATOR(ThreadPoolRegistry, CUDA, AsyncNetGPUThreadPoolCreator);
37 
38 std::shared_ptr<TaskThreadPool> GetAsyncNetGPUThreadPool(int gpu_id) {
39  static std::unordered_map<int, std::weak_ptr<TaskThreadPool>> pools;
40  static std::mutex pool_mutex;
41  std::lock_guard<std::mutex> lock(pool_mutex);
42 
43  std::shared_ptr<TaskThreadPool> shared_pool = nullptr;
44  if (pools.count(gpu_id)) {
45  shared_pool = pools.at(gpu_id).lock();
46  }
47  if (!shared_pool) {
48  shared_pool =
49  std::make_shared<TaskThreadPool>(FLAGS_caffe2_threads_per_gpu);
50  pools[gpu_id] = shared_pool;
51  }
52  return shared_pool;
53 }
54 
55 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.