Caffe2 - C++ API
A deep learning, cross platform ML framework
recurrent_network_executor.cc
1 
17 #include "caffe2/operators/recurrent_network_executor.h"
18 
19 #include "caffe2/core/timer.h"
20 
21 namespace caffe2 {
22 
28 template <>
29 std::unique_ptr<RecurrentNetworkExecutorBase> createRNNExecutor<CPUContext>(
30  const NetDef& step_net_def,
31  std::map<string, string>& recurrent_input_map,
32  std::string timestep_blob,
33  ArgumentHelper rnn_args) {
34  auto* exec = new ThreadedRecurrentNetworkExecutor(
35  step_net_def, recurrent_input_map, timestep_blob);
36  int num_threads =
37  rnn_args.GetSingleArgument<int>("rnn_executor.num_threads", 0);
38  if (num_threads > 0) {
39  exec->setNumThreads(num_threads);
40  LOG(INFO) << "Set num threads: " << num_threads;
41  }
42  exec->debug_ = rnn_args.GetSingleArgument<int>("rnn_executor_debug", 0);
43  return std::unique_ptr<RecurrentNetworkExecutorBase>(exec);
44 }
45 
50  CAFFE_ENFORCE(timestep_ops_.size() >= T);
51  countdown_ = T * timestep_ops_[0].size();
52  finished_timesteps_ = 0;
53 
54  CHECK(task_queue_.size() == 0);
55 
56  for (auto& rnn_op : timestep_ops_[0]) {
57  // Launch "frontier"-ops first.
58  if (rnn_op.frontier) {
59  task_queue_.Push(OpTask(0, rnn_op.order, T, 1));
60  }
61  }
62 
63  _Exec();
64  return true;
65 }
66 
71  CAFFE_ENFORCE(timestep_ops_.size() >= T);
72  countdown_ = T * timestep_ops_[0].size();
73  finished_timesteps_ = 0;
74 
75  // Frontier
76  CHECK(task_queue_.size() == 0);
77 
78  for (auto& rnn_op : timestep_ops_[T - 1]) {
79  if (rnn_op.frontier) {
80  task_queue_.Push(OpTask(T - 1, rnn_op.order, T, -1));
81  }
82  }
83 
84  _Exec();
85  return true;
86 }
87 
92 void ThreadedRecurrentNetworkExecutor::RunOp(OpTask job, int thread_id) {
93  bool first_timestep =
94  ((job.forward() && job.timestep == 0) ||
95  (job.backward() && job.timestep == job.T - 1));
96  bool last_timestep =
97  ((job.backward() && job.timestep == 0) ||
98  (job.forward() && job.timestep == job.T - 1));
99  auto& rnn_op = timestep_ops_[job.timestep][job.op_idx];
100  if (rnn_op.num_dynamic_inputs > 0 && !rnn_op.frontier) {
101  CAFFE_ENFORCE_EQ(
102  rnn_op.proc_inputs,
103  rnn_op.num_dynamic_inputs -
104  first_timestep * rnn_op.num_recurrent_inputs,
105  "Error at operator ",
106  job.op_idx,
107  " on timestep ",
108  job.timestep,
109  " T=",
110  job.T,
111  " first =",
112  first_timestep);
113  }
114 
115  // Reset input dependency counter
116  rnn_op.proc_inputs = 0;
117 
118  // Run the operator
119  rnn_op.op->Run();
120 
121  // Knock down dependencies and start next ops, if this
122  // was last dependency fulfilled.
123  for (int depidx : rnn_op.dependencies) {
124  int t = job.timestep;
125  bool for_next_timestep = depidx <= rnn_op.order;
126  if (!last_timestep && for_next_timestep) {
127  t += job.direction;
128  } else if (for_next_timestep) {
129  continue;
130  }
131 
132  auto& dep_op = timestep_ops_[t][depidx];
133  int proc_inputs = dep_op.proc_inputs.fetch_add(1) + 1;
134 
135  // Schedule next op, if this was the last dependency. Note that on
136  // first timestep we don't have recurrent inputs.
137  int num_req_inputs = dep_op.num_dynamic_inputs;
138  if (first_timestep && !for_next_timestep) {
139  num_req_inputs -= dep_op.num_recurrent_inputs;
140  }
141 
142  if (proc_inputs == num_req_inputs || num_req_inputs == 0) {
143  task_queue_.Push(OpTask(t, depidx, job.T, job.direction));
144  }
145  }
146 
147  // Decrement countdown: when at zero, we have run all ops and can
148  // notify the caller thread.
149  if (countdown_.fetch_sub(1) == 1) {
150  CAFFE_ENFORCE_EQ(0, task_queue_.size());
151  std::unique_lock<std::mutex> lk(countdown_mtx_);
152  cv_.notify_one();
153  }
154 }
155 
160 void ThreadedRecurrentNetworkExecutor::WorkerFunction() {
161  size_t num_jobs = 0;
162  static std::atomic<int> seq(0);
163  int id = seq.fetch_add(1);
164 
165  while (!failed_) {
166  OpTask job;
167  if (!task_queue_.Pop(&job)) {
168  break;
169  }
170 
171  // Check for limited timestep parallelism, and if too many timesteps would
172  // be started concurrently, return the task to task queue.
173  if (max_parallel_timesteps_ > 0) {
174  int t = (job.direction == 1 ? job.timestep : job.T - job.timestep + 1);
175  if (t - finished_timesteps_ >= max_parallel_timesteps_) {
176  // Return to queue
177  task_queue_.Push(job);
178  continue;
179  }
180  }
181 
182  try {
183  RunOp(job, id);
184  if (job.op_idx == timestep_ops_template_.size() - 1) {
185  finished_timesteps_.fetch_add(1);
186  }
187  num_jobs++;
188  } catch (::caffe2::EnforceNotMet& enf) {
189  std::unique_lock<std::mutex> lk(countdown_mtx_);
190  LOG(ERROR) << "Crash at thread " << id << " timestep " << job.timestep
191  << " op:" << ProtoDebugString(step_net_def_.op(job.op_idx))
192  << enf.what();
193  task_queue_.NoMoreJobs();
194  failed_ = true;
195  cv_.notify_one();
196  return;
197  }
198  }
199  VLOG(1) << "Worker exiting, did run: " << num_jobs << " jobs";
200 }
201 
206 void ThreadedRecurrentNetworkExecutor::_Exec() {
207  CAFFE_ENFORCE_EQ(
208  false, failed_, "Tried to execute a previously failed RNN executor");
209 
210  // Start threads if not started
211  std::unique_lock<std::mutex> lk(countdown_mtx_);
212  while (workers_.size() < num_threads_) {
213  VLOG(1) << "Start RNN worker " << workers_.size() << " / " << num_threads_;
214  workers_.push_back(
215  std::thread(&ThreadedRecurrentNetworkExecutor::WorkerFunction, this));
216  }
217 
218  // Wait until threads finish.
219  Timer t;
220  while (!failed_ && countdown_ > 0) {
221  cv_.wait_for(lk, std::chrono::seconds(30), [&] {
222  // Log if we are still running, so that we catch deadlocks.. there
223  // should not be any deadlocks, but...
224  if (t.Seconds() > 10) {
225  LOG(INFO) << "RNN Executor still running, remaining ops: "
226  << countdown_;
227  }
228  return failed_ || countdown_ == 0;
229  });
230  }
231 
232  CAFFE_ENFORCE_EQ(
233  false,
234  failed_,
235  "RNN executor encountered failure. See prior error logs for details.");
236 }
237 
238 } // namespace caffe2
Data structure for a scheduled task in the task queue.
std::unique_ptr< RecurrentNetworkExecutorBase > createRNNExecutor< CPUContext >(const NetDef &step_net_def, std::map< string, string > &recurrent_input_map, std::string timestep_blob, ArgumentHelper rnn_args)
Implementation of RecurrentNetworkExecutor that uses thread pool for multithreaded execution of RNNs...
A helper class to index into arguments.
Definition: proto_utils.h:198
Copyright (c) 2016-present, Facebook, Inc.
float Seconds()
Returns the elapsed time in seconds.
Definition: timer.h:56
bool RunBackwards(int T) override
Run backward pass with T timesteps.
A simple timer object for measuring time.
Definition: timer.h:32
bool Run(int T) override
Run forwardpass with T timesteps.