Caffe2 - C++ API
A deep learning, cross platform ML framework
recurrent_network_executor_gpu.h
1 #ifndef CAFFE2_OPERATORS_RECURRENT_NETWORK_GPU_EXECUTOR_H_
2 #define CAFFE2_OPERATORS_RECURRENT_NETWORK_GPU_EXECUTOR_H_
3 
4 #include "caffe2/core/context_gpu.h"
5 #include "caffe2/operators/rnn/recurrent_network_executor.h"
6 
7 
8 #include <map>
9 
10 namespace caffe2 {
11 
13  public:
15  const NetDef& step_net_def,
16  std::map<string, string>& recurrent_input_map,
17  std::string timestep_blob)
18  : RecurrentNetworkExecutorBase(step_net_def, recurrent_input_map, timestep_blob) {}
19 
21 
22  protected:
23  bool Run(int T) override;
24 
25  bool RunBackwards(int T) override;
26 
27  bool ignoreLinkDependencies() override {
28  return true;
29  }
30 
31  void AnalyzeOps() override {
39  has_timestep_parallelism_ = false;
40  for (auto& rnn_op : timestep_ops_template_) {
41  int i = rnn_op.order;
42  if (rnn_op.parents.size() >= 1 && i < timestep_ops_template_.size() - 1) {
43  bool only_recurrent_deps = std::all_of(
44  rnn_op.parents.begin(),
45  rnn_op.parents.end(), [&](const int &parent) {
46  return parent > i;
47  }
48  );
49  if (only_recurrent_deps) {
50  VLOG(1) << "Timestep parallel op: " << ProtoDebugString(step_net_def_.op(i));
51  has_timestep_parallelism_ = true;
52 
53  for (int dep : rnn_op.parents) {
54  if (dep == timestep_ops_template_.size() - 1) {
55  // This op depends on the last op of the previous iteration,
56  // so it will block any parallelism
57  has_timestep_parallelism_ = false;
58  break;
59  }
60  }
61  break;
62  }
63  }
64  }
65  LOG(INFO) << "Analyzed ops for timestep parallelism: " << has_timestep_parallelism_;
66  }
67 
68  public:
69 
70  void setMaxStreams(int n) {
71  max_cuda_streams_ = n;
72  }
73 
74  private:
75  void _ExecRange(int from, int to);
76 
77  std::vector<cudaEvent_t> events_;
78  bool has_timestep_parallelism_ = false;
79  int max_cuda_streams_ = 2;
80 };
81 }
82 #endif
RecurrentNetworkExecutor is a specialized runtime for recurrent neural networks (RNNs).
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13