1 #include "caffe2/core/context_gpu.h" 2 #include "caffe2/core/cudnn_wrappers.h" 3 #include "caffe2/core/operator.h" 4 #include "caffe2/core/types.h" 12 template <
class... Args>
15 cudnn_wrapper_(&context_),
16 size_(OperatorBase::GetSingleArgument<int>(
"size", 0)),
17 alpha_(OperatorBase::GetSingleArgument<float>(
"alpha", 0)),
18 beta_(OperatorBase::GetSingleArgument<float>(
"beta", 0)),
19 bias_(OperatorBase::GetSingleArgument<float>(
"bias", 1)) {
20 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
22 CUDNN_ENFORCE(cudnnCreateLRNDescriptor(&norm_desc_));
24 cudnnSetLRNDescriptor(norm_desc_, size_, alpha_, beta_, bias_));
27 ~CuDNNLRNOp()
override {
28 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
29 CUDNN_ENFORCE(cudnnDestroyLRNDescriptor(norm_desc_));
32 template <
typename T,
typename M>
35 bool RunOnDevice()
override;
39 cudnnTensorDescriptor_t data_desc_;
40 cudnnLRNDescriptor_t norm_desc_;
42 vector<int64_t> cudnn_input_dims_;
55 template <
class... Args>
58 cudnn_wrapper_(&context_),
59 size_(OperatorBase::GetSingleArgument<int>(
"size", 0)),
60 alpha_(OperatorBase::GetSingleArgument<float>(
"alpha", 0)),
61 beta_(OperatorBase::GetSingleArgument<float>(
"beta", 0)),
62 bias_(OperatorBase::GetSingleArgument<float>(
"bias", 1)) {
63 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
65 CUDNN_ENFORCE(cudnnCreateLRNDescriptor(&norm_desc_));
67 cudnnSetLRNDescriptor(norm_desc_, size_, alpha_, beta_, bias_));
70 ~CuDNNLRNGradientOp()
override {
71 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
72 CUDNN_ENFORCE(cudnnDestroyLRNDescriptor(norm_desc_));
75 template <
typename T,
typename M>
78 bool RunOnDevice()
override;
82 cudnnTensorDescriptor_t data_desc_;
83 cudnnLRNDescriptor_t norm_desc_;
85 vector<int64_t> cudnn_input_dims_;
96 template <
typename T,
typename M>
97 bool CuDNNLRNOp::DoRunWithType() {
98 const auto& X =
Input(0);
102 if (X.sizes() != cudnn_input_dims_) {
103 VLOG(1) <<
"Setting descriptors";
104 cudnn_input_dims_ = X.sizes().vec();
105 int C = 1, H = 1, W = 1;
110 CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
121 CUDNN_ENFORCE(cudnnLRNCrossChannelForward(
122 cudnn_wrapper_.inline_cudnn_handle(),
124 CUDNN_LRN_CROSS_CHANNEL_DIM1,
127 X.template data<T>(),
130 Y->template mutable_data<T>()));
135 bool CuDNNLRNOp::RunOnDevice() {
137 const auto& X =
Input(0);
141 if (X.IsType<
float>()) {
142 return DoRunWithType<float, float>();
144 return DoRunWithType<at::Half, float>();
146 CAFFE_THROW(
"Unsupported input type");
151 template <
typename T,
typename M>
152 bool CuDNNLRNGradientOp::DoRunWithType() {
153 const auto& X =
Input(0);
154 const auto& Y =
Input(1);
155 const auto& dY =
Input(2);
156 auto* dX = Output(0);
158 if (dY.sizes() != cudnn_input_dims_) {
159 VLOG(1) <<
"Setting descriptors";
160 cudnn_input_dims_ = dY.sizes().vec();
161 int C = 1, H = 1, W = 1;
166 CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
177 CUDNN_ENFORCE(cudnnLRNCrossChannelBackward(
178 cudnn_wrapper_.inline_cudnn_handle(),
180 CUDNN_LRN_CROSS_CHANNEL_DIM1,
183 Y.template data<T>(),
185 dY.template data<T>(),
187 X.template data<T>(),
190 dX->template mutable_data<T>()));
194 bool CuDNNLRNGradientOp::RunOnDevice() {
196 const auto& X =
Input(0);
197 const auto& Y =
Input(1);
198 const auto& dY =
Input(2);
199 auto* dX = Output(0);
203 if (dY.IsType<
float>()) {
204 return DoRunWithType<float, float>();
206 return DoRunWithType<at::Half, float>();
208 CAFFE_THROW(
"Unsupported input type");
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
const Tensor & Input(int idx, DeviceType type=CUDAContext::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...