1 #ifndef CAFFE2_CORE_NET_ASYNC_BASE_H_ 2 #define CAFFE2_CORE_NET_ASYNC_BASE_H_ 4 #include "c10/core/thread_pool.h" 5 #include "c10/util/Registry.h" 6 #include "caffe2/core/common.h" 7 #include "caffe2/core/net.h" 8 #include "caffe2/core/net_dag_utils.h" 9 #include "caffe2/core/prof_dag_counters.h" 10 #include "caffe2/core/stats.h" 11 #include "caffe2/core/timer.h" 12 #include "caffe2/core/workspace.h" 13 #include "caffe2/proto/caffe2_pb.h" 14 #include "caffe2/proto/prof_dag.pb.h" 15 #include "caffe2/utils/proto_utils.h" 17 C10_DECLARE_int(caffe2_streams_per_gpu);
18 C10_DECLARE_int(caffe2_net_async_max_gpus);
19 C10_DECLARE_int(caffe2_net_async_max_numa_nodes);
20 C10_DECLARE_int(caffe2_net_async_thread_pool_size);
21 C10_DECLARE_bool(caffe2_net_async_check_stream_status);
22 C10_DECLARE_bool(caffe2_net_async_use_single_pool);
23 C10_DECLARE_bool(caffe2_net_async_use_per_net_pools);
24 C10_DECLARE_bool(caffe2_net_async_run_root_tasks_inline);
28 class AsyncNetExecutorHelper;
38 int streams_per_gpu_ = 1;
40 bool finish_chain_ =
false;
41 bool always_schedule_child_ =
false;
43 bool check_stream_status_ =
false;
45 bool use_single_pool_ =
false;
47 bool use_per_net_pools_ =
false;
49 bool is_blocking_ =
false;
51 bool report_stats_ =
false;
53 bool use_dfs_scheduling_ =
false;
55 bool run_root_tasks_inline_ =
false;
63 bool SupportsAsync()
override {
67 vector<OperatorBase*> GetOperators()
const override {
71 bool RunAsync()
override;
73 const dag_utils::ExecutionChains& TEST_execution_chains()
const {
74 return execution_chains_;
77 ProfDAGProtos GetOperatorStats()
const;
78 ProfDAGProtos GetPerOperatorCost()
const;
84 const std::vector<EventStatus>* status =
nullptr,
85 bool* parent_failed =
nullptr);
86 bool canSchedule(
int parent_id,
int child_id);
89 Event& event(
int task_id)
const;
90 EventStatus query(
int task_id)
const;
91 const std::vector<int>& children(
int task_id)
const;
92 const std::vector<int>& parents(
int task_id)
const;
93 int updateParentCount(
int child_id);
94 int getParentCount(
int child_id);
95 bool testAndSetScheduled(
int task_id);
96 int numOps(
int task_id)
const;
98 int firstTaskOpId(
int task_id)
const;
99 int lastTaskOpId(
int task_id)
const;
108 const std::vector<int>& wait_task_ids)
const;
109 bool run(
int task_id,
int stream_id) noexcept;
110 int stream(
int task_id);
114 void finishTasks(
const std::unordered_set<int>& task_ids);
115 void finalizeEvents();
117 bool isStreamFree(
int task_id,
int stream_id)
const;
119 virtual void reset();
121 bool handleRunError()
override;
124 std::vector<OperatorBase*> operators_;
125 std::vector<dag_utils::OperatorNode> operator_nodes_;
126 std::vector<std::vector<int>> chains_;
127 std::vector<dag_utils::OpGraphNode> chain_nodes_;
128 dag_utils::ExecutionChains execution_chains_;
131 std::mutex pools_mutex_;
133 typedef std::unordered_map<
135 std::unordered_map<int, std::shared_ptr<TaskThreadPoolBase>>>
139 static std::vector<int>& getStreamCounters();
143 void handleChainError(
147 bool save_exception =
false) noexcept;
148 std::atomic<bool> success_;
151 std::shared_ptr<tracing::Tracer> tracer_;
162 poolGetter(PoolsMap& pools,
int device_type,
int device_id,
int pool_size);
164 std::unique_ptr<AsyncNetExecutorHelper> helper_;
174 return net_->pool(option);
181 template <
class TaskThreadPoolImpl,
int device_type>
182 std::shared_ptr<TaskThreadPoolBase>
183 GetAsyncNetThreadPool(
int device_id,
int pool_size,
bool create_new) {
184 static std::unordered_map<
186 std::unordered_map<int, std::weak_ptr<TaskThreadPoolBase>>>
188 static std::mutex pool_mutex;
190 const auto& device_type_name = DeviceTypeName(device_type);
192 if (pool_size <= 0) {
193 if (FLAGS_caffe2_net_async_thread_pool_size > 0) {
194 pool_size = FLAGS_caffe2_net_async_thread_pool_size;
195 LOG(INFO) <<
"Using default " << device_type_name
196 <<
" pool size: " << pool_size <<
"; device id: " << device_id;
198 auto num_cores = std::thread::hardware_concurrency();
199 CAFFE_ENFORCE(num_cores > 0,
"Failed to get number of CPU cores");
200 LOG(INFO) <<
"Using estimated " << device_type_name
201 <<
" pool size: " << num_cores <<
"; device id: " << device_id;
202 pool_size = num_cores;
205 LOG(INFO) <<
"Using specified " << device_type_name
206 <<
" pool size: " << pool_size <<
"; device id: " << device_id;
210 LOG(INFO) <<
"Created new " << device_type_name
211 <<
" pool, size: " << pool_size <<
"; device id: " << device_id;
212 return std::make_shared<TaskThreadPoolImpl>(pool_size, device_id);
214 std::lock_guard<std::mutex> lock(pool_mutex);
216 auto shared_pool = pools[device_id][pool_size].lock();
218 LOG(INFO) <<
"Created shared " << device_type_name
219 <<
" pool, size: " << pool_size <<
"; device id: " << device_id;
220 shared_pool = std::make_shared<TaskThreadPoolImpl>(pool_size, device_id);
221 pools[device_id][pool_size] = shared_pool;
229 #endif // CAFFE2_CORE_NET_ASYNC_BASE_H_ A simple wrapper around prof_dag's counters.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...