1 #ifndef CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_ 2 #define CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_ 4 #include "caffe2/core/context_gpu.h" 5 #include "caffe2/core/cudnn_wrappers.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/core/tensor.h" 8 #include "caffe2/core/types.h" 16 template <
class... Args>
19 cudnn_wrapper_(&context_) {
20 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
21 CUDNN_ENFORCE(cudnnCreateActivationDescriptor(&act_desc_));
24 virtual ~CuDNNActivationOpBase() {
25 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
26 CUDNN_ENFORCE(cudnnDestroyActivationDescriptor(act_desc_));
30 void SetTensorDescriptor(
31 const cudnnDataType_t data_type,
32 const int data_size) {
33 if (data_size != input_size_) {
36 input_size_ = data_size;
37 CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
49 cudnnTensorDescriptor_t data_desc_;
50 cudnnActivationDescriptor_t act_desc_;
55 template <cudnnActivationMode_t kCuDNNActivationMode>
60 template <
class... Args>
63 CUDNN_ENFORCE(cudnnSetActivationDescriptor(
64 act_desc_, kCuDNNActivationMode, CUDNN_PROPAGATE_NAN, 0.0));
67 bool RunOnDevice()
override {
72 bool DoRunWithType() {
73 const auto& X =
Input(0);
75 auto* Y = Output(0, X.sizes(), at::dtype<T>());
77 Y->template mutable_data<T>();
81 CUDNN_ENFORCE(cudnnActivationForward(
82 this->cudnn_wrapper_.inline_cudnn_handle(),
89 Y->template mutable_data<T>()));
94 template <cudnnActivationMode_t kCuDNNActivationMode>
99 template <
class... Args>
102 CUDNN_ENFORCE(cudnnSetActivationDescriptor(
103 act_desc_, kCuDNNActivationMode, CUDNN_PROPAGATE_NAN, 0.0));
106 bool RunOnDevice()
override {
110 template <
typename T>
111 bool DoRunWithType() {
112 const auto& Y =
Input(0);
113 const auto& dY =
Input(1);
115 auto* dX = Output(0, Y.sizes(), at::dtype<T>());
116 if (Y.numel() == 0) {
117 dX->template mutable_data<T>();
121 CUDNN_ENFORCE(cudnnActivationBackward(
122 this->cudnn_wrapper_.inline_cudnn_handle(),
126 Y.template data<T>(),
128 dY.template data<T>(),
130 Y.template data<T>(),
133 dX->template mutable_data<T>()));
140 #endif // CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
const Tensor & Input(int idx, DeviceType type=CUDAContext::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...