Caffe2 - C++ API
A deep learning, cross platform ML framework
net_simple_async.cc
1 
17 #include "caffe2/core/net_simple_async.h"
18 #include "caffe2/core/net.h"
19 
20 #include <set>
21 #include <unordered_map>
22 #include <unordered_set>
23 
24 #include "caffe2/core/operator.h"
25 #include "caffe2/core/static_tracepoint.h"
26 #include "caffe2/core/timer.h"
27 #include "caffe2/proto/caffe2.pb.h"
28 #include "caffe2/utils/proto_utils.h"
29 
30 namespace caffe2 {
31 
32 AsyncSimpleNet::AsyncSimpleNet(
33  const std::shared_ptr<const NetDef>& net_def,
34  Workspace* ws)
35  : NetBase(net_def, ws) {
36  VLOG(1) << "Constructing AsyncSimpleNet " << net_def->name();
37  const bool net_def_has_device_option = net_def->has_device_option();
38  // Initialize the operators
39  const DeviceOption* first_device_option = nullptr;
40  const DeviceOption* current_device_option;
41  for (int idx = 0; idx < net_def->op_size(); ++idx) {
42  const auto& operator_def = net_def->op(idx);
43  VLOG(1) << "Creating operator " << operator_def.name() << ": "
44  << operator_def.type();
45  std::unique_ptr<OperatorBase> op{nullptr};
46  if (!operator_def.has_device_option() && net_def_has_device_option) {
47  // In the case that the operator def does not specify a device option but
48  // the net def has a default option, we copy the device option over to the
49  // operator def.
50  OperatorDef temp_def(operator_def);
51  temp_def.mutable_device_option()->CopyFrom(net_def->device_option());
52  op = CreateOperator(temp_def, ws, idx);
53  current_device_option = &net_def->device_option();
54  } else {
55  op = CreateOperator(operator_def, ws, idx);
56  op->set_debug_def(
57  std::shared_ptr<const OperatorDef>{net_def, &(net_def->op(idx))});
58  current_device_option = &operator_def.device_option();
59  }
60  if (!first_device_option) {
61  first_device_option = current_device_option;
62  } else {
63  CAFFE_ENFORCE(
64  IsSameDevice(*first_device_option, *current_device_option),
65  "AsyncSimpleNet supports only single device networks");
66  }
67  operators_.emplace_back(std::move(op));
68  }
69  events_ = {&operators_.back()->event()};
70 }
71 
72 bool AsyncSimpleNet::DoRunAsync() {
73  StartAllObservers();
74 
75  VLOG(1) << "Running net " << name_;
76  for (auto& op : operators_) {
77  VLOG(1) << "Running operator " << op->debug_def().name() << "("
78  << op->debug_def().type() << ").";
79 #ifdef CAFFE2_ENABLE_SDT
80  const auto& op_name = op->debug_def().name().c_str();
81  const auto& op_type = op->debug_def().type().c_str();
82  auto* op_ptr = op.get();
83  const auto& net_name = name_.c_str();
84  CAFFE_SDT(operator_start, net_name, op_name, op_type, op_ptr);
85 #endif
86  bool res = op->RunAsync();
87 #ifdef CAFFE2_ENABLE_SDT
88  CAFFE_SDT(operator_done, net_name, op_name, op_type, op_ptr);
89 #endif
90  if (!res) {
91  LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def());
92  return false;
93  }
94  }
95  StopAllObservers();
96  return true;
97 }
98 
100  const int warmup_runs,
101  const int main_runs,
102  const bool run_individual) {
103  LOG(INFO) << "Starting benchmark.";
104  LOG(INFO) << "Running warmup runs.";
105  CAFFE_ENFORCE(
106  warmup_runs >= 0,
107  "Number of warm up runs should be non negative, provided ",
108  warmup_runs,
109  ".");
110  for (int i = 0; i < warmup_runs; ++i) {
111  CAFFE_ENFORCE(Run(), "Warmup run ", i, " has failed.");
112  }
113 
114  LOG(INFO) << "Main runs.";
115  CAFFE_ENFORCE(
116  main_runs >= 0,
117  "Number of main runs should be non negative, provided ",
118  main_runs,
119  ".");
120  Timer timer;
121  for (int i = 0; i < main_runs; ++i) {
122  CAFFE_ENFORCE(Run(), "Main run ", i, " has failed.");
123  }
124  auto millis = timer.MilliSeconds();
125  LOG(INFO) << "Main run finished. Milliseconds per iter: "
126  << millis / main_runs
127  << ". Iters per second: " << 1000.0 * main_runs / millis;
128 
129  if (run_individual) {
130  LOG(INFO) << "AsyncSimpleNet does not do per-op benchmark. To do so, "
131  "switch to a simple net type.";
132  }
133  return vector<float>{millis / main_runs};
134 }
135 
136 REGISTER_NET(async_simple, AsyncSimpleNet);
137 
138 } // namespace caffe2
vector< float > TEST_Benchmark(const int warmup_runs, const int main_runs, const bool run_individual) override
Benchmarks a network.
Copyright (c) 2016-present, Facebook, Inc.
float MilliSeconds()
Returns the elapsed time in milliseconds.
Definition: timer.h:48
A simple timer object for measuring time.
Definition: timer.h:32