Caffe2 - C++ API
A deep learning, cross platform ML framework
softmax_op_cudnn.cc
1 #include "caffe2/core/context_gpu.h"
2 #include "caffe2/core/cudnn_wrappers.h"
3 #include "caffe2/core/types.h"
4 #include "caffe2/operators/softmax_op.h"
5 
6 namespace caffe2 {
7 
8 namespace {
9 constexpr int NUM_DESCRIPTORS = 2;
10 constexpr int GRADIENT_NUM_DESCRIPTORS = 3;
11 constexpr int BOTTOM_DESC_ID = 0;
12 constexpr int TOP_DESC_ID = 1;
13 constexpr int TOP_GRADIENT_DESC_ID = 2;
14 } // namespace
15 
16 class CuDNNSoftmaxOp final : public Operator<CUDAContext> {
17  public:
18  template <class... Args>
19  explicit CuDNNSoftmaxOp(Args&&... args)
20  : Operator<CUDAContext>(std::forward<Args>(args)...),
21  cudnn_wrapper_(&context_),
22  axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {
23  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
24  }
25 
26  ~CuDNNSoftmaxOp() override {
27  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(desc_));
28  }
29 
30  template <typename T>
31  bool DoRunWithType() {
32  auto& X = Input(0);
33 
34  const auto canonical_axis = X.canonical_axis_index(axis_);
35  const int N = X.size_to_dim(canonical_axis);
36  const int D = X.size_from_dim(canonical_axis);
37 
38  auto* Y = Output(0, X.sizes(), at::dtype<T>());
39  auto* Y_data = Y->template mutable_data<T>();
40  if (N == 0) {
41  return true;
42  }
43  if (dims_ != X.sizes()) {
44  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
45  desc_,
46  GetCudnnTensorFormat(StorageOrder::NCHW),
48  N,
49  D,
50  1,
51  1));
52  dims_ = X.sizes().vec();
53  }
54  CUDNN_ENFORCE(cudnnSoftmaxForward(
55  cudnn_wrapper_.inline_cudnn_handle(),
56  CUDNN_SOFTMAX_ACCURATE,
57  CUDNN_SOFTMAX_MODE_INSTANCE,
59  desc_,
60  X.template data<T>(),
62  desc_,
63  Y_data));
64  return true;
65  }
66 
67  bool RunOnDevice() override {
68  return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
69  }
70 
71  protected:
72  CuDNNWrapper cudnn_wrapper_;
73  int axis_;
74  cudnnTensorDescriptor_t desc_;
75  vector<int64_t> dims_;
76 };
77 
78 
79 class CuDNNSoftmaxGradientOp final : public Operator<CUDAContext> {
80  public:
81  template <class... Args>
82  explicit CuDNNSoftmaxGradientOp(Args&&... args)
83  : Operator<CUDAContext>(std::forward<Args>(args)...),
84  cudnn_wrapper_(&context_),
85  axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {
86  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
87  }
88 
89  ~CuDNNSoftmaxGradientOp() override {
90  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(desc_));
91  }
92 
93  template <typename T>
94  bool DoRunWithType() {
95  auto& Y = Input(0);
96  auto& dY = Input(1);
97 
98  const auto canonical_axis = Y.canonical_axis_index(axis_);
99  const int N = Y.size_to_dim(canonical_axis);
100  const int D = Y.size_from_dim(canonical_axis);
101 
102  CHECK_EQ(Y.sizes(), dY.sizes());
103  auto* dX = Output(0, Y.sizes(), at::dtype<T>());
104  auto* dX_data = dX->template mutable_data<T>();
105  if (N == 0) {
106  return true;
107  }
108  if (dims_ != Y.sizes()) {
109  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
110  desc_,
111  GetCudnnTensorFormat(StorageOrder::NCHW),
113  N,
114  D,
115  1,
116  1));
117  dims_ = Y.sizes().vec();
118  }
119  CUDNN_ENFORCE(cudnnSoftmaxBackward(
120  cudnn_wrapper_.inline_cudnn_handle(),
121  CUDNN_SOFTMAX_ACCURATE,
122  CUDNN_SOFTMAX_MODE_INSTANCE,
124  desc_,
125  Y.template data<T>(),
126  desc_,
127  dY.template data<T>(),
129  desc_,
130  dX_data));
131  return true;
132  }
133 
134  bool RunOnDevice() override {
135  return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
136  }
137 
138  protected:
139  CuDNNWrapper cudnn_wrapper_;
140  int axis_;
141  cudnnTensorDescriptor_t desc_;
142  vector<int64_t> dims_;
143 };
144 
145 namespace {
146 REGISTER_CUDNN_OPERATOR(Softmax, CuDNNSoftmaxOp);
147 REGISTER_CUDNN_OPERATOR(SoftmaxGradient, CuDNNSoftmaxGradientOp);
148 } // namespace
149 } // namespace caffe2
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
Definition: common_cudnn.h:192
const Tensor & Input(int idx, DeviceType type=CUDAContext::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:70
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnHandle_t inline_cudnn_handle()
Returns the inline cudnn handle that executes on the current thread&#39;s cuda_stream.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...
Definition: common_cudnn.h:120