1 #ifndef CAFFE2_CORE_NET_H_ 2 #define CAFFE2_CORE_NET_H_ 9 #include <unordered_map> 12 #include "c10/core/thread_pool.h" 13 #include "c10/util/Registry.h" 14 #include "caffe2/core/blob.h" 15 #include "caffe2/core/common.h" 16 #include "caffe2/core/logging.h" 17 #include "caffe2/core/observer.h" 18 #include "caffe2/core/operator_schema.h" 19 #include "caffe2/core/tensor.h" 20 #include "caffe2/core/workspace.h" 21 #include "caffe2/proto/caffe2_pb.h" 22 #include "caffe2/utils/simple_queue.h" 24 C10_DECLARE_string(caffe2_override_executor);
29 typedef ObserverBase<NetBase> NetObserver;
30 typedef std::function<std::unique_ptr<NetObserver>(NetBase*)>
43 virtual bool SupportsAsync() = 0;
44 inline const vector<const Event*>& events()
const {
50 for (
const auto& event : events_) {
57 LOG(ERROR) <<
"Failed to execute async run";
61 return handleRunError();
64 virtual bool RunAsync();
75 virtual vector<float> TEST_Benchmark(
80 inline const vector<string>& external_output()
const {
81 return external_output_;
84 inline const vector<string>& external_input()
const {
85 return external_input_;
93 virtual vector<OperatorBase*> GetOperators()
const = 0;
95 const string& Name()
const {
99 inline const NetDef& debug_def()
const {
100 CAFFE_ENFORCE(has_debug_def(),
"net_def was null!");
104 inline bool has_debug_def()
const {
105 return net_def_ !=
nullptr;
109 virtual bool DoRunAsync() {
110 CAFFE_THROW(
"Not implemented");
113 virtual bool handleRunError() {
114 for (
const Event* event : events_) {
115 if (event->Query() != EventStatus::EVENT_SUCCESS) {
116 CAFFE_THROW(event->ErrorMessage());
122 vector<string> external_input_;
123 vector<string> external_output_;
125 vector<const Event*> events_;
126 std::shared_ptr<const NetDef> net_def_;
127 C10_DISABLE_COPY_AND_ASSIGN(
NetBase);
134 virtual std::vector<OperatorBase*> GetOperators()
const;
135 virtual int GetNumWorkers()
const;
139 C10_DECLARE_REGISTRY(
142 const std::shared_ptr<const NetDef>&,
144 #define REGISTER_NET_CREATOR(key, ...) \ 145 C10_REGISTER_CREATOR(NetRegistry, key, __VA_ARGS__) 146 #define REGISTER_NET(name, ...) \ 147 C10_REGISTER_CLASS(NetRegistry, name, __VA_ARGS__) 157 CAFFE2_API unique_ptr<NetBase>
CreateNet(
158 const std::shared_ptr<const NetDef>& net_def,
161 CAFFE2_API
void AddGlobalNetObserverCreator(NetObserverCreator creator);
163 CAFFE2_API
void ClearGlobalNetObservers();
167 #endif // CAFFE2_CORE_NET_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Inherit to make your class observable.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.