1 #ifndef CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_     2 #define CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_     4 #include "caffe2/core/context_gpu.h"     5 #include "caffe2/core/cudnn_wrappers.h"     6 #include "caffe2/core/operator.h"     7 #include "caffe2/core/tensor.h"     8 #include "caffe2/core/types.h"    16   template <
class... Args>
    19         cudnn_wrapper_(&context_) {
    20     CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
    21     CUDNN_ENFORCE(cudnnCreateActivationDescriptor(&act_desc_));
    24   virtual ~CuDNNActivationOpBase() {
    25     CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
    26     CUDNN_ENFORCE(cudnnDestroyActivationDescriptor(act_desc_));
    30   void SetTensorDescriptor(
    31       const cudnnDataType_t data_type,
    32       const int data_size) {
    33     if (data_size != input_size_) {
    36       input_size_ = data_size;
    37       CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
    49   cudnnTensorDescriptor_t data_desc_;
    50   cudnnActivationDescriptor_t act_desc_;
    55 template <cudnnActivationMode_t kCuDNNActivationMode>
    60   template <
class... Args>
    63     CUDNN_ENFORCE(cudnnSetActivationDescriptor(
    64         act_desc_, kCuDNNActivationMode, CUDNN_PROPAGATE_NAN, 0.0));
    67   bool RunOnDevice()
 override {
    72   bool DoRunWithType() {
    73     const auto& X = 
Input(0);
    75     auto* Y = Output(0, X.sizes(), at::dtype<T>());
    77       Y->template mutable_data<T>();
    81     CUDNN_ENFORCE(cudnnActivationForward(
    82         this->cudnn_wrapper_.inline_cudnn_handle(),
    89         Y->template mutable_data<T>()));
    94 template <cudnnActivationMode_t kCuDNNActivationMode>
    99   template <
class... Args>
   102     CUDNN_ENFORCE(cudnnSetActivationDescriptor(
   103         act_desc_, kCuDNNActivationMode, CUDNN_PROPAGATE_NAN, 0.0));
   106   bool RunOnDevice()
 override {
   110   template <
typename T>
   111   bool DoRunWithType() {
   112     const auto& Y = 
Input(0);
   113     const auto& dY = 
Input(1);
   115     auto* dX = Output(0, Y.sizes(), at::dtype<T>());
   116     if (Y.numel() == 0) {
   117       dX->template mutable_data<T>();
   121     CUDNN_ENFORCE(cudnnActivationBackward(
   122         this->cudnn_wrapper_.inline_cudnn_handle(),
   126         Y.template data<T>(),
   128         dY.template data<T>(),
   130         Y.template data<T>(), 
   133         dX->template mutable_data<T>()));
   140 #endif // CAFFE2_OPERATORS_ACTIVATION_OPS_CUDNN_H_ 
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
 
const Tensor & Input(int idx, DeviceType type=CUDAContext::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
 
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
 
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces. 
 
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...