1 #include <ATen/cudnn/Descriptors.h> 9 namespace at {
namespace native {
13 inline cudnnDataType_t getDataType(
const at::Tensor& t) {
14 auto scalar_type = t.scalar_type();
15 if (scalar_type == at::kFloat) {
16 return CUDNN_DATA_FLOAT;
17 }
else if (scalar_type == at::kHalf) {
18 return CUDNN_DATA_HALF;
19 }
else if (scalar_type == at::kDouble) {
20 return CUDNN_DATA_DOUBLE;
22 throw std::runtime_error(
"TensorDescriptor only supports double, float and half tensors");
28 void TensorDescriptor::set(
const at::Tensor &t,
size_t pad) {
29 set(getDataType(t), t.sizes(), t.strides(), pad);
32 void TensorDescriptor::set(cudnnDataType_t datatype, IntArrayRef t_sizes, IntArrayRef t_strides,
size_t pad) {
33 size_t dim = t_sizes.size();
34 if (dim > CUDNN_DIM_MAX || pad > CUDNN_DIM_MAX)
36 #define STR(X) _STR(X) 37 throw std::runtime_error(
"cuDNN supports only up to " STR(CUDNN_DIM_MAX)
" dimensions");
40 int size[CUDNN_DIM_MAX];
41 int stride[CUDNN_DIM_MAX];
42 for (
size_t i = 0; i < dim; ++i) {
43 size[i] =
static_cast<int>(t_sizes[i]);
44 stride[i] =
static_cast<int>(t_strides[i]);
46 for (
size_t i = dim; i < pad; ++i) {
50 set(datatype,
static_cast<int>(std::max(dim, pad)), size, stride);
53 std::string cudnnTypeToString(cudnnDataType_t dtype) {
55 case CUDNN_DATA_FLOAT:
56 return "CUDNN_DATA_FLOAT";
57 case CUDNN_DATA_DOUBLE:
58 return "CUDNN_DATA_DOUBLE";
60 return "CUDNN_DATA_HALF";
62 return "CUDNN_DATA_INT8";
63 case CUDNN_DATA_INT32:
64 return "CUDNN_DATA_INT32";
65 case CUDNN_DATA_INT8x4:
66 return "CUDNN_DATA_INT8x4";
67 #if CUDNN_VERSION >= 7100 68 case CUDNN_DATA_UINT8:
69 return "CUDNN_DATA_UINT8";
70 case CUDNN_DATA_UINT8x4:
71 return "CUDNN_DATA_UINT8x4";
74 std::ostringstream oss;
75 oss <<
"(unknown data-type " <<
static_cast<int>(dtype) <<
")";
80 std::ostream& operator<<(std::ostream & out,
const TensorDescriptor& d) {
81 out <<
"TensorDescriptor " <<
static_cast<void*
>(d.desc()) <<
"\n";
83 int dimA[CUDNN_DIM_MAX];
84 int strideA[CUDNN_DIM_MAX];
85 cudnnDataType_t dtype;
86 cudnnGetTensorNdDescriptor(d.desc(), CUDNN_DIM_MAX, &dtype, &nbDims, dimA, strideA);
87 out <<
" type = " << cudnnTypeToString(dtype) <<
"\n";
88 out <<
" nbDims = " << nbDims <<
"\n";
91 for (
auto i : ArrayRef<int>{dimA,
static_cast<size_t>(nbDims)}) {
96 for (
auto i : ArrayRef<int>{strideA,
static_cast<size_t>(nbDims)}) {
103 void TensorDescriptor::print() { std::cout << *
this; }
105 void FilterDescriptor::set(
const at::Tensor &t, int64_t pad) {
106 auto dim = t.ndimension();
107 if (dim > CUDNN_DIM_MAX || pad > CUDNN_DIM_MAX)
109 #define STR(X) _STR(X) 110 throw std::runtime_error(
"cuDNN supports only up to " STR(CUDNN_DIM_MAX)
" dimensions");
113 if (!t.is_contiguous()) {
118 throw std::runtime_error(
"cuDNN filters (a.k.a. weights) must be contiguous");
120 int size[CUDNN_DIM_MAX];
121 for (
int i = 0; i < dim; ++i) {
122 size[i] = (int) t.size(i);
124 for (
int i = dim; i < pad; ++i) {
127 dim = std::max(dim, pad);
128 set(getDataType(t), (int) dim, size);
Flush-To-Zero and Denormals-Are-Zero mode.