3 #include <ATen/cuda/CUDAContext.h> 4 #include <ATen/cuda/Exceptions.h> 6 #include <ATen/cudnn/cudnn-wrapper.h> 8 #include <ATen/TensorUtils.h> 9 #include <ATen/cuda/ATenCUDAGeneral.h> 12 namespace at {
namespace native {
16 inline int dataSize(cudnnDataType_t dataType)
19 case CUDNN_DATA_HALF:
return 2;
20 case CUDNN_DATA_FLOAT:
return 4;
35 static inline void fixSizeOneDimStride(
int dim,
const int *size,
int *stride) {
37 for(
int d = dim-1; d >= 0; d--)
47 template <
typename T, cudnnStatus_t (*dtor)(T*)>
49 void operator()(
T* x) {
51 AT_CUDNN_CHECK(dtor(x));
65 template <
typename T, cudnnStatus_t (*ctor)(T**), cudnnStatus_t (*dtor)(T*)>
75 T* desc()
const {
return desc_.get(); }
76 T* desc() {
return desc_.get(); }
82 T* mut_desc() { init();
return desc_.get(); }
85 if (desc_ ==
nullptr) {
87 AT_CUDNN_CHECK(ctor(&raw_desc));
88 desc_.reset(raw_desc);
92 std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
97 &cudnnCreateTensorDescriptor,
98 &cudnnDestroyTensorDescriptor>
119 void set(
const at::Tensor &t,
size_t pad = 0);
125 void set(cudnnDataType_t dataType,
int dim,
int* size,
int* stride) {
126 fixSizeOneDimStride(dim, size, stride);
127 AT_CUDNN_CHECK(cudnnSetTensorNdDescriptor(mut_desc(), dataType, dim, size, stride));
135 &cudnnCreateFilterDescriptor,
136 &cudnnDestroyFilterDescriptor>
139 void set(
const at::Tensor &t, int64_t pad = 0);
142 void set(cudnnDataType_t dataType,
int dim,
int* size) {
143 AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, CUDNN_TENSOR_NCHW, dim, size));
149 &cudnnCreateConvolutionDescriptor,
150 &cudnnDestroyConvolutionDescriptor>
152 void set(cudnnDataType_t dataType,
int dim,
int* pad,
int* stride,
int * upscale ,
int groups) {
153 cudnnDataType_t mathType = dataType;
154 if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT;
155 AT_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale,
156 CUDNN_CROSS_CORRELATION, mathType));
157 AT_CUDNN_CHECK(cudnnSetConvolutionGroupCount(mut_desc(), groups));
159 AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH));
160 if(dataType == CUDNN_DATA_HALF)
161 AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH));
167 :
public Descriptor<cudnnSpatialTransformerStruct,
168 &cudnnCreateSpatialTransformerDescriptor,
169 &cudnnDestroySpatialTransformerDescriptor>
171 void set(cudnnDataType_t dataType,
int dim,
int* size) {
172 AT_CUDNN_CHECK(cudnnSetSpatialTransformerNdDescriptor(mut_desc(), CUDNN_SAMPLER_BILINEAR, dataType, dim, size));
178 &cudnnCreateDropoutDescriptor,
179 &cudnnDestroyDropoutDescriptor>
186 void initialize_rng(cudnnHandle_t handle,
float dropout,
long long int seed,
const TensorOptions& options) {
187 AT_ASSERTM(dropout > 0,
"dropout must be nonzero; otherwise call set_no_dropout");
189 AT_CUDNN_CHECK(cudnnDropoutGetStatesSize(handle, &state_size));
190 AT_ASSERT(options.
device().type() == kCUDA);
191 AT_ASSERT(options.
dtype() == kByte);
192 state = at::empty({
static_cast<int64_t
>(state_size)}, options);
193 AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed));
198 void set(cudnnHandle_t handle,
float dropout,
at::Tensor state_) {
199 AT_ASSERTM(dropout > 0,
"dropout must be nonzero; otherwise call set_no_dropout");
201 void *state_ptr = state.data_ptr();
202 size_t state_size = state.size(0);
204 AT_CUDNN_CHECK(cudnnRestoreDropoutDescriptor(mut_desc(), handle, dropout, state_ptr, state_size, 0 ));
209 void set_no_dropout(cudnnHandle_t handle) {
214 AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, 0 ,
nullptr, 0 , 0 ));
220 &cudnnCreateRNNDescriptor,
221 &cudnnDestroyRNNDescriptor>
224 void set(cudnnHandle_t handle,
int hidden_size,
int num_layers,
DropoutDescriptor&& dropout_desc,
225 cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional,
226 cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo) {
227 dropout_desc_ = std::move(dropout_desc);
228 AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
233 dropout_desc_.desc(),
239 #if CUDA_VERSION >= 9000 240 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
241 if (prop->major >= 7) {
242 if (input_type == CUDNN_DATA_HALF) {
243 cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH);
247 cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_DEFAULT_MATH);
256 &cudnnCreateCTCLossDescriptor,
257 &cudnnDestroyCTCLossDescriptor>
259 void set(cudnnDataType_t datatype) {
260 AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype));
268 Constant(cudnnDataType_t dataType,
double value) {
269 if (dataType == CUDNN_DATA_HALF || dataType == CUDNN_DATA_FLOAT) {
270 f =
static_cast<float>(value);
C10_NODISCARD TensorOptions device(c10::optional< Device > device) const noexcept
Return a copy of TensorOptions with device set to the given one, or cleared if device is nullopt...
C10_NODISCARD TensorOptions dtype(c10::optional< caffe2::TypeMeta > dtype) const noexcept
Return a copy of TensorOptions with dtype set to the given one.
Flush-To-Zero and Denormals-Are-Zero mode.