Caffe2 - C++ API
A deep learning, cross platform ML framework
softmax_op.cc
1 #include "caffe2/operators/softmax_op.h"
2 #include "caffe2/operators/softmax_shared.h"
3 
4 namespace caffe2 {
5 
6 // Implementation for the CPU context.
7 template <>
8 bool SoftmaxOp<float, CPUContext>::RunOnDevice() {
9  auto& X = Input(0);
10 
11  const auto canonical_axis = X.canonical_axis_index(axis_);
12  const int N = X.size_to_dim(canonical_axis);
13  const int D = X.size_from_dim(canonical_axis);
14  auto* Y = Output(0, X.sizes(), at::dtype<float>());
15  float* Ydata = Y->template mutable_data<float>();
16  // First, get scales
17  if (!scale_.defined()) {
18  scale_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
19  } else if (scale_.numel() != N) {
20  scale_.Resize(N);
21  }
22 
23  if (!rowmax_.defined()) {
24  rowmax_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
25  } else if (rowmax_.numel() != N) {
26  rowmax_.Resize(N);
27  }
28 
29  if (!sum_multiplier_.defined()) {
30  sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CPU));
31  math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
32  } else if (sum_multiplier_.numel() != D) {
33  sum_multiplier_.Resize(D);
34  math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
35  }
36 
37  SoftmaxCPU(
38  context_,
39  N,
40  D,
41  X.data<float>(),
42  Ydata,
43  scale_.mutable_data<float>(),
44  sum_multiplier_.data<float>(),
45  false,
46  rowmax_.mutable_data<float>());
47  return true;
48 }
49 
50 // Implementation for the CPU context.
51 template <>
52 bool SoftmaxGradientOp<float, CPUContext>::RunOnDevice() {
53  auto& Y = Input(0);
54  auto& dY = Input(1);
55 
56  const auto canonical_axis = Y.canonical_axis_index(axis_);
57  const int64_t N = Y.size_to_dim(canonical_axis);
58  const int64_t D = Y.size_from_dim(canonical_axis);
59  // First, get scales
60  if (!scale_.defined()) {
61  scale_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
62  } else if (scale_.numel() != N) {
63  scale_.Resize(N);
64  }
65 
66  if (!sum_multiplier_.defined()) {
67  sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CPU));
68  math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
69  } else if (sum_multiplier_.numel() != D) {
70  sum_multiplier_.Resize(D);
71  math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<float>(), &context_);
72  }
73 
74  auto* dX = Output(0, Y.sizes(), at::dtype<float>());
75  const float* Ydata = Y.data<float>();
76  const float* dYdata = dY.data<float>();
77  float* dXdata = dX->mutable_data<float>();
78  if (N == 0) {
79  return true;
80  }
81  context_.CopySameDevice<float>(Y.numel(), dYdata, dXdata);
82  float* scaledata = scale_.mutable_data<float>();
83  for (int i = 0; i < N; ++i) {
84  math::Dot<float, CPUContext>(D, Ydata + i * D, dYdata + i * D,
85  scaledata + i, &context_);
86  }
87  math::Gemm<float, CPUContext>(CblasNoTrans, CblasNoTrans, N, D, 1, -1,
88  scaledata, sum_multiplier_.data<float>(), 1,
89  dXdata, &context_);
90  math::Mul<float, CPUContext>(Y.numel(), dXdata, Ydata, dXdata, &context_);
91  return true;
92 }
93 
94 REGISTER_CPU_OPERATOR(Softmax, SoftmaxOp<float, CPUContext>);
95 REGISTER_CPU_GRADIENT_OPERATOR(
96  SoftmaxGradient,
97  SoftmaxGradientOp<float, CPUContext>);
98 
99 OPERATOR_SCHEMA(Softmax)
100  .NumInputs(1)
101  .NumOutputs(1)
102  .IdenticalTypeAndShape()
103  .SetDoc(R"DOC(
104 
105 Applies the Softmax function to an n-dimensional input Tensor rescaling them so
106 that the elements of the n-dimensional output Tensor lie in the range (0,1) and
107 sum to 1. The softmax operator is typically the last layer in a classifier network,
108 as its output can be interpreted as confidence probabilities of an input belonging
109 to each class. The input is a 2-D tensor (Tensor) of size (batch_size x
110 input_feature_dimensions). The output tensor has the same shape and contains the
111 softmax normalized values of the corresponding input. The softmax function is
112 defined as follows:
113 
114 $$softmax(x_i) = \frac{\exp(x_i)}{\sum_{j} \exp(x_j)}$$
115 
116 The input does not need to explicitly be a 2D vector; rather, it will be coerced
117 into one. For an arbitrary n-dimensional tensor `X` in
118 $[a_0, a_1, ..., a_{k-1}, a_k, ..., a_{n-1}]$, where k is the `axis` provided,
119 then `X` will be coerced into a 2-dimensional tensor with dimensions
120 $[(a_0 * ... * a_{k-1}), (a_k * ... * a_{n-1})]$. For the default case where
121 `axis`=1, the `X` tensor will be coerced into a 2D tensor of dimensions
122 $[a_0, (a_1 * ... * a_{n-1})]$, where $a_0$ is often the batch size. In this
123 situation, we must have $a_0 = N$ and $a_1 * ... * a_{n-1} = D$. Each of these
124 dimensions must be matched correctly, or else the operator will throw errors.
125 
126 Github Links:
127 
128 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/softmax_op.h
129 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/softmax_op.cc
130 
131 
132 <details>
133 
134 <summary> <b>Example</b> </summary>
135 
136 **Code**
137 
138 ```
139 workspace.ResetWorkspace()
140 
141 op = core.CreateOperator(
142  "Softmax",
143  ["X"],
144  ["Y"]
145 )
146 
147 workspace.FeedBlob("X", np.random.randn(1, 5).astype(np.float32))
148 print("input:", workspace.FetchBlob("X"))
149 workspace.RunOperatorOnce(op)
150 print("softmax:", workspace.FetchBlob("Y"))
151 
152 ```
153 
154 **Result**
155 
156 ```
157 input: [[ 0.0417839 0.61960053 -0.23150268 -0.64389366 -3.0000346 ]]
158 softmax: [[0.24422921 0.43525138 0.18582782 0.12303016 0.01166145]]
159 
160 ```
161 
162 </details>
163 
164 
165 
166 )DOC")
167  .Arg(
168  "axis",
169  "*(type: int; default: 1)* Axis of the inputs when coerced to 2D matrix.")
170  .Input(
171  0,
172  "X",
173  "*(type: Tensor`<float>`)* Input tensor that's coerced into a 2D matrix of size (NxD) as described above.")
174  .Output(
175  0,
176  "Y",
177  "*(type: Tensor`<float>`)* The softmax normalized output tensor with the same shape as input tensor.")
178  .InheritOnnxSchema();
179 
180 // Input: Y, dY. Output: dX
181 GRADIENT_OPERATOR_SCHEMA(SoftmaxGradient).NumInputs(2).NumOutputs(1);
182 
183 class GetSoftmaxGradient : public GradientMakerBase {
184  using GradientMakerBase::GradientMakerBase;
185  vector<OperatorDef> GetGradientDefs() override {
186  return SingleGradientDef(
187  def_.type() + "Gradient", "",
188  vector<string>{O(0), GO(0)},
189  vector<string>{GI(0)});
190  }
191 };
192 REGISTER_GRADIENT(Softmax, GetSoftmaxGradient);
193 REGISTER_GRADIENT(SoftmaxFp16, GetSoftmaxGradient);
194 
195 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:70