Caffe2 - C++ API
A deep learning, cross platform ML framework
recurrent_op_cudnn.h
1 
17 #ifndef CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_
18 #define CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/context_gpu.h"
22 #include "caffe2/core/cudnn_wrappers.h"
23 #include "caffe2/core/logging.h"
24 #include "caffe2/core/operator.h"
25 
26 namespace caffe2 {
27 namespace detail {
28 
29 template <typename T>
31  public:
33  size_t n,
34  const std::vector<int>& dim,
35  const std::vector<int>& stride);
37  const cudnnTensorDescriptor_t* descs() const {
38  return descs_.data();
39  }
40 
41  private:
42  std::vector<cudnnTensorDescriptor_t> descs_;
43 };
44 
45 } // namespace detail
46 
47 template <typename T>
48 class RecurrentBaseOp : public Operator<CUDAContext> {
49  public:
50  USE_OPERATOR_FUNCTIONS(CUDAContext);
51  RecurrentBaseOp(const OperatorDef& operator_def, Workspace* ws);
52  virtual ~RecurrentBaseOp();
53 
54  protected:
55  void initialize(
56  const Tensor<CUDAContext>& input,
57  Tensor<CUDAContext>* dropoutStates = nullptr,
58  // If passed, reshapes to the appropriate size
59  Tensor<CUDAContext>* output = nullptr,
60  Tensor<CUDAContext>* hiddenOutput = nullptr,
61  Tensor<CUDAContext>* cellOutput = nullptr);
62 
63  CuDNNWrapper cudnn_wrapper_;
64  cudnnDropoutDescriptor_t dropoutDesc_;
65  cudnnRNNDescriptor_t rnnDesc_;
66  cudnnFilterDescriptor_t wDesc_;
67  cudnnTensorDescriptor_t hxDesc_;
68  cudnnTensorDescriptor_t cxDesc_;
69  cudnnTensorDescriptor_t hyDesc_;
70  cudnnTensorDescriptor_t cyDesc_;
71 
72  std::unique_ptr<detail::TensorDescriptors<T>> xDesc_;
73  std::unique_ptr<detail::TensorDescriptors<T>> yDesc_;
74 
75  std::vector<TIndex> cachedInputDims_;
76  size_t reserveNbytes_;
77  size_t cudnnWsNbytes_;
78 
79  private:
80 };
81 
82 #define USE_RECURRENT_BASE_FUNCTIONS \
83  USE_OPERATOR_FUNCTIONS(CUDAContext); \
84  using RecurrentBaseOp<T>::cudnn_wrapper_; \
85  using RecurrentBaseOp<T>::dropoutDesc_; \
86  using RecurrentBaseOp<T>::rnnDesc_; \
87  using RecurrentBaseOp<T>::wDesc_; \
88  using RecurrentBaseOp<T>::hxDesc_; \
89  using RecurrentBaseOp<T>::cxDesc_; \
90  using RecurrentBaseOp<T>::hyDesc_; \
91  using RecurrentBaseOp<T>::cyDesc_; \
92  using RecurrentBaseOp<T>::xDesc_; \
93  using RecurrentBaseOp<T>::yDesc_; \
94  using RecurrentBaseOp<T>::cachedInputDims_; \
95  using RecurrentBaseOp<T>::reserveNbytes_; \
96  using RecurrentBaseOp<T>::cudnnWsNbytes_; \
97  using RecurrentBaseOp<T>::initialize;
98 
99 template <typename T>
100 class RecurrentOp : public RecurrentBaseOp<T> {
101  public:
102  USE_RECURRENT_BASE_FUNCTIONS
103  RecurrentOp(const OperatorDef& operator_def, Workspace* ws)
104  : RecurrentBaseOp<T>(operator_def, ws) {}
105 
106  bool RunOnDevice() override;
107 
108  protected:
109  INPUT_TAGS(INPUT, HIDDEN_INPUT, CELL_INPUT, WEIGHT);
110  OUTPUT_TAGS(OUTPUT, HIDDEN_OUTPUT, CELL_OUTPUT, RNN_SCRATCH, DROPOUT_STATES);
111 };
112 
113 enum RecurrentParamOpMode { SET_PARAM, GET_PARAM };
114 
115 template <typename T, RecurrentParamOpMode mode>
117  public:
118  USE_RECURRENT_BASE_FUNCTIONS
119  RecurrentParamAccessOp(const OperatorDef& operator_def, Workspace* ws)
120  : RecurrentBaseOp<T>(operator_def, ws) {}
121 
122  bool RunOnDevice() override;
123 };
124 
125 template <typename T>
127  public:
128  USE_RECURRENT_BASE_FUNCTIONS
129  RecurrentGradientOp(const OperatorDef& operator_def, Workspace* ws)
130  : RecurrentBaseOp<T>(operator_def, ws) {}
131 
132  bool RunOnDevice() override;
133 
134  protected:
135  INPUT_TAGS(
136  INPUT,
137  HIDDEN_INPUT,
138  CELL_INPUT,
139  WEIGHT,
140  RNN_SCRATCH,
141  OUTPUT,
142  GRAD_OUTPUT,
143  GRAD_HIDDEN_OUTPUT,
144  GRAD_CELL_OUTPUT);
145  OUTPUT_TAGS(
146  GRAD_INPUT,
147  GRAD_HIDDEN_INPUT,
148  GRAD_CELL_INPUT,
149  GRAD_WEIGHT,
150  DROPOUT_STATES,
151  RNN_SCRATCH_OUT);
152 };
153 
154 
155 } // namespace caffe2
156 
157 #endif // CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.