Caffe2 - C++ API
A deep learning, cross platform ML framework
recurrent_op_miopen.h
1 #ifndef CAFFE2_OPERATORS_RECURRENT_OP_MIOPEN_H_
2 #define CAFFE2_OPERATORS_RECURRENT_OP_MIOPEN_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/hip/context_gpu.h"
6 #include "caffe2/core/hip/miopen_wrapper.h"
7 #include "caffe2/core/logging.h"
8 #include "caffe2/core/operator.h"
9 
10 namespace caffe2 {
11 namespace detail {
12 
13 template <typename T>
15  public:
17  size_t n,
18  // dim and stride are not declared as const as opposed to cuDNN
19  // since miopenSetTensorDescriptor doesn't take const arguments
20  std::vector<int>& dim,
21  std::vector<int>& stride);
23  const miopenTensorDescriptor_t* descs() const {
24  return descs_.data();
25  }
26 
27  private:
28  std::vector<miopenTensorDescriptor_t> descs_;
29 };
30 
31 } // namespace detail
32 
33 template <typename T>
34 class RecurrentBaseOp : public Operator<HIPContext> {
35  public:
36  USE_OPERATOR_FUNCTIONS(HIPContext);
37  RecurrentBaseOp(const OperatorDef& operator_def, Workspace* ws);
38  virtual ~RecurrentBaseOp();
39 
40  protected:
41  void initialize(
42  const Tensor& input,
43  // If passed, reshapes to the appropriate size
44  Tensor* output = nullptr,
45  Tensor* hiddenOutput = nullptr,
46  Tensor* cellOutput = nullptr);
47 
48  MIOPENWrapper miopen_wrapper_;
49  miopenRNNDescriptor_t rnnDesc_;
50  miopenTensorDescriptor_t wDesc_;
51  miopenTensorDescriptor_t hxDesc_;
52  miopenTensorDescriptor_t cxDesc_;
53  miopenTensorDescriptor_t hyDesc_;
54  miopenTensorDescriptor_t cyDesc_;
55 
56  std::unique_ptr<detail::TensorDescriptors<T>> xDesc_;
57  std::unique_ptr<detail::TensorDescriptors<T>> yDesc_;
58 
59  std::vector<int64_t> cachedInputDims_;
60  size_t reserveNbytes_;
61  size_t miopenWsNbytes_;
62 
63  private:
64 };
65 
66 #define USE_RECURRENT_BASE_FUNCTIONS \
67  USE_OPERATOR_FUNCTIONS(HIPContext); \
68  using RecurrentBaseOp<T>::miopen_wrapper_; \
69  using RecurrentBaseOp<T>::rnnDesc_; \
70  using RecurrentBaseOp<T>::wDesc_; \
71  using RecurrentBaseOp<T>::hxDesc_; \
72  using RecurrentBaseOp<T>::cxDesc_; \
73  using RecurrentBaseOp<T>::hyDesc_; \
74  using RecurrentBaseOp<T>::cyDesc_; \
75  using RecurrentBaseOp<T>::xDesc_; \
76  using RecurrentBaseOp<T>::yDesc_; \
77  using RecurrentBaseOp<T>::cachedInputDims_; \
78  using RecurrentBaseOp<T>::reserveNbytes_; \
79  using RecurrentBaseOp<T>::miopenWsNbytes_; \
80  using RecurrentBaseOp<T>::initialize;
81 
82 template <typename T>
83 class RecurrentOp : public RecurrentBaseOp<T> {
84  public:
85  USE_RECURRENT_BASE_FUNCTIONS
86  RecurrentOp(const OperatorDef& operator_def, Workspace* ws)
87  : RecurrentBaseOp<T>(operator_def, ws) {}
88 
89  bool RunOnDevice() override;
90 
91  protected:
92  INPUT_TAGS(INPUT, HIDDEN_INPUT, CELL_INPUT, WEIGHT);
93  OUTPUT_TAGS(OUTPUT, HIDDEN_OUTPUT, CELL_OUTPUT, RNN_SCRATCH, DROPOUT_STATES);
94 };
95 
96 enum RecurrentParamOpMode { SET_PARAM, GET_PARAM };
97 
98 template <typename T, RecurrentParamOpMode mode>
100  public:
101  USE_RECURRENT_BASE_FUNCTIONS
102  RecurrentParamAccessOp(const OperatorDef& operator_def, Workspace* ws)
103  : RecurrentBaseOp<T>(operator_def, ws) {}
104 
105  bool RunOnDevice() override;
106 };
107 
108 template <typename T>
110  public:
111  USE_RECURRENT_BASE_FUNCTIONS
112  RecurrentGradientOp(const OperatorDef& operator_def, Workspace* ws)
113  : RecurrentBaseOp<T>(operator_def, ws) {}
114 
115  bool RunOnDevice() override;
116 
117  protected:
118  INPUT_TAGS(
119  INPUT,
120  HIDDEN_INPUT,
121  CELL_INPUT,
122  WEIGHT,
123  RNN_SCRATCH,
124  OUTPUT,
125  GRAD_OUTPUT,
126  GRAD_HIDDEN_OUTPUT,
127  GRAD_CELL_OUTPUT);
128  OUTPUT_TAGS(
129  GRAD_INPUT,
130  GRAD_HIDDEN_INPUT,
131  GRAD_CELL_INPUT,
132  GRAD_WEIGHT,
133  DROPOUT_STATES,
134  RNN_SCRATCH_OUT);
135 };
136 
137 
138 } // namespace caffe2
139 
140 #endif // CAFFE2_OPERATORS_RECURRENT_OP_MIOPEN_H_
MIOPENWrapper is a class that wraps the miopen handles and miopen workspaces.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13