Caffe2 - C++ API
A deep learning, cross platform ML framework
activation_ops_miopen.h
1 #ifndef CAFFE2_OPERATORS_ACTIVATION_OPS_MIOPEN_H_
2 #define CAFFE2_OPERATORS_ACTIVATION_OPS_MIOPEN_H_
3 
4 #include "caffe2/core/hip/context_gpu.h"
5 #include "caffe2/core/hip/miopen_wrapper.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/core/tensor.h"
8 #include "caffe2/core/types.h"
9 
10 namespace caffe2 {
11 
12 class MIOPENActivationOpBase : public Operator<HIPContext> {
13  public:
14  USE_OPERATOR_FUNCTIONS(HIPContext);
15 
16  MIOPENActivationOpBase(const OperatorDef& operator_def, Workspace* ws)
17  : Operator<HIPContext>(operator_def, ws),
18  miopen_wrapper_(&context_) {
19  MIOPEN_ENFORCE(miopenCreateTensorDescriptor(&data_desc_));
20  MIOPEN_ENFORCE(miopenCreateActivationDescriptor(&act_desc_));
21  }
22 
23  virtual ~MIOPENActivationOpBase() {
24  MIOPEN_ENFORCE(miopenDestroyTensorDescriptor(data_desc_));
25  MIOPEN_ENFORCE(miopenDestroyActivationDescriptor(act_desc_));
26  }
27 
28  protected:
29  MIOPENWrapper miopen_wrapper_;
30  miopenTensorDescriptor_t data_desc_;
31  miopenActivationDescriptor_t act_desc_;
32  vector<int64_t> mio_dims_;
33 
34 };
35 
36 template <miopenActivationMode_t kMIOPENActivationMode>
38  public:
39  USE_OPERATOR_FUNCTIONS(HIPContext);
40 
41  MIOPENActivationOp(const OperatorDef& operator_def, Workspace* ws)
42  : MIOPENActivationOpBase(operator_def, ws) {
43  MIOPEN_ENFORCE(miopenSetActivationDescriptor(
44  act_desc_, kMIOPENActivationMode, 1.0, 1.0, 1.0));
45  }
46 
47  bool RunOnDevice() override {
48  return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
49  }
50 
51  template <typename T>
52  bool DoRunWithType() {
53  const auto& X = Input(0);
54  auto* Y = Output(0);
55  Y->ResizeLike(X);
56  if (X.size() == 0) {
57  Y->template mutable_data<T>();
58  return true;
59  }
60  // See if we need to reshape.
61  if (X.sizes() != mio_dims_) {
62  VLOG(1) << "Setting descriptors.";
63  mio_dims_ = X.sizes().vec();
64  int C = 1, H = 1, W = 1;
65  if (X.ndim() == 4) {
66  // Normal 4-dimensional tensors for images.
67  C = X.dim32(1);
68  H = X.dim32(2);
69  W = X.dim32(3);
70  } else {
71  // If X is not 4-dimensional, we will simply use H = 1 and W = 1
72  // and wrap everything into C.
73  C = X.size() / X.dim32(0);
74  }
75  MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
76  data_desc_, miopenTypeWrapper<T>::type, X.dim32(0), C, H, W));
77  }
78  MIOPEN_ENFORCE(miopenActivationForward(
79  this->miopen_wrapper_.inline_miopen_handle(),
80  this->act_desc_,
82  this->data_desc_,
83  X.template data<T>(),
85  this->data_desc_,
86  Y->template mutable_data<T>()));
87  return true;
88  }
89 };
90 
91 template <miopenActivationMode_t kMIOPENActivationMode>
93  public:
94  USE_OPERATOR_FUNCTIONS(HIPContext);
95 
96  MIOPENActivationGradientOp(const OperatorDef& operator_def, Workspace* ws)
97  : MIOPENActivationOpBase(operator_def, ws) {
98  MIOPEN_ENFORCE(miopenSetActivationDescriptor(
99  act_desc_, kMIOPENActivationMode, 1.0, 1.0, 1.0));
100  }
101 
102  bool RunOnDevice() override {
103  return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
104  }
105 
106  template <typename T>
107  bool DoRunWithType() {
108  const auto& Y = Input(0);
109  const auto& dY = Input(1);
110  auto* dX = Output(0);
111  dX->ResizeLike(Y);
112  if (Y.size() == 0) {
113  dX->template mutable_data<T>();
114  return true;
115  }
116  // See if we need to reshape.
117  if (Y.sizes() != mio_dims_) {
118  VLOG(1) << "Setting descriptors.";
119  mio_dims_ = Y.sizes().vec();
120  int C = 1, H = 1, W = 1;
121  if (Y.ndim() == 4) {
122  // Normal 4-dimensional tensors for images.
123  C = Y.dim32(1);
124  H = Y.dim32(2);
125  W = Y.dim32(3);
126  } else {
127  // If Y is not 4-dimensional, we will simply use H = 1 and W = 1
128  // and wrap everything into C.
129  C = Y.size() / Y.dim32(0);
130  }
131  MIOPEN_ENFORCE(miopenSet4dTensorDescriptor(
132  data_desc_, miopenTypeWrapper<T>::type, Y.dim32(0), C, H, W));
133  }
134  MIOPEN_ENFORCE(miopenActivationBackward(
135  this->miopen_wrapper_.inline_miopen_handle(),
136  this->act_desc_,
138  this->data_desc_,
139  Y.template data<T>(),
140  this->data_desc_,
141  dY.template data<T>(),
142  this->data_desc_,
143  Y.template data<T>(),
145  this->data_desc_,
146  dX->template mutable_data<T>()));
147  return true;
148  }
149 };
150 
151 } // namespace caffe2
152 
153 #endif // CAFFE2_OPERATORS_ACTIVATION_OPS_MIOPEN_H_
miopenTypeWrapper is a wrapper class that allows us to refer to the miopen type in a template functio...
Definition: common_miopen.h:90
MIOPENWrapper is a class that wraps the miopen handles and miopen workspaces.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=HIPContext::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64