Caffe2 - C++ API
A deep learning, cross platform ML framework
net.h
1 
17 #ifndef CAFFE2_CORE_NET_H_
18 #define CAFFE2_CORE_NET_H_
19 
20 #include <atomic>
21 #include <climits>
22 #include <cstddef>
23 #include <thread> // NOLINT
24 #include <typeinfo>
25 #include <unordered_map>
26 #include <vector>
27 
28 #include "caffe2/core/blob.h"
29 #include "caffe2/core/common.h"
30 #include "caffe2/core/logging.h"
31 #include "caffe2/core/observer.h"
32 #include "caffe2/core/operator_schema.h"
33 #include "caffe2/core/registry.h"
34 #include "caffe2/core/tensor.h"
35 #include "caffe2/core/workspace.h"
36 #include "caffe2/proto/caffe2.pb.h"
37 #include "caffe2/utils/simple_queue.h"
38 
39 namespace caffe2 {
40 
41 class NetBase;
42 typedef ObserverBase<NetBase> NetObserver;
43 typedef std::function<std::unique_ptr<NetObserver>(NetBase*)>
44  NetObserverCreator;
45 
46 class OperatorBase;
47 class Workspace;
48 
49 // Net is a thin struct that owns all the operators together with the operator
50 // contexts.
51 class NetBase : public Observable<NetBase> {
52  public:
53  NetBase(const std::shared_ptr<const NetDef>& net_def, Workspace* ws);
54  virtual ~NetBase() noexcept {}
55 
56  virtual bool SupportsAsync() = 0;
57  inline const vector<const Event*>& events() const {
58  return events_;
59  }
60 
61  virtual void Wait() {
62  // by default just wait till all events are finished
63  for (const auto& event : events_) {
64  event->Finish();
65  }
66  }
67 
68  virtual bool Run() {
69  if (!RunAsync()) {
70  LOG(ERROR) << "Failed to execute async run";
71  return false;
72  }
73  Wait();
74  for (const Event* event : events_) {
75  if (event->Query() != EventStatus::EVENT_SUCCESS) {
76  CAFFE_THROW(event->ErrorMessage());
77  }
78  }
79  return true;
80  }
81 
82  virtual bool RunAsync();
83 
93  virtual vector<float> TEST_Benchmark(
94  const int /*warmup_runs*/,
95  const int /*main_runs*/,
96  const bool /*run_individual*/) {
97  LOG(ERROR) << "Benchmark not implemented for this net type.";
98  return vector<float>();
99  }
100 
101  inline const vector<string>& external_output() const {
102  return external_output_;
103  }
104 
105  inline const vector<string>& external_input() const {
106  return external_input_;
107  }
108 
109  /* Used to attach Observers to operators of a Net
110  *
111  * Returns pointers to objects owned with unique_ptrs.
112  * Use with caution.
113  */
114  virtual vector<OperatorBase*> GetOperators() const = 0;
115 
116  const string& Name() const {
117  return name_;
118  }
119 
120  inline const NetDef& debug_def() const {
121  CAFFE_ENFORCE(has_debug_def(), "net_def was null!");
122  return *net_def_;
123  }
124 
125  inline bool has_debug_def() const {
126  return net_def_ != nullptr;
127  }
128 
129  protected:
130  virtual bool DoRunAsync() {
131  CAFFE_THROW("Not implemented");
132  };
133 
134  vector<string> external_input_;
135  vector<string> external_output_;
136  string name_;
137  vector<const Event*> events_;
138  std::shared_ptr<const NetDef> net_def_;
139  DISABLE_COPY_AND_ASSIGN(NetBase);
140 };
141 
142 CAFFE_DECLARE_REGISTRY(
143  NetRegistry,
144  NetBase,
145  const std::shared_ptr<const NetDef>&,
146  Workspace*);
147 #define REGISTER_NET_CREATOR(key, ...) \
148  CAFFE_REGISTER_CREATOR(NetRegistry, key, __VA_ARGS__)
149 #define REGISTER_NET(name, ...) \
150  CAFFE_REGISTER_CLASS(NetRegistry, name, __VA_ARGS__)
151 
159 unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws);
160 unique_ptr<NetBase> CreateNet(
161  const std::shared_ptr<const NetDef>& net_def,
162  Workspace* ws);
163 
164 void SetGlobalNetObserverCreator(NetObserverCreator creator);
165 
166 } // namespace caffe2
167 
168 #endif // CAFFE2_CORE_NET_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Inherit to make your class observable.
Definition: observer.h:60
Copyright (c) 2016-present, Facebook, Inc.
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:117
virtual vector< float > TEST_Benchmark(const int, const int, const bool)
Benchmarks a network.
Definition: net.h:93