10 #include <torch/extension.h> 12 #include <ATen/cuda/Exceptions.h> 13 #include <ATen/cudnn/Descriptors.h> 14 #include <ATen/cudnn/Handle.h> 18 const char* cudnn_relu_name =
"cudnn_relu";
21 void cudnn_relu_check(
31 torch::checkContiguous(cudnn_relu_name, arg_inputs);
32 torch::checkScalarType(cudnn_relu_name, arg_inputs, torch::kFloat);
33 torch::checkBackend(cudnn_relu_name, arg_inputs.tensor, torch::Backend::CUDA);
34 torch::checkContiguous(cudnn_relu_name, arg_outputs);
35 torch::checkScalarType(cudnn_relu_name, arg_outputs, torch::kFloat);
37 cudnn_relu_name, arg_outputs.tensor, torch::Backend::CUDA);
38 torch::checkSameSize(cudnn_relu_name, arg_inputs, arg_outputs);
45 cudnn_relu_check(inputs, outputs);
47 cudnnHandle_t cuDnn = torch::native::getCudnnHandle();
50 torch::native::TensorDescriptor input_tensor_desc(inputs, 4);
51 cudnnActivationDescriptor_t activationDesc;
53 AT_CUDNN_CHECK(cudnnCreateActivationDescriptor(&activationDesc));
54 AT_CUDNN_CHECK(cudnnSetActivationDescriptor(
56 CUDNN_ACTIVATION_RELU,
62 AT_CUDNN_CHECK(cudnnActivationForward(
66 input_tensor_desc.desc(),
69 input_tensor_desc.desc(),
72 AT_CUDNN_CHECK(cudnnDestroyActivationDescriptor(activationDesc));
77 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
79 m.def(cudnn_relu_name, &cudnn_relu,
"CuDNN ReLU");