Caffe2 - C++ API
A deep learning, cross platform ML framework
net_parallel.h
1 #ifndef CAFFE2_CORE_NET_PARALLEL_H
2 #define CAFFE2_CORE_NET_PARALLEL_H
3 
4 #include "caffe2/core/net_async_base.h"
5 #include "caffe2/core/net_async_task_graph.h"
6 
7 C10_DECLARE_string(caffe2_task_graph_engine);
8 
9 namespace caffe2 {
10 
11 class ParallelNetExecutorHelper;
12 
13 class CAFFE2_API ParallelNet : public NetBase {
14  public:
15  ParallelNet(const std::shared_ptr<const NetDef>& net_def, Workspace* ws);
16 
17  bool RunAsync() override;
18  void Wait() override;
19 
20  bool SupportsAsync() override;
21  std::vector<OperatorBase*> GetOperators() const override;
22 
23  TaskThreadPoolBase* Pool(const DeviceOption& device_option);
24 
25  protected:
26  bool handleRunError() override;
27  virtual void finishRun();
28  virtual void reset();
29 
30  ExecutionOptions options_;
31  int num_workers_;
32 
33  std::unique_ptr<ParallelNetExecutorHelper> helper_;
34  std::shared_ptr<AsyncTaskGraphBase> task_graph_;
35  AsyncTaskFuture* run_future_;
36 
37  std::vector<dag_utils::OperatorNode> operator_nodes_;
38  std::vector<OperatorBase*> operators_;
39 
40  std::mutex pools_mutex_;
41  typedef std::unordered_map<
42  int,
43  std::unordered_map<int, std::shared_ptr<TaskThreadPoolBase>>>
44  PoolsMap;
45  PoolsMap cpu_pools_;
46  PoolsMap gpu_pools_;
48  poolGetter(PoolsMap& pools, int device_type, int device_id, int pool_size);
49 
50  friend class ParallelNetExecutorHelper;
51  C10_DISABLE_COPY_AND_ASSIGN(ParallelNet);
52 };
53 
54 C10_DECLARE_SHARED_REGISTRY(
55  TaskGraphRegistry,
58  const ExecutionOptions&);
59 
60 std::shared_ptr<AsyncTaskGraphBase> GetAsyncTaskGraph(
61  ExecutorHelper* helper,
62  const ExecutionOptions& options);
63 
65  public:
66  explicit ParallelNetExecutorHelper(ParallelNet* net) : net_(net) {}
67  TaskThreadPoolBase* GetPool(const DeviceOption& option) const override {
68  return net_->Pool(option);
69  }
70 
71  std::vector<OperatorBase*> GetOperators() const override {
72  return net_->GetOperators();
73  }
74 
75  int GetNumWorkers() const override {
76  return net_->num_workers_;
77  }
78 
79  private:
80  ParallelNet* net_;
81 };
82 
83 } // namespace caffe2
84 
85 #endif // CAFFE2_CORE_NET_PARALLEL_H
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13