Caffe2 - C++ API
A deep learning, cross platform ML framework
net_async_base.h
1 #ifndef CAFFE2_CORE_NET_ASYNC_BASE_H_
2 #define CAFFE2_CORE_NET_ASYNC_BASE_H_
3 
4 #include "c10/core/thread_pool.h"
5 #include "c10/util/Registry.h"
6 #include "caffe2/core/common.h"
7 #include "caffe2/core/net.h"
8 #include "caffe2/core/net_dag_utils.h"
9 #include "caffe2/core/prof_dag_counters.h"
10 #include "caffe2/core/stats.h"
11 #include "caffe2/core/timer.h"
12 #include "caffe2/core/workspace.h"
13 #include "caffe2/proto/caffe2_pb.h"
14 #include "caffe2/proto/prof_dag.pb.h"
15 #include "caffe2/utils/proto_utils.h"
16 
17 C10_DECLARE_int(caffe2_streams_per_gpu);
18 C10_DECLARE_int(caffe2_net_async_max_gpus);
19 C10_DECLARE_int(caffe2_net_async_max_numa_nodes);
20 C10_DECLARE_int(caffe2_net_async_thread_pool_size);
21 C10_DECLARE_bool(caffe2_net_async_check_stream_status);
22 C10_DECLARE_bool(caffe2_net_async_use_single_pool);
23 C10_DECLARE_bool(caffe2_net_async_use_per_net_pools);
24 C10_DECLARE_bool(caffe2_net_async_run_root_tasks_inline);
25 
26 namespace caffe2 {
27 
28 class AsyncNetExecutorHelper;
29 
30 namespace tracing {
31 class Tracer;
32 }
33 
35  explicit ExecutionOptions(const std::shared_ptr<const NetDef>& net_def);
36 
37  // number of gpu streams per gpu per cpu thread
38  int streams_per_gpu_ = 1;
39  // ops synchronization options
40  bool finish_chain_ = false;
41  bool always_schedule_child_ = false;
42  // try to pick gpu stream that is not busy
43  bool check_stream_status_ = false;
44  // use single thread pool for all devices
45  bool use_single_pool_ = false;
46  // use per net instances thread pools instead of global ones
47  bool use_per_net_pools_ = false;
48  // whether RunAsync is blocking
49  bool is_blocking_ = false;
50  // prof_dag counters reporting
51  bool report_stats_ = false;
52  // immediately run children tasks inline whenever possible
53  bool use_dfs_scheduling_ = false;
54  // run net's root tasks in RunAsync thread instead of in thread pool
55  bool run_root_tasks_inline_ = false;
56 };
57 
58 class CAFFE2_API AsyncNetBase : public NetBase {
59  public:
60  AsyncNetBase(const std::shared_ptr<const NetDef>& net_def, Workspace* ws);
61  ~AsyncNetBase() override;
62 
63  bool SupportsAsync() override {
64  return true;
65  }
66 
67  vector<OperatorBase*> GetOperators() const override {
68  return operators_;
69  }
70 
71  bool RunAsync() override;
72 
73  const dag_utils::ExecutionChains& TEST_execution_chains() const {
74  return execution_chains_;
75  }
76 
77  ProfDAGProtos GetOperatorStats() const;
78  ProfDAGProtos GetPerOperatorCost() const;
79  ProfDAGReport GetProfReport() const;
80 
81  protected:
82  bool canSchedule(
83  int chain_id,
84  const std::vector<EventStatus>* status = nullptr,
85  bool* parent_failed = nullptr);
86  bool canSchedule(int parent_id, int child_id);
87 
88  int tasksNum() const;
89  Event& event(int task_id) const;
90  EventStatus query(int task_id) const;
91  const std::vector<int>& children(int task_id) const;
92  const std::vector<int>& parents(int task_id) const;
93  int updateParentCount(int child_id);
94  int getParentCount(int child_id);
95  bool testAndSetScheduled(int task_id);
96  int numOps(int task_id) const;
97 
98  int firstTaskOpId(int task_id) const;
99  int lastTaskOpId(int task_id) const;
100  const OperatorBase* firstTaskOp(int task_id) const;
101  const OperatorBase* lastTaskOp(int task_id) const;
102  OperatorBase* firstTaskOp(int task_id);
103  OperatorBase* lastTaskOp(int task_id);
104 
105  void asyncWait(
106  int task_id,
107  int stream_id,
108  const std::vector<int>& wait_task_ids) const;
109  bool run(int task_id, int stream_id) noexcept;
110  int stream(int task_id);
111  TaskThreadPoolBase* pool(const DeviceOption& device_option);
112  TaskThreadPoolBase* pool();
113 
114  void finishTasks(const std::unordered_set<int>& task_ids);
115  void finalizeEvents();
116 
117  bool isStreamFree(int task_id, int stream_id) const;
118 
119  virtual void reset();
120 
121  bool handleRunError() override;
122 
123  // Operator/task graph
124  std::vector<OperatorBase*> operators_;
125  std::vector<dag_utils::OperatorNode> operator_nodes_;
126  std::vector<std::vector<int>> chains_;
127  std::vector<dag_utils::OpGraphNode> chain_nodes_; // chains' parents/children
128  dag_utils::ExecutionChains execution_chains_; // for testing
129 
130  // Pools and streams
131  std::mutex pools_mutex_;
132  // first int key - device id, second - pool size, one pool per (device, size)
133  typedef std::unordered_map<
134  int,
135  std::unordered_map<int, std::shared_ptr<TaskThreadPoolBase>>>
136  PoolsMap;
137  PoolsMap cpu_pools_;
138  PoolsMap gpu_pools_;
139  static std::vector<int>& getStreamCounters();
140  int num_workers_;
141 
142  // Exception/error handling
143  void handleChainError(
144  int task_id,
145  OperatorBase* op,
146  const char* err_msg,
147  bool save_exception = false) noexcept;
148  std::atomic<bool> success_;
149 
150  // Tracing
151  std::shared_ptr<tracing::Tracer> tracer_;
152 
153  // execution mode flags
154  ExecutionOptions options_;
155 
156  ProfDAGCounters counters_;
157 
158  C10_DISABLE_COPY_AND_ASSIGN(AsyncNetBase);
159 
160  private:
162  poolGetter(PoolsMap& pools, int device_type, int device_id, int pool_size);
163 
164  std::unique_ptr<AsyncNetExecutorHelper> helper_;
165 
166  friend class AsyncNetExecutorHelper;
167  friend class tracing::Tracer;
168 };
169 
171  public:
172  explicit AsyncNetExecutorHelper(AsyncNetBase* net) : net_(net) {}
173  TaskThreadPoolBase* GetPool(const DeviceOption& option) const override {
174  return net_->pool(option);
175  }
176 
177  private:
178  AsyncNetBase* net_;
179 };
180 
181 template <class TaskThreadPoolImpl, int device_type>
182 std::shared_ptr<TaskThreadPoolBase>
183 GetAsyncNetThreadPool(int device_id, int pool_size, bool create_new) {
184  static std::unordered_map<
185  int,
186  std::unordered_map<int, std::weak_ptr<TaskThreadPoolBase>>>
187  pools;
188  static std::mutex pool_mutex;
189 
190  const auto& device_type_name = DeviceTypeName(device_type);
191 
192  if (pool_size <= 0) {
193  if (FLAGS_caffe2_net_async_thread_pool_size > 0) {
194  pool_size = FLAGS_caffe2_net_async_thread_pool_size;
195  LOG(INFO) << "Using default " << device_type_name
196  << " pool size: " << pool_size << "; device id: " << device_id;
197  } else {
198  auto num_cores = std::thread::hardware_concurrency();
199  CAFFE_ENFORCE(num_cores > 0, "Failed to get number of CPU cores");
200  LOG(INFO) << "Using estimated " << device_type_name
201  << " pool size: " << num_cores << "; device id: " << device_id;
202  pool_size = num_cores;
203  }
204  } else {
205  LOG(INFO) << "Using specified " << device_type_name
206  << " pool size: " << pool_size << "; device id: " << device_id;
207  }
208 
209  if (create_new) {
210  LOG(INFO) << "Created new " << device_type_name
211  << " pool, size: " << pool_size << "; device id: " << device_id;
212  return std::make_shared<TaskThreadPoolImpl>(pool_size, device_id);
213  } else {
214  std::lock_guard<std::mutex> lock(pool_mutex);
215 
216  auto shared_pool = pools[device_id][pool_size].lock();
217  if (!shared_pool) {
218  LOG(INFO) << "Created shared " << device_type_name
219  << " pool, size: " << pool_size << "; device id: " << device_id;
220  shared_pool = std::make_shared<TaskThreadPoolImpl>(pool_size, device_id);
221  pools[device_id][pool_size] = shared_pool;
222  }
223  return shared_pool;
224  }
225 }
226 
227 } // namespace caffe2
228 
229 #endif // CAFFE2_CORE_NET_ASYNC_BASE_H_
A simple wrapper around prof_dag&#39;s counters.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13