Caffe2 - C++ API
A deep learning, cross platform ML framework
net_dag.cc
1 
17 #include "caffe2/core/net_dag.h"
18 
19 #include <set>
20 #include <stack>
21 #include <unordered_map>
22 #include <unordered_set>
23 
24 #include "caffe2/core/operator.h"
25 #include "caffe2/core/static_tracepoint.h"
26 #include "caffe2/core/timer.h"
27 #include "caffe2/proto/caffe2.pb.h"
28 #include "caffe2/utils/proto_utils.h"
29 
30 CAFFE2_DEFINE_bool(
31  caffe2_disable_chaining,
32  false,
33  "Disable chaining logic (some latent multi-device issues).");
34 
35 CAFFE2_DEFINE_bool(
36  caffe2_dag_net_collect_stats,
37  false,
38  "Collect time stats in DAG net");
39 
40 namespace caffe2 {
41 
42 DAGNetBase::DAGNetBase(
43  const std::shared_ptr<const NetDef>& net_def,
44  Workspace* ws)
45  : NetBase(net_def, ws), iter_(0) {
46  // Blob creator allows us to track which operator created which blob.
47  VLOG(1) << "Constructing DAGNet " << net_def->name();
48 
49  operator_nodes_ = dag_utils::prepareOperatorNodes(net_def, ws);
50 
51  execution_chains_ =
52  (FLAGS_caffe2_disable_chaining
53  ? dag_utils::singleChains(operator_nodes_)
54  : dag_utils::computeChains(operator_nodes_));
55 
56  operators_.reserve(operator_nodes_.size());
57  for (const auto& node : operator_nodes_) {
58  operators_.push_back(node.operator_.get());
59  }
60 
61  LOG(INFO) << "Number of parallel execution chains "
62  << execution_chains_.size()
63  << " Number of operators = " << net_def->op_size();
64  // TODO: do we want to make sure that there are no loops in the
65  // dependency graph?
66 
67  // Figure out the initial frontier - this is the one we will feed into the job
68  // queue to start a run.
69  for (int idx = 0; idx < operator_nodes_.size(); ++idx) {
70  if (operator_nodes_[idx].parents_.size() == 0) {
71  initial_frontier_.push_back(idx);
72  }
73  }
74  // Finally, start the workers.
75  int num_workers = net_def->has_num_workers() ? net_def->num_workers() : 1;
76  CAFFE_ENFORCE(num_workers > 0, "Must have a positive number of workers.");
77  if (num_workers == 1) {
78  LOG(WARNING) << "Number of workers is 1: this means that all operators "
79  << "will be executed sequentially. Did you forget to set "
80  << "num_workers in the NetDef?";
81  }
82  num_workers_ = num_workers;
83 
84  for (int idx = 0; idx < operator_nodes_.size(); ++idx) {
85  if (operator_nodes_[idx].is_chain_start_) {
86  task_timers_[idx] = caffe2::make_unique<Timer>();
87  }
88  }
89  stats_.reserve(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
90  for (auto device_idx = 0;
91  device_idx < DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES;
92  ++device_idx) {
93  stats_.emplace_back(
94  "dag_net/stats/" + net_def->name() + "/" +
95  caffe2::DeviceTypeName(device_idx));
96  }
97 }
98 
99 DAGNetBase::~DAGNetBase() {
100  if (job_queue_) {
101  job_queue_->NoMoreJobs();
102  VLOG(1) << "Joining workers.";
103  for (auto& worker : workers_) {
104  worker.join();
105  }
106  }
107 }
108 
109 bool DAGNetBase::DoRunAsync() {
110  StartAllObservers();
111 
112  // Lock run_in_progress_ to prevent concurrent Run()s.
113  std::unique_lock<std::mutex> run_lock(run_in_progress_);
114  VLOG(1) << "Running parallel net.";
115  // First, set up job queue.
116  remaining_ops_ = operator_nodes_.size();
117  success_ = true;
118  iter_++;
119  if (!job_queue_) {
120  job_queue_ = caffe2::make_unique<SimpleQueue<int>>();
121  }
122  // Figure out number of workers to start.
123  auto num_workers_to_start = num_workers_ - workers_.size();
124 
125  // Ensure the number of workers matches the defined in case
126  // any of the previously started threads terminated.
127  for (auto i = 0; i < num_workers_to_start; i++) {
128  VLOG(1) << "Start worker #" << workers_.size();
129  workers_.push_back(std::thread(&DAGNetBase::WorkerFunction, this));
130  }
131  // Initialize the runtime parent count.
132  for (auto& node : operator_nodes_) {
133  node.runtime_parent_count_ = node.parents_.size();
134  }
135  // Kickstart the job queue.
136  for (auto& value : initial_frontier_) {
137  if (FLAGS_caffe2_dag_net_collect_stats) {
138  task_timers_[value]->Start();
139  }
140  job_queue_->Push(value);
141  }
142  // Wait for failure or completed execution.
143  {
144  std::unique_lock<std::mutex> mutex_lock(remaining_ops_mutex_);
145  for (;;) {
146  if (remaining_ops_ == 0 || !success_) {
147  break;
148  }
149  cv_.wait(mutex_lock);
150  }
151  }
152  // Wait for all workers to terminate after failure.
153  // If there is a failure, it is unlikely that the net is executed
154  // again without modifications. Therefore it's easier to let the
155  // workers terminate here, versus adding a drain state to make the
156  // sure the job queue is cleared.
157  if (!success_) {
158  for (auto& worker : workers_) {
159  worker.join();
160  }
161  workers_.clear();
162  job_queue_.reset(nullptr);
163  return success_;
164  }
165  VLOG(2) << "All ops finished running.";
166  for (const auto& op : operator_nodes_) {
167  CAFFE_ENFORCE(
168  op.runtime_parent_count_ == 0,
169  "Operator ",
170  op.operator_->debug_def().name(),
171  "(",
172  op.operator_->debug_def().type(),
173  ") has some runtime parents left.");
174  }
175 
176  StopAllObservers();
177  // If the above while loop finished, we know that the current run finished.
178  return success_;
179 }
180 
181 void DAGNetBase::WorkerFunction() {
182  // WorkerFunctions() is an infinite loop until there are no more jobs to run.
183  while (true) {
184  int idx = 0;
185 
186  // Return if there are no more operators to run (e.g. the
187  // DAGNetBase is destructing, or there was an error on another
188  // worker and we're cleaning up).
189  if (!job_queue_->Pop(&idx)) {
190  return;
191  }
192  if (FLAGS_caffe2_dag_net_collect_stats) {
193  auto device_option =
194  operator_nodes_[idx].operator_->event().GetDeviceOption();
195  CAFFE_EVENT(
196  stats_[device_option.device_type()],
197  task_pool_wait_time_us,
198  task_timers_[idx]->MicroSeconds());
199  }
200 
201  VLOG(1) << "Running operator #" << idx << " "
202  << operator_nodes_[idx].operator_->debug_def().name() << "("
203  << operator_nodes_[idx].operator_->debug_def().type() << ").";
204  CAFFE_ENFORCE(
205  execution_chains_.find(idx) != execution_chains_.end(),
206  "Can't find chain ",
207  idx,
208  ".");
209  const auto& chain = execution_chains_[idx];
210  bool this_success = RunAt(idx, execution_chains_[idx]);
211  if (!this_success) {
212  LOG(ERROR) << "Operator chain failed: "
213  << ProtoDebugString(
214  operator_nodes_[idx].operator_->debug_def());
215  }
216 
217  // Do book-keeping
218  std::vector<int> chains_to_queue;
219  for (const auto idx : chain) {
220  for (const auto child : operator_nodes_[idx].children_) {
221  const int count = --operator_nodes_[child].runtime_parent_count_;
222  CAFFE_ENFORCE(
223  count >= 0,
224  "Found runtime parent count smaller than zero for ",
225  "operator node ",
226  operator_nodes_[child].operator_->debug_def().name(),
227  "(",
228  operator_nodes_[child].operator_->debug_def().type(),
229  ").");
230 
231  if (count != 0) {
232  continue;
233  }
234 
235  if (operator_nodes_[child].is_chain_start_) {
236  VLOG(2) << "Pushing chain #" << child << " to queue.";
237  chains_to_queue.push_back(child);
238  }
239  }
240  }
241 
242  // Notify the caller of Run
243  {
244  std::unique_lock<std::mutex> mutex_lock(remaining_ops_mutex_);
245  remaining_ops_ -= chain.size();
246  CAFFE_ENFORCE(remaining_ops_ >= 0);
247  success_ &= this_success;
248  if (remaining_ops_ == 0 || !success_) {
249  cv_.notify_one();
250  }
251 
252  // Terminate thread if this or any other operator chain failed.
253  if (!success_) {
254  job_queue_->NoMoreJobs();
255  return;
256  }
257 
258  // Queue follow up operator chains.
259  // Can't do this inline because it can race with another thread
260  // calling NoMoreJobs(). So the lock needs to be held on push.
261  for (const auto idx : chains_to_queue) {
262  if (FLAGS_caffe2_dag_net_collect_stats) {
263  task_timers_[idx]->Start();
264  }
265  job_queue_->Push(idx);
266  }
267  }
268 
269  VLOG(2) << "Finished executing operator #" << idx;
270  }
271 }
272 
274  const int warmup_runs,
275  const int main_runs,
276  const bool run_individual) {
277  LOG(INFO) << "Starting benchmark.";
278  LOG(INFO) << "Running warmup runs.";
279  CAFFE_ENFORCE(
280  warmup_runs >= 0,
281  "Number of warm up runs should be non negative, provided ",
282  warmup_runs,
283  ".");
284  for (int i = 0; i < warmup_runs; ++i) {
285  CAFFE_ENFORCE(Run(), "Warmup run ", i, " has failed.");
286  }
287 
288  LOG(INFO) << "Main runs.";
289  CAFFE_ENFORCE(
290  main_runs >= 0,
291  "Number of main runs should be non negative, provided ",
292  main_runs,
293  ".");
294  Timer timer;
295  for (int i = 0; i < main_runs; ++i) {
296  CAFFE_ENFORCE(Run(), "Main run ", i, " has failed.");
297  }
298  auto millis = timer.MilliSeconds();
299  LOG(INFO) << "Main run finished. Milliseconds per iter: "
300  << millis / main_runs
301  << ". Iters per second: " << 1000.0 * main_runs / millis;
302 
303  if (run_individual) {
304  LOG(INFO) << "DAGNet does not do per-op benchmark. To do so, "
305  "switch to a simple net type.";
306  }
307  return vector<float>{millis / main_runs};
308 }
309 
310 bool DAGNet::RunAt(int chain_id, const std::vector<int>& chain) {
311  for (const auto i : chain) {
312 #ifdef CAFFE2_ENABLE_SDT
313  const auto& op_name =
314  operator_nodes_[i].operator_->debug_def().name().c_str();
315  const auto& op_type =
316  operator_nodes_[i].operator_->debug_def().type().c_str();
317  auto* op_ptr = operator_nodes_[i].operator_.get();
318  const auto& net_name = name_.c_str();
319  CAFFE_SDT(operator_start, net_name, op_name, op_type, op_ptr);
320 #endif
321  const auto success = operator_nodes_[i].operator_->Run();
322 #ifdef CAFFE2_ENABLE_SDT
323  CAFFE_SDT(operator_done, net_name, op_name, op_type, op_ptr);
324 #endif
325  if (!success) {
326  return false;
327  }
328  }
329  if (FLAGS_caffe2_dag_net_collect_stats) {
330  auto device_option =
331  operator_nodes_[chain_id].operator_->event().GetDeviceOption();
332  CAFFE_EVENT(
333  stats_[device_option.device_type()],
334  task_time_to_succeeded_ms,
335  task_timers_[chain_id]->MilliSeconds());
336  }
337  return true;
338 }
339 
340 REGISTER_NET(dag, DAGNet);
341 
342 } // namespace caffe2
vector< float > TEST_Benchmark(const int warmup_runs, const int main_runs, const bool run_individual) override
Benchmarks a network.
Definition: net_dag.cc:273
Copyright (c) 2016-present, Facebook, Inc.
float MilliSeconds()
Returns the elapsed time in milliseconds.
Definition: timer.h:48
A simple timer object for measuring time.
Definition: timer.h:32