1 #ifndef CAFFE2_OPERATORS_RECURRENT_NETWORK_EXECUTOR_H_ 2 #define CAFFE2_OPERATORS_RECURRENT_NETWORK_EXECUTOR_H_ 5 #include <unordered_set> 8 #include "caffe2/core/context.h" 9 #include "caffe2/core/logging.h" 10 #include "caffe2/core/operator.h" 11 #include "caffe2/core/timer.h" 12 #include "caffe2/operators/rnn/recurrent_network_executor_incl.h" 34 const NetDef& step_net_def,
35 std::map<string, string>& recurrent_input_map,
36 std::string timestep_blob)
37 : step_net_def_(step_net_def),
38 recurrent_input_map_(recurrent_input_map),
39 timestep_blob_(timestep_blob) {
40 const bool net_def_has_device_option = step_net_def_.has_device_option();
41 for (
int i = 0; i < step_net_def_.op_size(); i++) {
42 if (!step_net_def_.op(i).has_device_option() &&
43 net_def_has_device_option) {
47 step_net_def_.mutable_op(i)->mutable_device_option()->CopyFrom(
48 step_net_def_.device_option());
50 op_deps_.push_back(op_deps(i));
57 if (timestep_ops_.size() > 0) {
63 virtual bool Run(
int T) = 0;
65 virtual bool RunBackwards(
int T) = 0;
79 if (timestep_ops_template_.size() == 0) {
81 CalculateInternalDependencies();
85 for (
auto& rnn_op : timestep_ops_template_) {
86 rnn_op.has_timestep_blob =
false;
87 const OperatorDef& op = step_net_def_.op(rnn_op.order);
88 for (
int i = 0; i < op.input_size(); i++) {
89 if (op.input(i) == timestep_blob_) {
90 rnn_op.has_timestep_blob =
true;
95 !HasOutput(op, timestep_blob_),
96 "Timestep cannot be output of an op: ",
98 " op=" + ProtoDebugString(op));
103 if (timestep_ops_.size() <= t ||
104 (timestep_ops_.size() > t && timestep_ops_[t].size() == 0)) {
107 for (
int j = timestep_ops_.size(); j < t + 1; j++) {
108 timestep_ops_.push_back(std::vector<RNNNetOperator>());
109 timestep_ops_.back().reserve(timestep_ops_template_.size());
113 if (workspaces_.size() < t + 1) {
114 workspaces_.resize(t + 1);
121 std::string this_timestep_blob =
122 timestep_blob_ +
"_rnnexec_t" + c10::to_string(t);
123 BlobGetMutableTensor(ws->
CreateBlob(this_timestep_blob), CPU)->Resize(1);
124 auto b = ws->
GetBlob(this_timestep_blob);
126 BlobGetMutableTensor(b, CPU)->template mutable_data<int32_t>()[0] = t;
129 for (
auto& template_rnn_op : timestep_ops_template_) {
130 auto& rnn_op = template_rnn_op;
136 if (rnn_op.has_timestep_blob) {
137 OperatorDef op_copy = step_net_def_.op(rnn_op.order);
139 for (
int i = 0; i < op_copy.input_size(); i++) {
140 if (op_copy.input(i) == timestep_blob_) {
141 op_copy.set_input(i, this_timestep_blob);
145 rnn_op.op = CreateOperator(op_copy, ws);
146 for (
const auto& observer : observers_list) {
147 std::unique_ptr<ObserverBase<OperatorBase>> rnn_observer_copy =
148 observer.get()->rnnCopy(rnn_op.op.get(), rnn_op.order);
149 if (rnn_observer_copy) {
150 rnn_op.op->AttachObserver(std::move(rnn_observer_copy));
156 if (t > max_parallel_timesteps_ && max_parallel_timesteps_ > 0 &&
157 workspaces_[t - max_parallel_timesteps_] == ws) {
159 timestep_ops_[t - max_parallel_timesteps_][rnn_op.order].op;
163 rnn_op.op = CreateOperator(step_net_def_.op(rnn_op.order), ws);
164 for (
const auto& observer : observers_list) {
165 std::unique_ptr<ObserverBase<OperatorBase>> rnn_observer_copy =
166 observer.get()->rnnCopy(rnn_op.op.get(), rnn_op.order);
167 if (rnn_observer_copy) {
168 rnn_op.op->AttachObserver(std::move(rnn_observer_copy));
173 rnn_op.op->DisableEvent();
175 timestep_ops_[t].emplace_back(rnn_op);
186 max_parallel_timesteps_ = p;
189 size_t NumObserversStepNet() {
191 for (
auto& ops_at_timestep_t : timestep_ops_) {
192 for (
auto& rnn_op : ops_at_timestep_t) {
193 num += rnn_op.op->NumObservers();
202 bool has_input(std::string x,
int opidx) {
203 for (
auto& inp : step_net_def_.op(opidx).input()) {
208 for (
auto& inp : step_net_def_.op(opidx).control_input()) {
218 std::vector<string> op_deps(
int i) {
219 std::vector<string> outs;
220 auto& opdef = step_net_def_.op(i);
221 for (
string o : opdef.output()) {
224 for (
auto& arg : opdef.arg()) {
225 if (arg.name().find(
"rnn_dependency") == 0) {
226 outs.push_back(arg.s());
236 void infer_dependencies(
238 std::unordered_set<string> outputs,
239 std::vector<RNNNetOperator>& rnn_ops,
240 std::unordered_set<int>* dep_ops) {
241 std::unordered_set<int> already_accounted_deps;
242 int num_ops = step_net_def_.op_size();
243 bool ignore_links = this->ignoreLinkDependencies();
244 for (
int j = 0; j < num_ops - 1 && !outputs.empty(); j++) {
245 int i = (start_i + j) % num_ops;
246 if (ignore_links && rnn_ops[i].link_op) {
249 for (
auto& outp : outputs) {
250 if (has_input(outp, i)) {
251 if (already_accounted_deps.find(i) == already_accounted_deps.end()) {
257 for (
int odep : rnn_ops[i].dependencies) {
258 already_accounted_deps.insert(odep);
260 for (
string& dep_out : op_deps_[i]) {
261 auto oit = outputs.find(dep_out);
262 if (oit != outputs.end()) {
281 void add_race_conflict_dependencies(
283 std::vector<RNNNetOperator>& rnn_ops,
284 std::unordered_set<int>* dep_ops) {
285 for (
int i = 0; i < rnn_ops.size(); i++) {
289 if (rnn_ops[i].link_op && this->ignoreLinkDependencies()) {
292 for (
auto& dep_blob : op_deps_[i]) {
293 for (
auto& inp : step_net_def_.op(opidx).input()) {
294 if (inp == dep_blob) {
300 for (
auto& outp : step_net_def_.op(opidx).output()) {
301 if (outp == dep_blob) {
316 void CalculateInternalDependencies() {
317 for (
int i = 0; i < step_net_def_.op_size(); i++) {
318 timestep_ops_template_.push_back(
RNNNetOperator(step_net_def_.op(i), i));
322 for (
auto& rnn_op : timestep_ops_template_) {
323 std::unordered_set<string> dep_outputs;
324 for (
auto& outp : op_deps_[rnn_op.order]) {
325 dep_outputs.insert(outp);
329 for (
auto& outp : dep_outputs) {
330 auto rit = recurrent_input_map_.find(outp);
331 if (rit != recurrent_input_map_.end()) {
332 dep_outputs.insert(rit->second);
334 dep_outputs.insert(outp);
339 if (!rnn_op.link_op || !this->ignoreLinkDependencies()) {
340 std::unordered_set<int> dependent_ops;
344 timestep_ops_template_,
349 if (!this->ignoreLinkDependencies()) {
350 add_race_conflict_dependencies(
351 rnn_op.order, timestep_ops_template_, &dependent_ops);
354 for (
int i : dependent_ops) {
355 rnn_op.dependencies.push_back(i);
362 rnn_op.dependencies.begin(),
363 rnn_op.dependencies.end(),
364 [&](
const int& a,
const int& b) {
365 if (a < rnn_op.order && b < rnn_op.order) {
368 if (a >= rnn_op.order && b >= rnn_op.order) {
371 if (a >= rnn_op.order && b < rnn_op.order) {
380 for (
auto& rnn_op : timestep_ops_template_) {
381 for (
int i : rnn_op.dependencies) {
382 timestep_ops_template_[i].num_dynamic_inputs++;
384 if (i > rnn_op.order) {
385 timestep_ops_template_[i].frontier =
false;
387 timestep_ops_template_[i].num_recurrent_inputs++;
395 for (
auto& rnn_op : timestep_ops_template_) {
396 if (rnn_op.num_dynamic_inputs == 0 && rnn_op.num_recurrent_inputs == 0) {
397 if (rnn_op.link_op && this->ignoreLinkDependencies()) {
400 timestep_ops_template_.back().dependencies.push_back(rnn_op.order);
405 for (
auto& rnn_op : timestep_ops_template_) {
406 for (
int dep : rnn_op.dependencies) {
407 timestep_ops_template_[dep].parents.push_back(rnn_op.order);
419 auto& rnn_ops = timestep_ops_[t];
421 LOG(INFO) <<
"Timestep: " << t;
422 for (
auto& rnn_op : rnn_ops) {
423 auto& op = rnn_op.op;
424 LOG(INFO) <<
"Operator " << rnn_op.order <<
": " << op->type()
425 <<
" dep inputs:" << rnn_op.num_dynamic_inputs
426 <<
" rec inputs:" << rnn_op.num_recurrent_inputs
427 <<
" frontier: " << rnn_op.frontier;
428 for (
auto& inp : rnn_op.op->debug_def().input()) {
429 LOG(INFO) <<
" ---- input: " << inp;
431 for (
auto& outp : rnn_op.op->debug_def().output()) {
432 LOG(INFO) <<
" ---- output: " << outp;
434 for (
auto j : rnn_op.dependencies) {
435 LOG(INFO) <<
" dep: " << j <<
": " << rnn_ops[j].op->type();
437 for (
auto j : rnn_op.parents) {
438 LOG(INFO) <<
" parent: " << j <<
": " << rnn_ops[j].op->type();
442 LOG(INFO) <<
"recurrent_inputs:" << recurrent_input_map_;
444 for (
auto& rnn_op : rnn_ops) {
445 LOG(INFO) <<
"Operator " << rnn_op.order;
446 LOG(INFO) << ProtoDebugString(rnn_op.op->debug_def());
450 virtual void AnalyzeOps() {}
452 virtual bool ignoreLinkDependencies() = 0;
454 std::vector<std::vector<RNNNetOperator>> timestep_ops_;
455 std::vector<OperatorBase*> op_ptrs_;
457 std::vector<RNNNetOperator> timestep_ops_template_;
459 NetDef step_net_def_;
460 std::vector<std::vector<string>> op_deps_;
461 std::vector<Workspace*> workspaces_;
462 std::map<string, string> recurrent_input_map_;
463 std::string timestep_blob_;
465 int max_parallel_timesteps_ = -1;
471 template <
class Context>
472 std::unique_ptr<RecurrentNetworkExecutorBase> createRNNExecutor(
473 const NetDef& step_net_def,
474 std::map<string, string>& recurrent_input_map,
475 std::string timestep_blob,
481 const NetDef& step_net_def,
482 std::map<string, string>& recurrent_input_map,
483 std::string timestep_blob)
488 task_queue_.NoMoreJobs();
489 VLOG(1) <<
"Joining workers.";
490 for (
auto& worker : workers_) {
495 bool Run(
int T)
override;
497 bool RunBackwards(
int T)
override;
499 bool ignoreLinkDependencies()
override {
503 void setNumThreads(
int n) {
508 void _ExecRange(
int from,
int to);
512 void WorkerFunction();
514 void RunOp(
OpTask job,
int thread_id);
517 std::atomic<int> countdown_;
518 std::atomic<bool> failed_;
519 std::atomic<int> finished_timesteps_;
521 std::mutex countdown_mtx_;
522 std::condition_variable cv_;
523 std::vector<std::thread> workers_;
524 int num_threads_ = 4;
529 #endif // CAFFE2_OPERATORS_RECURRENT_NETWORK_EXECUTOR_H_ Blob * CreateBlob(const string &name)
Creates a blob of the given name.
RecurrentNetworkExecutor is a specialized runtime for recurrent neural networks (RNNs).
Struct for operator in a timestep and its dependenceis.
Data structure for a scheduled task in the task queue.
void EnsureTimestepInitialized(int t, Workspace *ws, const std::vector< std::unique_ptr< ObserverBase< OperatorBase >>> &observers_list)
Callers must call EnsureTimestepInitialized before starting execution for each of the relevant timest...
void PrintInfo(int t)
For debug purposes, print the dependency structure.
A helper class to index into arguments.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
void SetMaxParallelTimesteps(int p)
Set limit for the number of timesteps that run in parallel.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...