1 #ifndef CAFFE2_OPERATORS_RECURRENT_OP_MIOPEN_H_ 2 #define CAFFE2_OPERATORS_RECURRENT_OP_MIOPEN_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/hip/context_gpu.h" 6 #include "caffe2/core/hip/miopen_wrapper.h" 7 #include "caffe2/core/logging.h" 8 #include "caffe2/core/operator.h" 20 std::vector<int>& dim,
21 std::vector<int>& stride);
23 const miopenTensorDescriptor_t* descs()
const {
28 std::vector<miopenTensorDescriptor_t> descs_;
36 USE_OPERATOR_FUNCTIONS(HIPContext);
45 Tensor* hiddenOutput =
nullptr,
46 Tensor* cellOutput =
nullptr);
49 miopenRNNDescriptor_t rnnDesc_;
50 miopenTensorDescriptor_t wDesc_;
51 miopenTensorDescriptor_t hxDesc_;
52 miopenTensorDescriptor_t cxDesc_;
53 miopenTensorDescriptor_t hyDesc_;
54 miopenTensorDescriptor_t cyDesc_;
56 std::unique_ptr<detail::TensorDescriptors<T>> xDesc_;
57 std::unique_ptr<detail::TensorDescriptors<T>> yDesc_;
59 std::vector<int64_t> cachedInputDims_;
60 size_t reserveNbytes_;
61 size_t miopenWsNbytes_;
66 #define USE_RECURRENT_BASE_FUNCTIONS \ 67 USE_OPERATOR_FUNCTIONS(HIPContext); \ 68 using RecurrentBaseOp<T>::miopen_wrapper_; \ 69 using RecurrentBaseOp<T>::rnnDesc_; \ 70 using RecurrentBaseOp<T>::wDesc_; \ 71 using RecurrentBaseOp<T>::hxDesc_; \ 72 using RecurrentBaseOp<T>::cxDesc_; \ 73 using RecurrentBaseOp<T>::hyDesc_; \ 74 using RecurrentBaseOp<T>::cyDesc_; \ 75 using RecurrentBaseOp<T>::xDesc_; \ 76 using RecurrentBaseOp<T>::yDesc_; \ 77 using RecurrentBaseOp<T>::cachedInputDims_; \ 78 using RecurrentBaseOp<T>::reserveNbytes_; \ 79 using RecurrentBaseOp<T>::miopenWsNbytes_; \ 80 using RecurrentBaseOp<T>::initialize; 85 USE_RECURRENT_BASE_FUNCTIONS
89 bool RunOnDevice()
override;
92 INPUT_TAGS(INPUT, HIDDEN_INPUT, CELL_INPUT, WEIGHT);
93 OUTPUT_TAGS(OUTPUT, HIDDEN_OUTPUT, CELL_OUTPUT, RNN_SCRATCH, DROPOUT_STATES);
96 enum RecurrentParamOpMode { SET_PARAM, GET_PARAM };
98 template <
typename T, RecurrentParamOpMode mode>
101 USE_RECURRENT_BASE_FUNCTIONS
105 bool RunOnDevice()
override;
108 template <
typename T>
111 USE_RECURRENT_BASE_FUNCTIONS
115 bool RunOnDevice()
override;
140 #endif // CAFFE2_OPERATORS_RECURRENT_OP_MIOPEN_H_
MIOPENWrapper is a class that wraps the miopen handles and miopen workspaces.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...