1 #include "caffe2/core/context_gpu.h" 2 #include "caffe2/core/cudnn_wrappers.h" 3 #include "caffe2/core/operator.h" 4 #include "caffe2/core/types.h" 10 #if CUDNN_VERSION_MIN(7,0,0) 12 class CuDNNDropoutOp final :
public Operator<CUDAContext> {
14 USE_OPERATOR_FUNCTIONS(CUDAContext);
16 explicit CuDNNDropoutOp(
const OperatorDef& operator_def, Workspace* ws)
17 : Operator<CUDAContext>(operator_def, ws),
18 cudnn_wrapper_(&context_),
19 ratio_(OperatorBase::GetSingleArgument<float>(
"ratio", 0.5)),
20 is_test_(OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
21 states_initialized_(false),
22 random_seed_(operator_def.device_option().random_seed()) {
23 CAFFE_ENFORCE_GE(ratio_, 0);
24 CAFFE_ENFORCE_LT(ratio_, 1);
25 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
27 CUDNN_ENFORCE(cudnnCreateDropoutDescriptor(&dropout_desc_));
28 CUDNN_ENFORCE(cudnnDropoutGetStatesSize(
29 cudnn_wrapper_.inline_cudnn_handle(),
30 reinterpret_cast<size_t*
>(&states_size_in_bytes_)));
33 scratch_blob_ = ws->CreateBlob(scratch_blob_name(operator_def.output(1)));
34 CAFFE_ENFORCE(scratch_blob_);
38 ~CuDNNDropoutOp() noexcept
override {
39 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
40 CUDNN_ENFORCE(cudnnDestroyDropoutDescriptor(dropout_desc_));
43 template <
typename T,
typename M>
46 bool RunOnDevice()
override;
48 static string scratch_blob_name(
string mask_blob_name) {
49 return "cudnn_dropout_scratch_" + mask_blob_name;
53 CuDNNWrapper cudnn_wrapper_;
54 cudnnTensorDescriptor_t data_desc_;
55 cudnnDropoutDescriptor_t dropout_desc_;
57 vector<int64_t> cudnn_input_dims_;
62 Blob* scratch_blob_ =
nullptr;
64 size_t states_size_in_bytes_, reserve_space_size_in_bytes_;
68 bool states_initialized_;
71 unsigned long long random_seed_;
74 class CuDNNDropoutGradientOp final :
public Operator<CUDAContext> {
76 USE_OPERATOR_FUNCTIONS(CUDAContext);
77 explicit CuDNNDropoutGradientOp(
const OperatorDef& operator_def, Workspace* ws)
78 : Operator<CUDAContext>(operator_def, ws),
79 cudnn_wrapper_(&context_),
80 ratio_(OperatorBase::GetSingleArgument<float>(
"ratio", 0.5)),
81 is_test_(OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
82 states_initialized_(false),
83 random_seed_(operator_def.device_option().random_seed()) {
84 CAFFE_ENFORCE_GE(ratio_, 0);
85 CAFFE_ENFORCE_LT(ratio_, 1);
86 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
88 CUDNN_ENFORCE(cudnnCreateDropoutDescriptor(&dropout_desc_));
89 CUDNN_ENFORCE(cudnnDropoutGetStatesSize(
90 cudnn_wrapper_.inline_cudnn_handle(),
91 reinterpret_cast<size_t*
>(&states_size_in_bytes_)));
95 ws->GetBlob(CuDNNDropoutOp::scratch_blob_name(operator_def.input(1)));
96 CAFFE_ENFORCE(scratch_blob_);
99 ~CuDNNDropoutGradientOp() noexcept
override {
100 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
101 CUDNN_ENFORCE(cudnnDestroyDropoutDescriptor(dropout_desc_));
104 template <
typename T,
typename M>
105 bool DoRunWithType();
107 bool RunOnDevice()
override;
110 CuDNNWrapper cudnn_wrapper_;
111 cudnnTensorDescriptor_t data_desc_;
112 cudnnDropoutDescriptor_t dropout_desc_;
114 vector<int64_t> cudnn_input_dims_;
121 size_t states_size_in_bytes_, reserve_space_size_in_bytes_;
125 bool states_initialized_;
127 unsigned long long random_seed_;
130 template <
typename T,
typename M>
131 bool CuDNNDropoutOp::DoRunWithType() {
132 const auto& X = Input(0);
136 for (
auto dim : X.sizes()) {
142 context_.CopySameDevice<
T>(
143 X.numel(), X.template data<T>(), Y->template mutable_data<T>());
148 if (X.sizes() != cudnn_input_dims_) {
149 CAFFE_ENFORCE(scratch_blob_);
150 Tensor* states = BlobGetMutableTensor(scratch_blob_, CUDA);
151 cudnn_input_dims_ = X.sizes().vec();
152 CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
155 cudnnTypeWrapper<T>::type,
162 CUDNN_ENFORCE(cudnnDropoutGetReserveSpaceSize(
163 data_desc_, &reserve_space_size_in_bytes_));
165 states->Resize(states_size_in_bytes_);
167 if (!states_initialized_) {
170 uint8_t* states_data = states->template mutable_data<uint8_t>();
173 std::lock_guard<std::mutex> lk(CUDAContext::mutex());
174 CUDNN_ENFORCE(cudnnSetDropoutDescriptor(
176 cudnn_wrapper_.inline_cudnn_handle(),
179 states_size_in_bytes_,
183 states_initialized_ =
true;
188 {
static_cast<int64_t
>(reserve_space_size_in_bytes_)},
189 at::dtype<uint8_t>());
190 CUDNN_ENFORCE(cudnnDropoutForward(
191 cudnn_wrapper_.inline_cudnn_handle(),
194 X.template data<T>(),
196 Y->template mutable_data<T>(),
197 mask->template mutable_data<uint8_t>(),
198 reserve_space_size_in_bytes_));
203 bool CuDNNDropoutOp::RunOnDevice() {
205 const auto& X = Input(0);
209 if (X.IsType<
float>()) {
210 return DoRunWithType<float, float>();
212 return DoRunWithType<at::Half, float>();
217 template <
typename T,
typename M>
218 bool CuDNNDropoutGradientOp::DoRunWithType() {
219 const auto& dY = Input(0);
220 const auto& mask = Input(1);
222 auto* dX = Output(0);
225 for (
auto dim : dY.sizes()) {
229 if (!states_initialized_) {
233 std::lock_guard<std::mutex> lk(CUDAContext::mutex());
234 CUDNN_ENFORCE(cudnnRestoreDropoutDescriptor(
236 cudnn_wrapper_.inline_cudnn_handle(),
238 const_cast<uint8_t*
>(states.data<uint8_t>()),
239 states_size_in_bytes_,
243 states_initialized_ =
true;
246 if (dY.sizes() != cudnn_input_dims_) {
247 cudnn_input_dims_ = dY.sizes().vec();
248 CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
251 cudnnTypeWrapper<T>::type,
258 CUDNN_ENFORCE(cudnnDropoutGetReserveSpaceSize(
259 data_desc_, &reserve_space_size_in_bytes_));
263 void* mask_data =
const_cast<void*
>(mask.raw_data());
264 CUDNN_ENFORCE(cudnnDropoutBackward(
265 cudnn_wrapper_.inline_cudnn_handle(),
270 dX->template mutable_data<T>(),
272 reserve_space_size_in_bytes_));
276 bool CuDNNDropoutGradientOp::RunOnDevice() {
278 const auto& dY = Input(0);
279 auto* dX = Output(0);
283 if (dY.IsType<
float>()) {
284 return DoRunWithType<float, float>();
286 return DoRunWithType<at::Half, float>();
292 REGISTER_CUDNN_OPERATOR(Dropout, CuDNNDropoutOp);
293 REGISTER_CUDNN_OPERATOR(DropoutGrad, CuDNNDropoutGradientOp);
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...