Caffe2 - C++ API
A deep learning, cross platform ML framework
recurrent_network_executor_gpu.h
1 
17 #ifndef CAFFE2_OPERATORS_RECURRENT_NETWORK_GPU_EXECUTOR_H_
18 #define CAFFE2_OPERATORS_RECURRENT_NETWORK_GPU_EXECUTOR_H_
19 
20 #include "caffe2/core/context_gpu.h"
21 #include "caffe2/operators/recurrent_network_executor.h"
22 
23 
24 #include <map>
25 
26 namespace caffe2 {
27 
29  public:
31  const NetDef& step_net_def,
32  std::map<string, string>& recurrent_input_map,
33  std::string timestep_blob)
34  : RecurrentNetworkExecutorBase(step_net_def, recurrent_input_map, timestep_blob) {}
35 
37 
38  protected:
39  bool Run(int T) override;
40 
41  bool RunBackwards(int T) override;
42 
43  bool ignoreLinkDependencies() override {
44  return true;
45  }
46 
47  void AnalyzeOps() override {
55  has_timestep_parallelism_ = false;
56  for (auto& rnn_op : timestep_ops_template_) {
57  int i = rnn_op.order;
58  if (rnn_op.parents.size() >= 1 && i < timestep_ops_template_.size() - 1) {
59  bool only_recurrent_deps = std::all_of(
60  rnn_op.parents.begin(),
61  rnn_op.parents.end(), [&](const int &parent) {
62  return parent > i;
63  }
64  );
65  if (only_recurrent_deps) {
66  VLOG(1) << "Timestep parallel op: " << ProtoDebugString(step_net_def_.op(i));
67  has_timestep_parallelism_ = true;
68 
69  for (int dep : rnn_op.parents) {
70  if (dep == timestep_ops_template_.size() - 1) {
71  // This op depends on the last op of the previous iteration,
72  // so it will block any parallelism
73  has_timestep_parallelism_ = false;
74  break;
75  }
76  }
77  break;
78  }
79  }
80  }
81  LOG(INFO) << "Analyzed ops for timestep parallelism: " << has_timestep_parallelism_;
82  }
83 
84  public:
85 
86  void setMaxStreams(int n) {
87  max_cuda_streams_ = n;
88  }
89 
90  private:
91  void _ExecRange(int from, int to);
92 
93  std::vector<cudaEvent_t> events_;
94  bool has_timestep_parallelism_ = false;
95  int max_cuda_streams_ = 2;
96 };
97 }
98 #endif
RecurrentNetworkExecutor is a specialized runtime for recurrent neural networks (RNNs).
Copyright (c) 2016-present, Facebook, Inc.