1 #ifndef CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_ 2 #define CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/context_gpu.h" 6 #include "caffe2/core/cudnn_wrappers.h" 7 #include "caffe2/core/logging.h" 8 #include "caffe2/core/operator.h" 14 class TensorDescriptors {
18 const std::vector<int>& dim,
19 const std::vector<int>& stride);
21 const cudnnTensorDescriptor_t* descs()
const {
26 std::vector<cudnnTensorDescriptor_t> descs_;
32 class RecurrentBaseOp :
public Operator<CUDAContext> {
34 USE_OPERATOR_FUNCTIONS(CUDAContext);
35 template<
class... Args>
explicit RecurrentBaseOp(Args&&... args)
36 : Operator<CUDAContext>(
std::forward<Args>(args)...), cudnn_wrapper_(&context_) {
37 CUDNN_ENFORCE(cudnnCreateDropoutDescriptor(&dropoutDesc_));
38 CUDNN_ENFORCE(cudnnCreateRNNDescriptor(&rnnDesc_));
39 CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&wDesc_));
40 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&hxDesc_));
42 virtual ~RecurrentBaseOp();
47 Tensor* dropoutStates =
nullptr,
50 Tensor* hiddenOutput =
nullptr,
51 Tensor* cellOutput =
nullptr);
53 CuDNNWrapper cudnn_wrapper_;
54 cudnnDropoutDescriptor_t dropoutDesc_;
55 cudnnRNNDescriptor_t rnnDesc_;
56 cudnnFilterDescriptor_t wDesc_;
57 cudnnTensorDescriptor_t hxDesc_;
58 cudnnTensorDescriptor_t cxDesc_;
59 cudnnTensorDescriptor_t hyDesc_;
60 cudnnTensorDescriptor_t cyDesc_;
62 std::unique_ptr<detail::TensorDescriptors<T>> xDesc_;
63 std::unique_ptr<detail::TensorDescriptors<T>> yDesc_;
65 std::vector<int64_t> cachedInputDims_;
66 size_t reserveNbytes_;
67 size_t cudnnWsNbytes_;
72 #define USE_RECURRENT_BASE_FUNCTIONS \ 73 USE_OPERATOR_FUNCTIONS(CUDAContext); \ 74 using RecurrentBaseOp<T>::cudnn_wrapper_; \ 75 using RecurrentBaseOp<T>::dropoutDesc_; \ 76 using RecurrentBaseOp<T>::rnnDesc_; \ 77 using RecurrentBaseOp<T>::wDesc_; \ 78 using RecurrentBaseOp<T>::hxDesc_; \ 79 using RecurrentBaseOp<T>::cxDesc_; \ 80 using RecurrentBaseOp<T>::hyDesc_; \ 81 using RecurrentBaseOp<T>::cyDesc_; \ 82 using RecurrentBaseOp<T>::xDesc_; \ 83 using RecurrentBaseOp<T>::yDesc_; \ 84 using RecurrentBaseOp<T>::cachedInputDims_; \ 85 using RecurrentBaseOp<T>::reserveNbytes_; \ 86 using RecurrentBaseOp<T>::cudnnWsNbytes_; \ 87 using RecurrentBaseOp<T>::initialize; 90 class RecurrentOp :
public RecurrentBaseOp<T> {
92 USE_RECURRENT_BASE_FUNCTIONS
93 template <
class... Args>
94 explicit RecurrentOp(Args&&... args)
95 : RecurrentBaseOp<
T>(
std::forward<Args>(args)...) {}
97 bool RunOnDevice()
override;
100 INPUT_TAGS(INPUT, HIDDEN_INPUT, CELL_INPUT, WEIGHT);
101 OUTPUT_TAGS(OUTPUT, HIDDEN_OUTPUT, CELL_OUTPUT, RNN_SCRATCH, DROPOUT_STATES);
104 enum RecurrentParamOpMode { SET_PARAM, GET_PARAM };
106 template <
typename T, RecurrentParamOpMode mode>
107 class RecurrentParamAccessOp :
public RecurrentBaseOp<T> {
109 USE_RECURRENT_BASE_FUNCTIONS
110 template <
class... Args>
111 explicit RecurrentParamAccessOp(Args&&... args)
112 : RecurrentBaseOp<
T>(
std::forward<Args>(args)...) {}
114 bool RunOnDevice()
override;
117 template <
typename T>
118 class RecurrentGradientOp :
public RecurrentBaseOp<T> {
120 USE_RECURRENT_BASE_FUNCTIONS
121 template <
class... Args>
122 explicit RecurrentGradientOp(Args&&... args)
123 : RecurrentBaseOp<
T>(
std::forward<Args>(args)...) {}
125 bool RunOnDevice()
override;
150 #endif // CAFFE2_OPERATORS_RECURRENT_OP_CUDNN_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...