1 #include "caffe2/operators/elu_op.h" 3 #include "caffe2/operators/activation_ops_cudnn.h" 13 template <
class... Args>
16 OP_SINGLE_ARG(
float,
"alpha", alpha_, 1.0f) {
17 CUDNN_ENFORCE(cudnnSetActivationDescriptor(
21 static_cast<double>(alpha_)));
24 bool RunOnDevice()
override {
29 bool DoRunWithType() {
30 const auto& X =
Input(0);
32 auto* Y = Output(0, X.sizes(), at::dtype<T>());
34 Y->template mutable_data<T>();
38 CUDNN_ENFORCE(cudnnActivationForward(
46 Y->template mutable_data<T>()));
60 template <
class... Args>
63 OP_SINGLE_ARG(
float,
"alpha", alpha_, 1.0f) {
64 CUDNN_ENFORCE(cudnnSetActivationDescriptor(
68 static_cast<double>(alpha_)));
71 bool RunOnDevice()
override {
76 bool DoRunWithType() {
77 const auto& Y =
Input(0);
78 const auto& dY =
Input(1);
80 auto* dX = Output(0, Y.sizes(), at::dtype<T>());
82 dX->template mutable_data<T>();
86 CUDNN_ENFORCE(cudnnActivationBackward(
87 this->cudnn_wrapper_.inline_cudnn_handle(),
93 dY.template data<T>(),
98 dX->template mutable_data<T>()));
107 REGISTER_CUDNN_OPERATOR(
const Tensor & Input(int idx, DeviceType type=CUDAContext::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
cudnnHandle_t inline_cudnn_handle()
Returns the inline cudnn handle that executes on the current thread's cuda_stream.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...