Caffe2 - C++ API
A deep learning, cross platform ML framework
net.cc
1 #include "caffe2/core/net.h"
2 #include "caffe2/core/net_simple.h"
3 
4 #include <set>
5 #include <unordered_map>
6 #include <unordered_set>
7 
8 #include "caffe2/core/init.h"
9 #include "caffe2/core/operator.h"
10 #include "caffe2/core/timer.h"
11 #include "caffe2/proto/caffe2_pb.h"
12 #include "caffe2/utils/proto_utils.h"
13 #include "caffe2/utils/string_utils.h"
14 
15 C10_DEFINE_string(
16  caffe2_override_executor,
17  "",
18  "Comma-separated list of executor overrides");
19 
20 namespace caffe2 {
21 
22 C10_DEFINE_REGISTRY(
23  NetRegistry,
24  NetBase,
25  const std::shared_ptr<const NetDef>&,
26  Workspace*);
27 
28 NetBase::NetBase(
29  const std::shared_ptr<const NetDef>& def,
30  Workspace* /* unused */)
31  : external_input_(
32  def->external_input().begin(),
33  def->external_input().end()),
34  external_output_(
35  def->external_output().begin(),
36  def->external_output().end()),
37  name_(def->name()),
38  net_def_(def) {
39  static GlobalInitIsCalledGuard guard;
40  // Check that node_name is empty for all ops
41  for (const OperatorDef& op : def->op()) {
42  if (op.has_device_option()) {
43  CAFFE_ENFORCE(
44  !op.device_option().has_node_name(),
45  "node_name must be empty for all operators at execution time.");
46  }
47  }
48 
49  // Go through the operators and make sure that blobs are correctly made.
50  std::set<string> known_blobs(
51  external_input_.begin(), external_input_.end());
52  std::set<string> remaining_output(
53  external_output_.begin(), external_output_.end());
54  for (const auto& blob : known_blobs) {
55  remaining_output.erase(blob);
56  }
57  for (const OperatorDef& op : def->op()) {
58  for (const string& in : op.input()) {
59  if (!known_blobs.count(in)) {
60  if (external_input_.size()) {
61  CAFFE_THROW(
62  "op ",
63  op.type(),
64  ": Source for input ",
65  in,
66  " is unknown for net ",
67  def->name(),
68  ", operator ",
69  ProtoDebugString(op));
70  } else {
71  // If we are not declaring input and output, we will simply VLOG it
72  // for debugging purposes.
73  VLOG(1) << "op " << op.type() << ": input " << in << " is unknown.";
74  }
75  }
76  }
77  for (const string& out : op.output()) {
78  known_blobs.insert(out);
79  remaining_output.erase(out);
80  }
81  }
82  // Finally, check if all declared outputs are being created.
83  CAFFE_ENFORCE(
84  remaining_output.size() == 0,
85  "Some of the blobs are declared as output but never produced by the "
86  "net ",
87  def->name(),
88  ", the first one is ",
89  *remaining_output.begin());
90 }
91 
92 bool NetBase::RunAsync() {
93  for (auto& op : GetOperators()) {
94  op->ResetEvent();
95  }
96  return DoRunAsync();
97 }
98 
99 namespace {
100 const std::string kSimpleNet = "simple";
101 
102 std::vector<NetObserverCreator>* GetNetObserverCreators() {
103  static std::vector<NetObserverCreator> creators;
104  return &creators;
105 }
106 
107 const std::unordered_map<std::string, std::string>& defaultOverrides() {
108  // redirecting legacy net types to async_scheduling (except for 'simple');
109  // async_scheduling checks net type for backward compatibility
110  static const std::unordered_map<std::string, std::string> overrides = {
111  {"dag", "async_scheduling"},
112  {"prof_dag", "async_scheduling"},
113  {"async_dag", "async_scheduling"},
114  {"async_polling", "async_scheduling"},
115  {"async_simple", "simple"}, // "async_simple" impl has been removed.
116  {"rnn", "simple"}, // "rnn" impl has been removed.
117  };
118  return overrides;
119 }
120 
121 void ApplyPotentialExecutorOverride(std::string* net_type) {
122  auto executors = caffe2::split(',', FLAGS_caffe2_override_executor);
123  CAFFE_ENFORCE(
124  executors.size() % 2 == 0, "Invalid override executors flag value");
125  std::unordered_map<std::string, std::string> overrides;
126  for (const auto& kv : defaultOverrides()) {
127  overrides[kv.first] = kv.second;
128  }
129  for (size_t idx = 0; idx < executors.size(); idx += 2) {
130  overrides[executors[idx]] = executors[idx + 1];
131  }
132  if (overrides.count(*net_type)) {
133  VLOG(1) << "Overrode net type '" << *net_type << "' with '"
134  << overrides[*net_type] << "'";
135  *net_type = overrides[*net_type];
136  }
137 }
138 
139 } // namespace
140 
141 void AddGlobalNetObserverCreator(NetObserverCreator creator) {
142  GetNetObserverCreators()->push_back(creator);
143  VLOG(1) << "Have set a custom GlobalNetObserverCreator";
144 }
145 
146 void ClearGlobalNetObservers() {
147  GetNetObserverCreators()->clear();
148  VLOG(1) << "All net observers cleared";
149 }
150 
151 unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws) {
152  std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def));
153  return CreateNet(tmp_net_def, ws);
154 }
155 
156 unique_ptr<NetBase> CreateNet(
157  const std::shared_ptr<const NetDef>& net_def,
158  Workspace* ws) {
159  std::string net_type;
160  if (net_def->has_type()) {
161  net_type = net_def->type();
162  } else {
163  // By default, we will return a simple network that just runs all operators
164  // sequentially.
165  net_type = kSimpleNet;
166  }
167  ApplyPotentialExecutorOverride(&net_type);
168  unique_ptr<NetBase> net = NetRegistry()->Create(net_type, net_def, ws);
169 
170  VLOG(1) << "Adding a global observer to a net";
171  if (net) {
172  auto* observer_creators = GetNetObserverCreators();
173  for (auto& creator : *observer_creators) {
174  net->AttachObserver(creator(net.get()));
175  }
176  }
177  return net;
178 }
179 
180 TaskThreadPoolBase* ExecutorHelper::GetPool(
181  const DeviceOption& /* unused */) const {
182  CAFFE_THROW("Not implemented");
183 }
184 
185 std::vector<OperatorBase*> ExecutorHelper::GetOperators() const {
186  CAFFE_THROW("Not implemented");
187 }
188 
189 int ExecutorHelper::GetNumWorkers() const {
190  CAFFE_THROW("Not implemented");
191 }
192 
193 std::vector<float> NetBase::TEST_Benchmark(
194  const int warmup_runs,
195  const int main_runs,
196  const bool run_individual) {
197  LOG(INFO) << "Starting benchmark, running warmup runs";
198  CAFFE_ENFORCE(
199  warmup_runs >= 0,
200  "Number of warm up runs should be non negative, provided ",
201  warmup_runs);
202  for (int run_idx = 0; run_idx < warmup_runs; ++run_idx) {
203  CAFFE_ENFORCE(Run(), "Warmup run ", run_idx, " has failed");
204  }
205 
206  LOG(INFO) << "Running main runs";
207  CAFFE_ENFORCE(
208  main_runs >= 0,
209  "Number of main runs should be non negative, provided ",
210  main_runs);
211 
212  Timer timer;
213  for (int run_idx = 0; run_idx < main_runs; ++run_idx) {
214  CAFFE_ENFORCE(Run(), "Main run ", run_idx, " has failed");
215  }
216  auto millis = timer.MilliSeconds();
217  LOG(INFO) << "Main runs finished. Milliseconds per iter: "
218  << millis / main_runs
219  << ". Iters per second: " << 1000.0 * main_runs / millis;
220 
221  if (run_individual) {
222  LOG(INFO) << "Net does not support per-op benchmark; "
223  "to run it, switch to a simple net type";
224  }
225  return std::vector<float>{millis / main_runs};
226 }
227 
228 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
float MilliSeconds()
Returns the elapsed time in milliseconds.
Definition: timer.h:32
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:151
A simple timer object for measuring time.
Definition: timer.h:16