1 #ifndef CAFFE2_CORE_COMMON_CUDNN_H_ 2 #define CAFFE2_CORE_COMMON_CUDNN_H_ 7 #include "caffe2/core/common.h" 8 #include "caffe2/core/context.h" 9 #include "caffe2/core/logging.h" 10 #include "caffe2/core/types.h" 12 #ifndef CAFFE2_USE_CUDNN 13 #error("This Caffe2 install is not built with cudnn, so you should not include this file."); 19 CUDNN_VERSION >= 5000,
20 "Caffe2 requires cudnn version 5.0 or above.");
22 #if CUDNN_VERSION < 6000 23 #pragma message "CUDNN version under 6.0 is supported at best effort." 24 #pragma message "We strongly encourage you to move to 6.0 and above." 25 #pragma message "This message is intended to annoy you enough to update." 26 #endif // CUDNN_VERSION < 6000 28 #define CUDNN_VERSION_MIN(major, minor, patch) \ 29 (CUDNN_VERSION >= ((major) * 1000 + (minor) * 100 + (patch))) 37 inline const char* cudnnGetErrorString(cudnnStatus_t status) {
39 case CUDNN_STATUS_SUCCESS:
40 return "CUDNN_STATUS_SUCCESS";
41 case CUDNN_STATUS_NOT_INITIALIZED:
42 return "CUDNN_STATUS_NOT_INITIALIZED";
43 case CUDNN_STATUS_ALLOC_FAILED:
44 return "CUDNN_STATUS_ALLOC_FAILED";
45 case CUDNN_STATUS_BAD_PARAM:
46 return "CUDNN_STATUS_BAD_PARAM";
47 case CUDNN_STATUS_INTERNAL_ERROR:
48 return "CUDNN_STATUS_INTERNAL_ERROR";
49 case CUDNN_STATUS_INVALID_VALUE:
50 return "CUDNN_STATUS_INVALID_VALUE";
51 case CUDNN_STATUS_ARCH_MISMATCH:
52 return "CUDNN_STATUS_ARCH_MISMATCH";
53 case CUDNN_STATUS_MAPPING_ERROR:
54 return "CUDNN_STATUS_MAPPING_ERROR";
55 case CUDNN_STATUS_EXECUTION_FAILED:
56 return "CUDNN_STATUS_EXECUTION_FAILED";
57 case CUDNN_STATUS_NOT_SUPPORTED:
58 return "CUDNN_STATUS_NOT_SUPPORTED";
59 case CUDNN_STATUS_LICENSE_ERROR:
60 return "CUDNN_STATUS_LICENSE_ERROR";
62 return "Unknown cudnn error number";
69 #define CUDNN_ENFORCE(condition) \ 71 cudnnStatus_t status = condition; \ 74 CUDNN_STATUS_SUCCESS, \ 80 ::caffe2::internal::cudnnGetErrorString(status)); \ 82 #define CUDNN_CHECK(condition) \ 84 cudnnStatus_t status = condition; \ 85 CHECK(status == CUDNN_STATUS_SUCCESS) \ 86 << ::caffe2::internal::cudnnGetErrorString(status); \ 90 inline size_t cudnnCompiledVersion() {
94 inline size_t cudnnRuntimeVersion() {
95 return cudnnGetVersion();
99 inline void CheckCuDNNVersions() {
105 bool version_match = cudnnCompiledVersion() == cudnnRuntimeVersion();
106 bool compiled_with_7 = cudnnCompiledVersion() >= 7000;
107 bool backwards_compatible_7 = compiled_with_7 && cudnnRuntimeVersion() >= cudnnCompiledVersion();
108 bool patch_compatible = compiled_with_7 && (cudnnRuntimeVersion() / 100) == (cudnnCompiledVersion() / 100);
109 CAFFE_ENFORCE(version_match || backwards_compatible_7 || patch_compatible,
110 "cuDNN compiled (", cudnnCompiledVersion(),
") and " 111 "runtime (", cudnnRuntimeVersion(),
") versions mismatch");
119 template <
typename T>
125 static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
126 typedef const float ScalingParamType;
127 typedef float BNParamType;
128 static ScalingParamType* kOne() {
129 static ScalingParamType v = 1.0;
132 static const ScalingParamType* kZero() {
133 static ScalingParamType v = 0.0;
138 #if CUDNN_VERSION_MIN(6, 0, 0) 142 static const cudnnDataType_t type = CUDNN_DATA_INT32;
143 typedef const int ScalingParamType;
144 typedef int BNParamType;
145 static ScalingParamType* kOne() {
146 static ScalingParamType v = 1;
149 static const ScalingParamType* kZero() {
150 static ScalingParamType v = 0;
154 #endif // CUDNN_VERSION_MIN(6, 0, 0) 159 static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
160 typedef const double ScalingParamType;
161 typedef double BNParamType;
162 static ScalingParamType* kOne() {
163 static ScalingParamType v = 1.0;
166 static ScalingParamType* kZero() {
167 static ScalingParamType v = 0.0;
175 static const cudnnDataType_t type = CUDNN_DATA_HALF;
176 typedef const float ScalingParamType;
177 typedef float BNParamType;
178 static ScalingParamType* kOne() {
179 static ScalingParamType v = 1.0;
182 static ScalingParamType* kZero() {
183 static ScalingParamType v = 0.0;
194 case StorageOrder::NHWC:
195 return CUDNN_TENSOR_NHWC;
196 case StorageOrder::NCHW:
197 return CUDNN_TENSOR_NCHW;
199 LOG(FATAL) <<
"Unknown cudnn equivalent for order: " << order;
202 return CUDNN_TENSOR_NCHW;
213 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
216 CUDNN_CHECK(cudnnDestroyTensorDescriptor(desc_));
219 inline cudnnTensorDescriptor_t Descriptor(
220 const cudnnTensorFormat_t format,
221 const cudnnDataType_t type,
222 const vector<int>& dims,
224 if (type_ == type && format_ == format && dims_ == dims) {
231 dims.size(), 4,
"Currently only 4-dimensional descriptor supported.");
235 CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
240 (format == CUDNN_TENSOR_NCHW ? dims_[1] : dims_[3]),
241 (format == CUDNN_TENSOR_NCHW ? dims_[2] : dims_[1]),
242 (format == CUDNN_TENSOR_NCHW ? dims_[3] : dims_[2])));
248 template <
typename T>
249 inline cudnnTensorDescriptor_t Descriptor(
250 const StorageOrder& order,
251 const vector<int>& dims) {
257 cudnnTensorDescriptor_t desc_;
258 cudnnTensorFormat_t format_;
259 cudnnDataType_t type_;
267 CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&desc_));
270 CUDNN_CHECK(cudnnDestroyFilterDescriptor(desc_));
273 inline cudnnFilterDescriptor_t Descriptor(
274 const StorageOrder& order,
275 const cudnnDataType_t type,
276 const vector<int>& dims,
278 if (type_ == type && order_ == order && dims_ == dims) {
285 dims.size(), 4,
"Currently only 4-dimensional descriptor supported.");
289 CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
295 (order == StorageOrder::NCHW ? dims_[1] : dims_[3]),
296 (order == StorageOrder::NCHW ? dims_[2] : dims_[1]),
297 (order == StorageOrder::NCHW ? dims_[3] : dims_[2])));
303 template <
typename T>
304 inline cudnnFilterDescriptor_t Descriptor(
305 const StorageOrder& order,
306 const vector<int>& dims) {
311 cudnnFilterDescriptor_t desc_;
313 cudnnDataType_t type_;
321 #endif // CAFFE2_CORE_COMMON_CUDNN_H_
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Flush-To-Zero and Denormals-Are-Zero mode.
cudnnTensorDescWrapper is the placeholder that wraps around a cudnnTensorDescriptor_t, allowing us to do descriptor change as-needed during runtime.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...