Caffe2 - C++ API
A deep learning, cross platform ML framework
net.h
1 #ifndef CAFFE2_CORE_NET_H_
2 #define CAFFE2_CORE_NET_H_
3 
4 #include <atomic>
5 #include <climits>
6 #include <cstddef>
7 #include <thread> // NOLINT
8 #include <typeinfo>
9 #include <unordered_map>
10 #include <vector>
11 
12 #include "c10/core/thread_pool.h"
13 #include "c10/util/Registry.h"
14 #include "caffe2/core/blob.h"
15 #include "caffe2/core/common.h"
16 #include "caffe2/core/logging.h"
17 #include "caffe2/core/observer.h"
18 #include "caffe2/core/operator_schema.h"
19 #include "caffe2/core/tensor.h"
20 #include "caffe2/core/workspace.h"
21 #include "caffe2/proto/caffe2_pb.h"
22 #include "caffe2/utils/simple_queue.h"
23 
24 C10_DECLARE_string(caffe2_override_executor);
25 
26 namespace caffe2 {
27 
28 class NetBase;
29 typedef ObserverBase<NetBase> NetObserver;
30 typedef std::function<std::unique_ptr<NetObserver>(NetBase*)>
31  NetObserverCreator;
32 
33 class OperatorBase;
34 class Workspace;
35 
36 // Net is a thin struct that owns all the operators together with the operator
37 // contexts.
38 class CAFFE2_API NetBase : public Observable<NetBase> {
39  public:
40  NetBase(const std::shared_ptr<const NetDef>& net_def, Workspace* ws);
41  virtual ~NetBase() noexcept {}
42 
43  virtual bool SupportsAsync() = 0;
44  inline const vector<const Event*>& events() const {
45  return events_;
46  }
47 
48  virtual void Wait() {
49  // by default just wait till all events are finished
50  for (const auto& event : events_) {
51  event->Finish();
52  }
53  }
54 
55  virtual bool Run() {
56  if (!RunAsync()) {
57  LOG(ERROR) << "Failed to execute async run";
58  return false;
59  }
60  Wait();
61  return handleRunError();
62  }
63 
64  virtual bool RunAsync();
65 
75  virtual vector<float> TEST_Benchmark(
76  const int /*warmup_runs*/,
77  const int /*main_runs*/,
78  const bool /*run_individual*/);
79 
80  inline const vector<string>& external_output() const {
81  return external_output_;
82  }
83 
84  inline const vector<string>& external_input() const {
85  return external_input_;
86  }
87 
88  /* Used to attach Observers to operators of a Net
89  *
90  * Returns pointers to objects owned with unique_ptrs.
91  * Use with caution.
92  */
93  virtual vector<OperatorBase*> GetOperators() const = 0;
94 
95  const string& Name() const {
96  return name_;
97  }
98 
99  inline const NetDef& debug_def() const {
100  CAFFE_ENFORCE(has_debug_def(), "net_def was null!");
101  return *net_def_;
102  }
103 
104  inline bool has_debug_def() const {
105  return net_def_ != nullptr;
106  }
107 
108  protected:
109  virtual bool DoRunAsync() {
110  CAFFE_THROW("Not implemented");
111  };
112 
113  virtual bool handleRunError() {
114  for (const Event* event : events_) {
115  if (event->Query() != EventStatus::EVENT_SUCCESS) {
116  CAFFE_THROW(event->ErrorMessage());
117  }
118  }
119  return true;
120  }
121 
122  vector<string> external_input_;
123  vector<string> external_output_;
124  string name_;
125  vector<const Event*> events_;
126  std::shared_ptr<const NetDef> net_def_;
127  C10_DISABLE_COPY_AND_ASSIGN(NetBase);
128 };
129 
130 class CAFFE2_API ExecutorHelper {
131  public:
132  ExecutorHelper() {}
133  virtual TaskThreadPoolBase* GetPool(const DeviceOption& option) const;
134  virtual std::vector<OperatorBase*> GetOperators() const;
135  virtual int GetNumWorkers() const;
136  virtual ~ExecutorHelper() {}
137 };
138 
139 C10_DECLARE_REGISTRY(
140  NetRegistry,
141  NetBase,
142  const std::shared_ptr<const NetDef>&,
143  Workspace*);
144 #define REGISTER_NET_CREATOR(key, ...) \
145  C10_REGISTER_CREATOR(NetRegistry, key, __VA_ARGS__)
146 #define REGISTER_NET(name, ...) \
147  C10_REGISTER_CLASS(NetRegistry, name, __VA_ARGS__)
148 
156 CAFFE2_API unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws);
157 CAFFE2_API unique_ptr<NetBase> CreateNet(
158  const std::shared_ptr<const NetDef>& net_def,
159  Workspace* ws);
160 
161 CAFFE2_API void AddGlobalNetObserverCreator(NetObserverCreator creator);
162 
163 CAFFE2_API void ClearGlobalNetObservers();
164 
165 } // namespace caffe2
166 
167 #endif // CAFFE2_CORE_NET_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
Inherit to make your class observable.
Definition: observer.h:45
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:151