Caffe2 - C++ API
A deep learning, cross platform ML framework
elu_op_cudnn.cc
1 #include "caffe2/operators/elu_op.h"
2 
3 #include "caffe2/operators/activation_ops_cudnn.h"
4 
5 namespace caffe2 {
6 
7 template <>
8 class CuDNNActivationOp<CUDNN_ACTIVATION_ELU> final
9  : public CuDNNActivationOpBase {
10  public:
11  USE_OPERATOR_FUNCTIONS(CUDAContext);
12 
13  template <class... Args>
14  explicit CuDNNActivationOp(Args&&... args)
15  : CuDNNActivationOpBase(std::forward<Args>(args)...),
16  OP_SINGLE_ARG(float, "alpha", alpha_, 1.0f) {
17  CUDNN_ENFORCE(cudnnSetActivationDescriptor(
18  act_desc_,
19  CUDNN_ACTIVATION_ELU,
20  CUDNN_PROPAGATE_NAN,
21  static_cast<double>(alpha_)));
22  }
23 
24  bool RunOnDevice() override {
25  return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
26  }
27 
28  template <typename T>
29  bool DoRunWithType() {
30  const auto& X = Input(0);
31 
32  auto* Y = Output(0, X.sizes(), at::dtype<T>());
33  if (X.numel() == 0) {
34  Y->template mutable_data<T>();
35  return true;
36  }
37  this->SetTensorDescriptor(cudnnTypeWrapper<T>::type, X.numel());
38  CUDNN_ENFORCE(cudnnActivationForward(
39  this->cudnn_wrapper_.inline_cudnn_handle(),
40  this->act_desc_,
42  this->data_desc_,
43  X.template data<T>(),
45  this->data_desc_,
46  Y->template mutable_data<T>()));
47  return true;
48  }
49 
50  private:
51  const float alpha_;
52 };
53 
54 template <>
55 class CuDNNActivationGradientOp<CUDNN_ACTIVATION_ELU> final
56  : public CuDNNActivationOpBase {
57  public:
58  USE_OPERATOR_FUNCTIONS(CUDAContext);
59 
60  template <class... Args>
61  explicit CuDNNActivationGradientOp(Args&&... args)
62  : CuDNNActivationOpBase(std::forward<Args>(args)...),
63  OP_SINGLE_ARG(float, "alpha", alpha_, 1.0f) {
64  CUDNN_ENFORCE(cudnnSetActivationDescriptor(
65  act_desc_,
66  CUDNN_ACTIVATION_ELU,
67  CUDNN_PROPAGATE_NAN,
68  static_cast<double>(alpha_)));
69  }
70 
71  bool RunOnDevice() override {
72  return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
73  }
74 
75  template <typename T>
76  bool DoRunWithType() {
77  const auto& Y = Input(0);
78  const auto& dY = Input(1);
79 
80  auto* dX = Output(0, Y.sizes(), at::dtype<T>());
81  if (Y.numel() == 0) {
82  dX->template mutable_data<T>();
83  return true;
84  }
85  this->SetTensorDescriptor(cudnnTypeWrapper<T>::type, Y.numel());
86  CUDNN_ENFORCE(cudnnActivationBackward(
87  this->cudnn_wrapper_.inline_cudnn_handle(),
88  this->act_desc_,
90  this->data_desc_,
91  Y.template data<T>(),
92  this->data_desc_,
93  dY.template data<T>(),
94  this->data_desc_,
95  Y.template data<T>(), // Use Y_data as placeholder here.
97  this->data_desc_,
98  dX->template mutable_data<T>()));
99  return true;
100  }
101 
102  private:
103  const float alpha_;
104 };
105 
106 REGISTER_CUDNN_OPERATOR(Elu, CuDNNActivationOp<CUDNN_ACTIVATION_ELU>);
107 REGISTER_CUDNN_OPERATOR(
108  EluGradient,
110 
111 } // namespace caffe2
const Tensor & Input(int idx, DeviceType type=CUDAContext::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
cudnnHandle_t inline_cudnn_handle()
Returns the inline cudnn handle that executes on the current thread&#39;s cuda_stream.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...
Definition: common_cudnn.h:120