Caffe2 - C++ API
A deep learning, cross platform ML framework
recurrent_op_cudnn.h
1 #ifndef CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_
2 #define CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/context_gpu.h"
6 #include "caffe2/core/cudnn_wrappers.h"
7 #include "caffe2/core/logging.h"
8 #include "caffe2/core/operator.h"
9 
10 namespace caffe2 {
11 namespace detail {
12 
13 template <typename T>
14 class TensorDescriptors {
15  public:
16  TensorDescriptors(
17  size_t n,
18  const std::vector<int>& dim,
19  const std::vector<int>& stride);
20  ~TensorDescriptors();
21  const cudnnTensorDescriptor_t* descs() const {
22  return descs_.data();
23  }
24 
25  private:
26  std::vector<cudnnTensorDescriptor_t> descs_;
27 };
28 
29 } // namespace detail
30 
31 template <typename T>
32 class RecurrentBaseOp : public Operator<CUDAContext> {
33  public:
34  USE_OPERATOR_FUNCTIONS(CUDAContext);
35  template<class... Args> explicit RecurrentBaseOp(Args&&... args)
36  : Operator<CUDAContext>(std::forward<Args>(args)...), cudnn_wrapper_(&context_) {
37  CUDNN_ENFORCE(cudnnCreateDropoutDescriptor(&dropoutDesc_));
38  CUDNN_ENFORCE(cudnnCreateRNNDescriptor(&rnnDesc_));
39  CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&wDesc_));
40  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&hxDesc_));
41  }
42  virtual ~RecurrentBaseOp();
43 
44  protected:
45  void initialize(
46  const Tensor& input,
47  Tensor* dropoutStates = nullptr,
48  // If passed, reshapes to the appropriate size
49  Tensor* output = nullptr,
50  Tensor* hiddenOutput = nullptr,
51  Tensor* cellOutput = nullptr);
52 
53  CuDNNWrapper cudnn_wrapper_;
54  cudnnDropoutDescriptor_t dropoutDesc_;
55  cudnnRNNDescriptor_t rnnDesc_;
56  cudnnFilterDescriptor_t wDesc_;
57  cudnnTensorDescriptor_t hxDesc_;
58  cudnnTensorDescriptor_t cxDesc_;
59  cudnnTensorDescriptor_t hyDesc_;
60  cudnnTensorDescriptor_t cyDesc_;
61 
62  std::unique_ptr<detail::TensorDescriptors<T>> xDesc_;
63  std::unique_ptr<detail::TensorDescriptors<T>> yDesc_;
64 
65  std::vector<int64_t> cachedInputDims_;
66  size_t reserveNbytes_;
67  size_t cudnnWsNbytes_;
68 
69  private:
70 };
71 
72 #define USE_RECURRENT_BASE_FUNCTIONS \
73  USE_OPERATOR_FUNCTIONS(CUDAContext); \
74  using RecurrentBaseOp<T>::cudnn_wrapper_; \
75  using RecurrentBaseOp<T>::dropoutDesc_; \
76  using RecurrentBaseOp<T>::rnnDesc_; \
77  using RecurrentBaseOp<T>::wDesc_; \
78  using RecurrentBaseOp<T>::hxDesc_; \
79  using RecurrentBaseOp<T>::cxDesc_; \
80  using RecurrentBaseOp<T>::hyDesc_; \
81  using RecurrentBaseOp<T>::cyDesc_; \
82  using RecurrentBaseOp<T>::xDesc_; \
83  using RecurrentBaseOp<T>::yDesc_; \
84  using RecurrentBaseOp<T>::cachedInputDims_; \
85  using RecurrentBaseOp<T>::reserveNbytes_; \
86  using RecurrentBaseOp<T>::cudnnWsNbytes_; \
87  using RecurrentBaseOp<T>::initialize;
88 
89 template <typename T>
90 class RecurrentOp : public RecurrentBaseOp<T> {
91  public:
92  USE_RECURRENT_BASE_FUNCTIONS
93  template <class... Args>
94  explicit RecurrentOp(Args&&... args)
95  : RecurrentBaseOp<T>(std::forward<Args>(args)...) {}
96 
97  bool RunOnDevice() override;
98 
99  protected:
100  INPUT_TAGS(INPUT, HIDDEN_INPUT, CELL_INPUT, WEIGHT);
101  OUTPUT_TAGS(OUTPUT, HIDDEN_OUTPUT, CELL_OUTPUT, RNN_SCRATCH, DROPOUT_STATES);
102 };
103 
104 enum RecurrentParamOpMode { SET_PARAM, GET_PARAM };
105 
106 template <typename T, RecurrentParamOpMode mode>
107 class RecurrentParamAccessOp : public RecurrentBaseOp<T> {
108  public:
109  USE_RECURRENT_BASE_FUNCTIONS
110  template <class... Args>
111  explicit RecurrentParamAccessOp(Args&&... args)
112  : RecurrentBaseOp<T>(std::forward<Args>(args)...) {}
113 
114  bool RunOnDevice() override;
115 };
116 
117 template <typename T>
118 class RecurrentGradientOp : public RecurrentBaseOp<T> {
119  public:
120  USE_RECURRENT_BASE_FUNCTIONS
121  template <class... Args>
122  explicit RecurrentGradientOp(Args&&... args)
123  : RecurrentBaseOp<T>(std::forward<Args>(args)...) {}
124 
125  bool RunOnDevice() override;
126 
127  protected:
128  INPUT_TAGS(
129  INPUT,
130  HIDDEN_INPUT,
131  CELL_INPUT,
132  WEIGHT,
133  RNN_SCRATCH,
134  OUTPUT,
135  GRAD_OUTPUT,
136  GRAD_HIDDEN_OUTPUT,
137  GRAD_CELL_OUTPUT);
138  OUTPUT_TAGS(
139  GRAD_INPUT,
140  GRAD_HIDDEN_INPUT,
141  GRAD_CELL_INPUT,
142  GRAD_WEIGHT,
143  DROPOUT_STATES,
144  RNN_SCRATCH_OUT);
145 };
146 
147 
148 } // namespace caffe2
149 
150 #endif // CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13