1 #ifndef CAFFE2_CORE_NET_PARALLEL_H 2 #define CAFFE2_CORE_NET_PARALLEL_H 4 #include "caffe2/core/net_async_base.h" 5 #include "caffe2/core/net_async_task_graph.h" 7 C10_DECLARE_string(caffe2_task_graph_engine);
11 class ParallelNetExecutorHelper;
17 bool RunAsync()
override;
20 bool SupportsAsync()
override;
21 std::vector<OperatorBase*> GetOperators()
const override;
26 bool handleRunError()
override;
27 virtual void finishRun();
33 std::unique_ptr<ParallelNetExecutorHelper> helper_;
34 std::shared_ptr<AsyncTaskGraphBase> task_graph_;
37 std::vector<dag_utils::OperatorNode> operator_nodes_;
38 std::vector<OperatorBase*> operators_;
40 std::mutex pools_mutex_;
41 typedef std::unordered_map<
43 std::unordered_map<int, std::shared_ptr<TaskThreadPoolBase>>>
48 poolGetter(PoolsMap& pools,
int device_type,
int device_id,
int pool_size);
54 C10_DECLARE_SHARED_REGISTRY(
60 std::shared_ptr<AsyncTaskGraphBase> GetAsyncTaskGraph(
68 return net_->Pool(option);
71 std::vector<OperatorBase*> GetOperators()
const override {
72 return net_->GetOperators();
75 int GetNumWorkers()
const override {
76 return net_->num_workers_;
85 #endif // CAFFE2_CORE_NET_PARALLEL_H
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...