1 #ifndef CAFFE2_CORE_NET_ASYNC_BASE_H_     2 #define CAFFE2_CORE_NET_ASYNC_BASE_H_     4 #include "c10/core/thread_pool.h"     5 #include "c10/util/Registry.h"     6 #include "caffe2/core/common.h"     7 #include "caffe2/core/net.h"     8 #include "caffe2/core/net_dag_utils.h"     9 #include "caffe2/core/prof_dag_counters.h"    10 #include "caffe2/core/stats.h"    11 #include "caffe2/core/timer.h"    12 #include "caffe2/core/workspace.h"    13 #include "caffe2/proto/caffe2_pb.h"    14 #include "caffe2/proto/prof_dag.pb.h"    15 #include "caffe2/utils/proto_utils.h"    17 C10_DECLARE_int(caffe2_streams_per_gpu);
    18 C10_DECLARE_int(caffe2_net_async_max_gpus);
    19 C10_DECLARE_int(caffe2_net_async_max_numa_nodes);
    20 C10_DECLARE_int(caffe2_net_async_thread_pool_size);
    21 C10_DECLARE_bool(caffe2_net_async_check_stream_status);
    22 C10_DECLARE_bool(caffe2_net_async_use_single_pool);
    23 C10_DECLARE_bool(caffe2_net_async_use_per_net_pools);
    24 C10_DECLARE_bool(caffe2_net_async_run_root_tasks_inline);
    28 class AsyncNetExecutorHelper;
    38   int streams_per_gpu_ = 1;
    40   bool finish_chain_ = 
false;
    41   bool always_schedule_child_ = 
false;
    43   bool check_stream_status_ = 
false;
    45   bool use_single_pool_ = 
false;
    47   bool use_per_net_pools_ = 
false;
    49   bool is_blocking_ = 
false;
    51   bool report_stats_ = 
false;
    53   bool use_dfs_scheduling_ = 
false;
    55   bool run_root_tasks_inline_ = 
false;
    63   bool SupportsAsync()
 override {
    67   vector<OperatorBase*> GetOperators()
 const override {
    71   bool RunAsync() 
override;
    73   const dag_utils::ExecutionChains& TEST_execution_chains()
 const {
    74     return execution_chains_;
    77   ProfDAGProtos GetOperatorStats() 
const;
    78   ProfDAGProtos GetPerOperatorCost() 
const;
    84       const std::vector<EventStatus>* status = 
nullptr,
    85       bool* parent_failed = 
nullptr);
    86   bool canSchedule(
int parent_id, 
int child_id);
    89   Event& event(
int task_id) 
const;
    90   EventStatus query(
int task_id) 
const;
    91   const std::vector<int>& children(
int task_id) 
const;
    92   const std::vector<int>& parents(
int task_id) 
const;
    93   int updateParentCount(
int child_id);
    94   int getParentCount(
int child_id);
    95   bool testAndSetScheduled(
int task_id);
    96   int numOps(
int task_id) 
const;
    98   int firstTaskOpId(
int task_id) 
const;
    99   int lastTaskOpId(
int task_id) 
const;
   108       const std::vector<int>& wait_task_ids) 
const;
   109   bool run(
int task_id, 
int stream_id) noexcept;
   110   int stream(
int task_id);
   114   void finishTasks(
const std::unordered_set<int>& task_ids);
   115   void finalizeEvents();
   117   bool isStreamFree(
int task_id, 
int stream_id) 
const;
   119   virtual void reset();
   121   bool handleRunError() 
override;
   124   std::vector<OperatorBase*> operators_;
   125   std::vector<dag_utils::OperatorNode> operator_nodes_;
   126   std::vector<std::vector<int>> chains_;
   127   std::vector<dag_utils::OpGraphNode> chain_nodes_; 
   128   dag_utils::ExecutionChains execution_chains_; 
   131   std::mutex pools_mutex_;
   133   typedef std::unordered_map<
   135       std::unordered_map<int, std::shared_ptr<TaskThreadPoolBase>>>
   139   static std::vector<int>& getStreamCounters();
   143   void handleChainError(
   147       bool save_exception = 
false) noexcept;
   148   std::atomic<bool> success_;
   151   std::shared_ptr<tracing::Tracer> tracer_;
   162   poolGetter(PoolsMap& pools, 
int device_type, 
int device_id, 
int pool_size);
   164   std::unique_ptr<AsyncNetExecutorHelper> helper_;
   174     return net_->pool(option);
   181 template <
class TaskThreadPoolImpl, 
int device_type>
   182 std::shared_ptr<TaskThreadPoolBase>
   183 GetAsyncNetThreadPool(
int device_id, 
int pool_size, 
bool create_new) {
   184   static std::unordered_map<
   186       std::unordered_map<int, std::weak_ptr<TaskThreadPoolBase>>>
   188   static std::mutex pool_mutex;
   190   const auto& device_type_name = DeviceTypeName(device_type);
   192   if (pool_size <= 0) {
   193     if (FLAGS_caffe2_net_async_thread_pool_size > 0) {
   194       pool_size = FLAGS_caffe2_net_async_thread_pool_size;
   195       LOG(INFO) << 
"Using default " << device_type_name
   196                 << 
" pool size: " << pool_size << 
"; device id: " << device_id;
   198       auto num_cores = std::thread::hardware_concurrency();
   199       CAFFE_ENFORCE(num_cores > 0, 
"Failed to get number of CPU cores");
   200       LOG(INFO) << 
"Using estimated " << device_type_name
   201                 << 
" pool size: " << num_cores << 
"; device id: " << device_id;
   202       pool_size = num_cores;
   205     LOG(INFO) << 
"Using specified " << device_type_name
   206               << 
" pool size: " << pool_size << 
"; device id: " << device_id;
   210     LOG(INFO) << 
"Created new " << device_type_name
   211               << 
" pool, size: " << pool_size << 
"; device id: " << device_id;
   212     return std::make_shared<TaskThreadPoolImpl>(pool_size, device_id);
   214     std::lock_guard<std::mutex> lock(pool_mutex);
   216     auto shared_pool = pools[device_id][pool_size].lock();
   218       LOG(INFO) << 
"Created shared " << device_type_name
   219                 << 
" pool, size: " << pool_size << 
"; device id: " << device_id;
   220       shared_pool = std::make_shared<TaskThreadPoolImpl>(pool_size, device_id);
   221       pools[device_id][pool_size] = shared_pool;
   229 #endif // CAFFE2_CORE_NET_ASYNC_BASE_H_ A simple wrapper around prof_dag's counters. 
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...