1 #include <ATen/cudnn/Types.h> 5 namespace at {
namespace native {
7 cudnnDataType_t getCudnnDataType(
const at::Tensor& tensor) {
8 if (tensor.scalar_type() == at::kFloat) {
9 return CUDNN_DATA_FLOAT;
10 }
else if (tensor.scalar_type() == at::kDouble) {
11 return CUDNN_DATA_DOUBLE;
12 }
else if (tensor.scalar_type() == at::kHalf) {
13 return CUDNN_DATA_HALF;
15 std::string msg(
"getCudnnDataType() not supported for ");
16 msg += toString(tensor.scalar_type());
17 throw std::runtime_error(msg);
20 int64_t cudnn_version() {
Flush-To-Zero and Denormals-Are-Zero mode.