Caffe2 - C++ API
A deep learning, cross platform ML framework
softmax_op_cudnn.cc
1 
17 #include "caffe2/core/context_gpu.h"
18 #include "caffe2/core/cudnn_wrappers.h"
19 #include "caffe2/core/types.h"
20 #include "caffe2/operators/softmax_op.h"
21 
22 namespace caffe2 {
23 
24 namespace {
25 constexpr int NUM_DESCRIPTORS = 2;
26 constexpr int GRADIENT_NUM_DESCRIPTORS = 3;
27 constexpr int BOTTOM_DESC_ID = 0;
28 constexpr int TOP_DESC_ID = 1;
29 constexpr int TOP_GRADIENT_DESC_ID = 2;
30 } // namespace
31 
32 class CuDNNSoftmaxOp final : public Operator<CUDAContext> {
33  public:
34  explicit CuDNNSoftmaxOp(const OperatorDef& def, Workspace* ws)
35  : Operator<CUDAContext>(def, ws),
36  cudnn_wrapper_(&context_),
37  axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {
38  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
39  }
40 
41  ~CuDNNSoftmaxOp() {
42  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(desc_));
43  }
44 
45  template <typename T>
46  bool DoRunWithType() {
47  auto& X = Input(0);
48  auto* Y = Output(0);
49  const auto canonical_axis = X.canonical_axis_index(axis_);
50  const int N = X.size_to_dim(canonical_axis);
51  const int D = X.size_from_dim(canonical_axis);
52 
53  Y->ResizeLike(X);
54  if (dims_ != X.dims()) {
55  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
56  desc_,
57  GetCudnnTensorFormat(StorageOrder::NCHW),
59  N,
60  D,
61  1,
62  1));
63  dims_ = X.dims();
64  }
65  CUDNN_ENFORCE(cudnnSoftmaxForward(
66  cudnn_wrapper_.inline_cudnn_handle(),
67  CUDNN_SOFTMAX_ACCURATE,
68  CUDNN_SOFTMAX_MODE_INSTANCE,
70  desc_,
71  X.template data<T>(),
73  desc_,
74  Y->template mutable_data<T>()));
75  return true;
76  }
77 
78  bool RunOnDevice() override {
79  return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0));
80  }
81 
82  protected:
83  CuDNNWrapper cudnn_wrapper_;
84  int axis_;
85  cudnnTensorDescriptor_t desc_;
86  vector<TIndex> dims_;
87 };
88 
89 
90 class CuDNNSoftmaxGradientOp final : public Operator<CUDAContext> {
91  public:
92  explicit CuDNNSoftmaxGradientOp(const OperatorDef& def, Workspace* ws)
93  : Operator<CUDAContext>(def, ws),
94  cudnn_wrapper_(&context_),
95  axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {
96  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
97  }
98 
100  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(desc_));
101  }
102 
103  template <typename T>
104  bool DoRunWithType() {
105  auto& Y = Input(0);
106  auto& dY = Input(1);
107  auto* dX = Output(0);
108  const auto canonical_axis = Y.canonical_axis_index(axis_);
109  const int N = Y.size_to_dim(canonical_axis);
110  const int D = Y.size_from_dim(canonical_axis);
111 
112  CHECK_EQ(Y.dims(), dY.dims());
113  dX->ResizeLike(Y);
114  if (dims_ != Y.dims()) {
115  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
116  desc_,
117  GetCudnnTensorFormat(StorageOrder::NCHW),
119  N,
120  D,
121  1,
122  1));
123  dims_ = Y.dims();
124  }
125  CUDNN_ENFORCE(cudnnSoftmaxBackward(
126  cudnn_wrapper_.inline_cudnn_handle(),
127  CUDNN_SOFTMAX_ACCURATE,
128  CUDNN_SOFTMAX_MODE_INSTANCE,
130  desc_,
131  Y.template data<T>(),
132  desc_,
133  dY.template data<T>(),
135  desc_,
136  dX->template mutable_data<T>()));
137  return true;
138  }
139 
140  bool RunOnDevice() override {
141  return DispatchHelper<TensorTypes<float, float16>>::call(this, Input(0));
142  }
143 
144  protected:
145  CuDNNWrapper cudnn_wrapper_;
146  int axis_;
147  cudnnTensorDescriptor_t desc_;
148  vector<TIndex> dims_;
149 };
150 
151 namespace {
152 REGISTER_CUDNN_OPERATOR(Softmax, CuDNNSoftmaxOp);
153 REGISTER_CUDNN_OPERATOR(SoftmaxGradient, CuDNNSoftmaxGradientOp);
154 } // namespace
155 } // namespace caffe2
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
Definition: common_cudnn.h:199
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnHandle_t inline_cudnn_handle()
Returns the inline cudnn handle that executes on the current thread&#39;s cuda_stream.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...
Definition: common_cudnn.h:127