Caffe2 - C++ API
A deep learning, cross platform ML framework
net_async_base.h
1 
17 #ifndef CAFFE2_CORE_NET_ASYNC_BASE_H_
18 #define CAFFE2_CORE_NET_ASYNC_BASE_H_
19 
20 #include "caffe2/core/common.h"
21 #include "caffe2/core/net.h"
22 #include "caffe2/core/net_dag_utils.h"
23 #include "caffe2/core/registry.h"
24 #include "caffe2/core/stats.h"
25 #include "caffe2/core/timer.h"
26 #include "caffe2/core/workspace.h"
27 #include "caffe2/proto/caffe2.pb.h"
28 #include "caffe2/utils/proto_utils.h"
29 #include "caffe2/utils/thread_pool.h"
30 
31 namespace caffe2 {
32 
33 class AsyncNetBase : public NetBase {
34  public:
35  AsyncNetBase(const std::shared_ptr<const NetDef>& net_def, Workspace* ws);
36  ~AsyncNetBase() override;
37 
38  bool SupportsAsync() override {
39  return true;
40  }
41 
42  vector<OperatorBase*> GetOperators() const override {
43  return operators_;
44  }
45 
46  protected:
47  bool canSchedule(
48  int chain_id,
49  const std::vector<EventStatus>* status = nullptr);
50 
51  int tasksNum() const;
52  Event& event(int task_id) const;
53  EventStatus query(int task_id) const;
54  const std::vector<int>& children(int task_id) const;
55  const std::vector<int>& parents(int task_id) const;
56  void asyncWait(
57  int task_id,
58  int stream_id,
59  const std::vector<int>& wait_task_ids) const;
60  void run(int task_id, int stream_id);
61  int stream(int task_id);
62  std::shared_ptr<TaskThreadPool> pool(const DeviceOption& device_option);
63 
64  void finishTasks(const std::unordered_set<int>& task_ids);
65  void finalizeEvents();
66 
67  bool isStreamFree(int task_id, int stream_id) const;
68 
69  // Operator/task graph
70  std::vector<OperatorBase*> operators_;
71  std::vector<dag_utils::OperatorNode> operator_nodes_;
72  std::vector<std::vector<int>> chains_;
73  std::vector<dag_utils::OpGraphNode> chain_nodes_; // chains' parents/children
74 
75  // Pools and streams
76  std::mutex pools_mutex_;
77  std::vector<std::shared_ptr<TaskThreadPool>> gpu_pools_;
78  std::shared_ptr<TaskThreadPool> cpu_pool_;
79  std::shared_ptr<TaskThreadPool> gpu_pool_;
80  static thread_local std::vector<int> stream_counters_;
81 
82  DISABLE_COPY_AND_ASSIGN(AsyncNetBase);
83 };
84 
85 CAFFE_DECLARE_SHARED_REGISTRY(
86  ThreadPoolRegistry,
88  const DeviceOption&);
89 
90 std::shared_ptr<TaskThreadPool> GetAsyncNetCPUThreadPool();
91 
92 } // namespace caffe2
93 
94 #endif // CAFFE2_CORE_NET_ASYNC_POLLING_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.