1 #include "caffe2/core/context_gpu.h" 2 #include "caffe2/core/cudnn_wrappers.h" 3 #include "caffe2/core/types.h" 4 #include "caffe2/operators/softmax_op.h" 9 constexpr
int NUM_DESCRIPTORS = 2;
10 constexpr
int GRADIENT_NUM_DESCRIPTORS = 3;
11 constexpr
int BOTTOM_DESC_ID = 0;
12 constexpr
int TOP_DESC_ID = 1;
13 constexpr
int TOP_GRADIENT_DESC_ID = 2;
18 template <
class... Args>
21 cudnn_wrapper_(&context_),
22 axis_(OperatorBase::GetSingleArgument<int>(
"axis", 1)) {
23 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
26 ~CuDNNSoftmaxOp()
override {
27 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(desc_));
31 bool DoRunWithType() {
34 const auto canonical_axis = X.canonical_axis_index(axis_);
35 const int N = X.size_to_dim(canonical_axis);
36 const int D = X.size_from_dim(canonical_axis);
38 auto* Y = Output(0, X.sizes(), at::dtype<T>());
39 auto* Y_data = Y->template mutable_data<T>();
43 if (dims_ != X.sizes()) {
44 CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
52 dims_ = X.sizes().vec();
54 CUDNN_ENFORCE(cudnnSoftmaxForward(
56 CUDNN_SOFTMAX_ACCURATE,
57 CUDNN_SOFTMAX_MODE_INSTANCE,
67 bool RunOnDevice()
override {
74 cudnnTensorDescriptor_t desc_;
75 vector<int64_t> dims_;
81 template <
class... Args>
84 cudnn_wrapper_(&context_),
85 axis_(OperatorBase::GetSingleArgument<int>(
"axis", 1)) {
86 CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&desc_));
89 ~CuDNNSoftmaxGradientOp()
override {
90 CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(desc_));
94 bool DoRunWithType() {
98 const auto canonical_axis = Y.canonical_axis_index(axis_);
99 const int N = Y.size_to_dim(canonical_axis);
100 const int D = Y.size_from_dim(canonical_axis);
102 CHECK_EQ(Y.sizes(), dY.sizes());
103 auto* dX = Output(0, Y.sizes(), at::dtype<T>());
104 auto* dX_data = dX->template mutable_data<T>();
108 if (dims_ != Y.sizes()) {
109 CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
117 dims_ = Y.sizes().vec();
119 CUDNN_ENFORCE(cudnnSoftmaxBackward(
120 cudnn_wrapper_.inline_cudnn_handle(),
121 CUDNN_SOFTMAX_ACCURATE,
122 CUDNN_SOFTMAX_MODE_INSTANCE,
125 Y.template data<T>(),
127 dY.template data<T>(),
134 bool RunOnDevice()
override {
141 cudnnTensorDescriptor_t desc_;
142 vector<int64_t> dims_;
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
const Tensor & Input(int idx, DeviceType type=CUDAContext::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnHandle_t inline_cudnn_handle()
Returns the inline cudnn handle that executes on the current thread's cuda_stream.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...