1 #include "caffe2/operators/rnn/recurrent_network_executor.h" 3 #include "caffe2/core/timer.h" 14 const NetDef& step_net_def,
15 std::map<string, string>& recurrent_input_map,
16 std::string timestep_blob,
19 step_net_def, recurrent_input_map, timestep_blob);
21 rnn_args.GetSingleArgument<
int>(
"rnn_executor.num_threads", 0);
22 if (num_threads > 0) {
23 exec->setNumThreads(num_threads);
24 LOG(INFO) <<
"Set num threads: " << num_threads;
26 exec->debug_ = rnn_args.GetSingleArgument<
int>(
"rnn_executor_debug", 0);
27 return std::unique_ptr<RecurrentNetworkExecutorBase>(exec);
34 CAFFE_ENFORCE_GE(T, 0,
"Negative number of steps");
39 CAFFE_ENFORCE(timestep_ops_.size() >= T);
40 countdown_ = T * timestep_ops_[0].size();
41 finished_timesteps_ = 0;
43 CHECK(task_queue_.size() == 0);
45 for (
auto& rnn_op : timestep_ops_[0]) {
47 if (rnn_op.frontier) {
48 task_queue_.Push(
OpTask(0, rnn_op.order, T, 1));
60 CAFFE_ENFORCE_GE(T, 0,
"Negative number of steps");
65 CAFFE_ENFORCE(timestep_ops_.size() >= T);
66 countdown_ = T * timestep_ops_[0].size();
67 finished_timesteps_ = 0;
70 CHECK(task_queue_.size() == 0);
72 for (
auto& rnn_op : timestep_ops_[T - 1]) {
73 if (rnn_op.frontier) {
74 task_queue_.Push(
OpTask(T - 1, rnn_op.order, T, -1));
86 void ThreadedRecurrentNetworkExecutor::RunOp(
OpTask job,
int ) {
88 ((job.forward() && job.timestep == 0) ||
89 (job.backward() && job.timestep == job.T - 1));
91 ((job.backward() && job.timestep == 0) ||
92 (job.forward() && job.timestep == job.T - 1));
93 auto& rnn_op = timestep_ops_[job.timestep][job.op_idx];
94 if (rnn_op.num_dynamic_inputs > 0 && !rnn_op.frontier) {
97 rnn_op.num_dynamic_inputs -
98 first_timestep * rnn_op.num_recurrent_inputs,
110 rnn_op.proc_inputs = 0;
117 for (
int depidx : rnn_op.dependencies) {
118 int t = job.timestep;
119 bool for_next_timestep = depidx <= rnn_op.order;
120 if (!last_timestep && for_next_timestep) {
122 }
else if (for_next_timestep) {
126 auto& dep_op = timestep_ops_[t][depidx];
127 int proc_inputs = dep_op.proc_inputs.fetch_add(1) + 1;
131 int num_req_inputs = dep_op.num_dynamic_inputs;
132 if (first_timestep && !for_next_timestep) {
133 num_req_inputs -= dep_op.num_recurrent_inputs;
136 if (proc_inputs == num_req_inputs || num_req_inputs == 0) {
137 task_queue_.Push(
OpTask(t, depidx, job.T, job.direction));
143 if (countdown_.fetch_sub(1) == 1) {
144 CAFFE_ENFORCE_EQ(0, task_queue_.size());
145 std::unique_lock<std::mutex> lk(countdown_mtx_);
154 void ThreadedRecurrentNetworkExecutor::WorkerFunction() {
156 static std::atomic<int> seq(0);
157 int id = seq.fetch_add(1);
161 if (!task_queue_.Pop(&job)) {
167 if (max_parallel_timesteps_ > 0) {
168 int t = (job.direction == 1 ? job.timestep : job.T - job.timestep + 1);
169 if (t - finished_timesteps_ >= max_parallel_timesteps_) {
171 task_queue_.Push(job);
178 if (job.op_idx == timestep_ops_template_.size() - 1) {
179 finished_timesteps_.fetch_add(1);
183 std::unique_lock<std::mutex> lk(countdown_mtx_);
184 LOG(ERROR) <<
"Crash at thread " <<
id <<
" timestep " << job.timestep
185 <<
" op:" << ProtoDebugString(step_net_def_.op(job.op_idx))
187 task_queue_.NoMoreJobs();
193 VLOG(1) <<
"Worker exiting, did run: " << num_jobs <<
" jobs";
200 void ThreadedRecurrentNetworkExecutor::_Exec() {
202 false, failed_,
"Tried to execute a previously failed RNN executor");
205 std::unique_lock<std::mutex> lk(countdown_mtx_);
206 while (workers_.size() < num_threads_) {
207 VLOG(1) <<
"Start RNN worker " << workers_.size() <<
" / " << num_threads_;
209 std::thread(&ThreadedRecurrentNetworkExecutor::WorkerFunction,
this));
214 while (!failed_ && countdown_ > 0) {
215 cv_.wait_for(lk, std::chrono::seconds(30), [&] {
219 LOG(INFO) <<
"RNN Executor still running, remaining ops: " 222 return failed_ || countdown_ == 0;
229 "RNN executor encountered failure. See prior error logs for details.");
Data structure for a scheduled task in the task queue.
std::unique_ptr< RecurrentNetworkExecutorBase > createRNNExecutor< CPUContext >(const NetDef &step_net_def, std::map< string, string > &recurrent_input_map, std::string timestep_blob, ArgumentHelper rnn_args)
Implementation of RecurrentNetworkExecutor that uses thread pool for multithreaded execution of RNNs...
A helper class to index into arguments.
The primary ATen error class.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
const char * what() const noexceptoverride
Returns the complete error message, including the source location.
float Seconds()
Returns the elapsed time in seconds.
bool RunBackwards(int T) override
Run backward pass with T timesteps.
A simple timer object for measuring time.
bool Run(int T) override
Run forwardpass with T timesteps.