Caffe2 - C++ API
A deep learning, cross platform ML framework
net_dag.h
1 
17 #ifndef CAFFE2_CORE_NET_DAG_H_
18 #define CAFFE2_CORE_NET_DAG_H_
19 
20 #include <atomic>
21 #include <climits>
22 #include <cstddef>
23 #include <thread> // NOLINT
24 #include <typeinfo>
25 #include <unordered_map>
26 #include <vector>
27 
28 #include "caffe2/core/blob.h"
29 #include "caffe2/core/common.h"
30 #include "caffe2/core/logging.h"
31 #include "caffe2/core/net_dag_utils.h"
32 #include "caffe2/core/observer.h"
33 #include "caffe2/core/operator_schema.h"
34 #include "caffe2/core/registry.h"
35 #include "caffe2/core/stats.h"
36 #include "caffe2/core/tensor.h"
37 #include "caffe2/core/timer.h"
38 #include "caffe2/core/workspace.h"
39 #include "caffe2/proto/caffe2.pb.h"
40 #include "caffe2/utils/simple_queue.h"
41 
42 namespace caffe2 {
43 
44 class DAGNetBase : public NetBase {
45  public:
46  DAGNetBase(const std::shared_ptr<const NetDef>& net_def, Workspace* ws);
47  ~DAGNetBase() override;
48 
49  // WorkerFunction() is a function wrapper to allow us to run worker threads.
50  // It checks out one ready-to-run operator from the job queue, runs it,
51  // notifies all its children, and for any children that is ready, enqueues
52  // it to the job queue.
53  void WorkerFunction();
54  vector<float> TEST_Benchmark(
55  const int warmup_runs,
56  const int main_runs,
57  const bool run_individual) override;
58 
59  const dag_utils::ExecutionChains& TEST_execution_chains() const {
60  return execution_chains_;
61  }
62 
63  vector<OperatorBase*> GetOperators() const override {
64  return operators_;
65  }
66 
67  protected:
68  bool DoRunAsync() override;
69 
70  virtual bool RunAt(int chain_id, const std::vector<int>& chain) = 0;
71 
72  vector<dag_utils::OperatorNode> operator_nodes_;
73  vector<OperatorBase*> operators_;
74  dag_utils::ExecutionChains execution_chains_;
75  vector<int> initial_frontier_;
76  std::unique_ptr<SimpleQueue<int>> job_queue_;
77  std::vector<std::thread> workers_;
78  int num_workers_;
79  int remaining_ops_;
80 
81  bool success_;
82  int iter_;
83  std::mutex remaining_ops_mutex_;
84  std::condition_variable cv_;
85  std::mutex run_in_progress_;
86 
87  struct DAGNetStats {
88  CAFFE_STAT_CTOR(DAGNetStats);
89  CAFFE_AVG_EXPORTED_STAT(task_pool_wait_time_us);
90  CAFFE_AVG_EXPORTED_STAT(task_time_to_scheduled_us);
91  CAFFE_AVG_EXPORTED_STAT(task_time_to_succeeded_ms);
92  CAFFE_AVG_EXPORTED_STAT(task_wait_time_us);
93  };
94  mutable std::vector<DAGNetStats> stats_;
95  std::unordered_map<int, std::unique_ptr<Timer>> task_timers_;
96 
97  DISABLE_COPY_AND_ASSIGN(DAGNetBase);
98 };
99 
100 class DAGNet : public DAGNetBase {
101  public:
102  using DAGNetBase::DAGNetBase;
103 
104  protected:
105  bool RunAt(int chain_id, const std::vector<int>& chain) override;
106  bool SupportsAsync() override {
107  return false;
108  }
109 };
110 
111 } // namespace caffe2
112 
113 #endif // CAFFE2_CORE_NET_DAG_H_
vector< float > TEST_Benchmark(const int warmup_runs, const int main_runs, const bool run_individual) override
Benchmarks a network.
Definition: net_dag.cc:273
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.