1 #include "caffe2/core/net.h" 2 #include "caffe2/core/net_simple.h" 5 #include <unordered_map> 6 #include <unordered_set> 8 #include "caffe2/core/init.h" 9 #include "caffe2/core/operator.h" 10 #include "caffe2/core/timer.h" 11 #include "caffe2/proto/caffe2_pb.h" 12 #include "caffe2/utils/proto_utils.h" 13 #include "caffe2/utils/string_utils.h" 16 caffe2_override_executor,
18 "Comma-separated list of executor overrides");
25 const std::shared_ptr<const NetDef>&,
29 const std::shared_ptr<const NetDef>& def,
32 def->external_input().begin(),
33 def->external_input().end()),
35 def->external_output().begin(),
36 def->external_output().end()),
39 static GlobalInitIsCalledGuard guard;
41 for (
const OperatorDef& op : def->op()) {
42 if (op.has_device_option()) {
44 !op.device_option().has_node_name(),
45 "node_name must be empty for all operators at execution time.");
50 std::set<string> known_blobs(
51 external_input_.begin(), external_input_.end());
52 std::set<string> remaining_output(
53 external_output_.begin(), external_output_.end());
54 for (
const auto& blob : known_blobs) {
55 remaining_output.erase(blob);
57 for (
const OperatorDef& op : def->op()) {
58 for (
const string& in : op.input()) {
59 if (!known_blobs.count(in)) {
60 if (external_input_.size()) {
64 ": Source for input ",
66 " is unknown for net ",
69 ProtoDebugString(op));
73 VLOG(1) <<
"op " << op.type() <<
": input " << in <<
" is unknown.";
77 for (
const string& out : op.output()) {
78 known_blobs.insert(out);
79 remaining_output.erase(out);
84 remaining_output.size() == 0,
85 "Some of the blobs are declared as output but never produced by the " 88 ", the first one is ",
89 *remaining_output.begin());
92 bool NetBase::RunAsync() {
93 for (
auto& op : GetOperators()) {
100 const std::string kSimpleNet =
"simple";
102 std::vector<NetObserverCreator>* GetNetObserverCreators() {
103 static std::vector<NetObserverCreator> creators;
107 const std::unordered_map<std::string, std::string>& defaultOverrides() {
110 static const std::unordered_map<std::string, std::string> overrides = {
111 {
"dag",
"async_scheduling"},
112 {
"prof_dag",
"async_scheduling"},
113 {
"async_dag",
"async_scheduling"},
114 {
"async_polling",
"async_scheduling"},
115 {
"async_simple",
"simple"},
121 void ApplyPotentialExecutorOverride(std::string* net_type) {
122 auto executors = caffe2::split(
',', FLAGS_caffe2_override_executor);
124 executors.size() % 2 == 0,
"Invalid override executors flag value");
125 std::unordered_map<std::string, std::string> overrides;
126 for (
const auto& kv : defaultOverrides()) {
127 overrides[kv.first] = kv.second;
129 for (
size_t idx = 0; idx < executors.size(); idx += 2) {
130 overrides[executors[idx]] = executors[idx + 1];
132 if (overrides.count(*net_type)) {
133 VLOG(1) <<
"Overrode net type '" << *net_type <<
"' with '" 134 << overrides[*net_type] <<
"'";
135 *net_type = overrides[*net_type];
141 void AddGlobalNetObserverCreator(NetObserverCreator creator) {
142 GetNetObserverCreators()->push_back(creator);
143 VLOG(1) <<
"Have set a custom GlobalNetObserverCreator";
146 void ClearGlobalNetObservers() {
147 GetNetObserverCreators()->clear();
148 VLOG(1) <<
"All net observers cleared";
152 std::shared_ptr<NetDef> tmp_net_def(
new NetDef(net_def));
157 const std::shared_ptr<const NetDef>& net_def,
159 std::string net_type;
160 if (net_def->has_type()) {
161 net_type = net_def->type();
165 net_type = kSimpleNet;
167 ApplyPotentialExecutorOverride(&net_type);
168 unique_ptr<NetBase> net = NetRegistry()->Create(net_type, net_def, ws);
170 VLOG(1) <<
"Adding a global observer to a net";
172 auto* observer_creators = GetNetObserverCreators();
173 for (
auto& creator : *observer_creators) {
174 net->AttachObserver(creator(net.get()));
181 const DeviceOption& )
const {
182 CAFFE_THROW(
"Not implemented");
185 std::vector<OperatorBase*> ExecutorHelper::GetOperators()
const {
186 CAFFE_THROW(
"Not implemented");
189 int ExecutorHelper::GetNumWorkers()
const {
190 CAFFE_THROW(
"Not implemented");
193 std::vector<float> NetBase::TEST_Benchmark(
194 const int warmup_runs,
196 const bool run_individual) {
197 LOG(INFO) <<
"Starting benchmark, running warmup runs";
200 "Number of warm up runs should be non negative, provided ",
202 for (
int run_idx = 0; run_idx < warmup_runs; ++run_idx) {
203 CAFFE_ENFORCE(Run(),
"Warmup run ", run_idx,
" has failed");
206 LOG(INFO) <<
"Running main runs";
209 "Number of main runs should be non negative, provided ",
213 for (
int run_idx = 0; run_idx < main_runs; ++run_idx) {
214 CAFFE_ENFORCE(Run(),
"Main run ", run_idx,
" has failed");
217 LOG(INFO) <<
"Main runs finished. Milliseconds per iter: " 218 << millis / main_runs
219 <<
". Iters per second: " << 1000.0 * main_runs / millis;
221 if (run_individual) {
222 LOG(INFO) <<
"Net does not support per-op benchmark; " 223 "to run it, switch to a simple net type";
225 return std::vector<float>{millis / main_runs};
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
float MilliSeconds()
Returns the elapsed time in milliseconds.
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
A simple timer object for measuring time.