1 #include "softmax_with_loss_op.h" 2 #include "softmax_shared.h" 6 REGISTER_CPU_OPERATOR(SoftmaxWithLoss, SoftmaxWithLossOp<float, CPUContext>);
8 SoftmaxWithLossGradient,
9 SoftmaxWithLossGradientOp<float, CPUContext>);
12 OPERATOR_SCHEMA(SoftmaxWithLoss)
15 .TensorInferenceFunction(
16 [](
const OperatorDef& def,
const vector<TensorShape>& in) {
17 ArgumentHelper helper(def);
18 auto axis = helper.GetSingleArgument<int32_t>(
"axis", 1);
20 vector<TensorShape> out(2);
24 const auto canonical_axis =
25 canonical_axis_index_(axis, logits.dims().size());
26 const int batch_size =
27 size_to_dim_(canonical_axis, GetDimsVector(logits));
28 const int num_classes =
31 out[0].set_data_type(logits.data_type());
32 out[0].add_dims(batch_size);
33 out[0].add_dims(num_classes);
38 Combined Softmax and Cross-Entropy loss operator. The operator first computes the softmax normalized values for each layer in the batch of the given input, then computes cross-entropy loss. This operator is numerically more stable than separate `Softmax` and `CrossEntropy` ops. The inputs are a 2-D tensor `logits` of size (batch_size x input_feature_dimensions), which represents the unscaled log probabilities, and a 1-dimensional integer `labels` tensor for ground truth. An optional third input blob (`weight_tensor`) can be used to weight the samples for the loss, which is useful if the training set is unbalanced. This operator outputs a `softmax` tensor which contains the probability for each label for each example (same shape is `logits` input), and a scalar `loss` value, which is the averaged cross-entropy loss between the softmax probabilities and the ground truth values. Use parameter `label_prob`=1 to enable inputting labels as a probability distribution. 40 Softmax cross-entropy loss function: 42 $$loss(x, class) = -\log{\biggl(\frac{\exp(x[class])}{\sum_{j} \exp(x[j])}\biggr)} = -x[class] + \log{\biggl(\sum_{j} \exp(x[j])\biggr)}$$ 44 or if the `weight_tensor` has been passed: 46 $$loss(x, class) = weight[class]\biggl(-x[class] + \log{\biggl(\sum_{j} \exp(x[j])\biggr)}\biggr)$$ 48 The `logits` input does not need to explicitly be a 2D vector; rather, it will be coerced into one. For an arbitrary n-dimensional tensor `X` in $[a_0, a_1, ..., a_{k-1}, a_k, ..., a_{n-1}]$, where k is the `axis` provided, then `X` will be coerced into a 2-dimensional tensor with dimensions $[(a_0 * ... * a_{k-1}), (a_k * ... * a_{n-1})]$. For the default case where `axis`=1, the `X` tensor will be coerced into a 2D tensor of dimensions $[a_0, (a_1 * ... * a_{n-1})]$, where $a_0$ is often the batch size. In this situation, we must have $a_0 = N$ and $a_1 * ... * a_{n-1} = D$. Each of these dimensions must be matched correctly, or else the operator will throw errors. 52 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/softmax_with_loss_op.cc 57 <summary> <b>Example</b> </summary> 63 workspace.ResetWorkspace() 65 op = core.CreateOperator( 68 ["softmax", "avgloss"] 71 workspace.FeedBlob("logits", np.random.randn(1, 5).astype(np.float32)) 72 workspace.FeedBlob("labels", np.asarray([4]).astype(np.int32)) 73 print("logits:", workspace.FetchBlob("logits")) 74 print("labels:", workspace.FetchBlob("labels")) 75 workspace.RunOperatorOnce(op) 76 print("softmax:", workspace.FetchBlob("softmax")) 77 print("avgloss:", workspace.FetchBlob("avgloss")) 85 logits: [[-0.3429451 -0.80375195 0.23104447 1.4569176 -0.5268362 ]] 87 softmax: [[0.09721052 0.0613179 0.17258129 0.58800864 0.0808817 ]] 96 <summary> <b>Example 2</b> </summary> 102 workspace.ResetWorkspace() 104 op = core.CreateOperator( 106 ["logits", "labels"], 107 ["softmax", "avgloss"], 111 workspace.FeedBlob("logits", np.asarray([[.1, .4, .7, 1.5, .2]]).astype(np.float32)) 112 workspace.FeedBlob("labels", np.asarray([4]).astype(np.int32)) 113 print("logits:", workspace.FetchBlob("logits")) 114 print("labels:", workspace.FetchBlob("labels")) 115 workspace.RunOperatorOnce(op) 116 print("softmax:", workspace.FetchBlob("softmax")) 117 print("avgloss:", workspace.FetchBlob("avgloss")) 125 logits: [[0.1 0.4 0.7 1.5 0.2]] 127 softmax: [[0.10715417 0.144643 0.19524762 0.4345316 0.11842369]] 135 .Arg("label_prob",
"*(type: int; default: 0)* Setting to 1 enables inputting labels as probability distribution.")
136 .Arg(
"axis",
"*(type: int; default: 1)* Axis of the inputs when coerced to 2D.")
137 .Arg(
"scale",
"*(type: float)* Average loss output scaling factor (must be >= 0).")
138 .Arg(
"order",
"*(type: string; default: 'NCHW')* Order of blob dimensions (only 'NCHW' is supported currently).")
139 .Input(0,
"logits",
"*(type: Tensor`<float>`)* Input tensor.")
140 .Input(1,
"labels",
"*(type: Tensor`<float>`)* Ground truth label tensor.")
144 "*(type: Tensor`<float>`)* [OPTIONAL] Blob used to weight the samples for the loss.")
145 .Output(0,
"softmax",
"*(type: Tensor`<float>`)* Softmax output tensor.")
146 .Output(1,
"loss",
"*(type: float)* Averaged cross-entropy loss output.");
149 OPERATOR_SCHEMA(SoftmaxWithLossGradient).NumOutputs(1);
151 #define DONT_CARE (-1) 154 bool SoftmaxWithLossOp<float, CPUContext>::RunOnDevice() {
158 const auto canonical_axis = X.canonical_axis_index(axis_);
160 N = X.size_to_dim(canonical_axis);
161 D = X.size_from_dim(canonical_axis);
163 Output(0, X.sizes(), at::dtype<float>());
165 float* Pdata = P->template mutable_data<float>();
166 const float* weights = (InputSize() > 2 ? Input(2).data<
float>() :
nullptr);
168 if (label_prob_mode_) {
169 CAFFE_ENFORCE_GE(
T.dim(), 2);
170 CAFFE_ENFORCE_EQ(
T.size_to_dim(canonical_axis), N);
171 CAFFE_ENFORCE_EQ(
T.size_from_dim(canonical_axis), D);
173 if (
T.dim() == canonical_axis) {
174 CAFFE_ENFORCE_EQ(
T.numel(), N);
176 CAFFE_ENFORCE_EQ(
T.size_to_dim(canonical_axis), N);
177 CAFFE_ENFORCE_EQ(
T.size_from_dim(canonical_axis), 1);
181 if (!sum_multiplier_.defined()) {
182 sum_multiplier_ = caffe2::empty({D}, at::dtype<float>().device(CPU));
183 math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<
float>(), &context_);
184 }
else if (sum_multiplier_.numel() != D) {
185 sum_multiplier_.Resize(D);
186 math::Set<float, CPUContext>(D, 1.f, sum_multiplier_.mutable_data<
float>(), &context_);
189 if (!losses_.defined()) {
190 losses_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
191 }
else if (losses_.numel() != N) {
195 if (!rowmax_.defined()) {
196 rowmax_ = caffe2::empty({N}, at::dtype<float>().device(CPU));
197 }
else if (rowmax_.numel() != N) {
207 losses_.mutable_data<
float>(),
208 sum_multiplier_.data<
float>(),
210 rowmax_.mutable_data<
float>());
213 float loss_sum = 0.0;
214 float weight_sum = 0.0;
215 if (!label_prob_mode_) {
216 const int* label_data =
T.data<
int>();
217 const float* Xdata = X.data<
float>();
219 for (
int i = 0; i < N; ++i) {
221 label_data[i] < D && label_data[i] >= 0,
222 "Label seems incorrect: label value larger than number of classes: ",
226 float weight = weights ? weights[i] : 1.0;
227 float l = -Pdata[i * D + label_data[i]] * weight;
229 weight_sum += weight;
231 math::Exp(N * D, Pdata, Pdata, &context_);
233 const float* label_data =
T.data<
float>();
235 for (
int i = 0; i < N; ++i) {
237 float total_prob = 0.0;
238 float weight = weights ? weights[i] : 1.0;
239 for (
int j = 0; j < D; ++j) {
241 label_data[i * D + j] >= 0,
242 "Label prob seems incorrect: label prob value must be nonnegative:",
244 label_data[i * D + j]);
245 l += -log(std::max(Pdata[i * D + j], 1e-20f)) * label_data[i * D + j] *
247 total_prob += label_data[i * D + j];
251 std::abs(total_prob - 1.) < 1e-5f,
252 "Label prob seems incorrect: label prob values do not sum to 1.0: ",
254 " vs 1.0 (+/- 1e-5)");
255 weight_sum += weight;
260 Output(1, vector<int64_t>(), at::dtype<float>());
262 float* avg_loss_data = avg_loss->template mutable_data<float>();
263 if (weight_sum != 0.0) {
264 avg_loss_data[0] = loss_sum * scale_ / weight_sum;
266 avg_loss_data[0] = 0.0;
272 bool SoftmaxWithLossGradientOp<float, CPUContext>::RunOnDevice() {
276 auto& P = Input(InputSize() - 2);
277 auto& d_avg_loss = Input(InputSize() - 1);
279 const float* weights = (InputSize() > 4 ? Input(2).data<
float>() :
nullptr);
281 const auto canonical_axis = X.canonical_axis_index(axis_);
283 N = X.size_to_dim(canonical_axis);
284 D = X.size_from_dim(canonical_axis);
285 auto* dX = Output(0, X.sizes(), at::dtype<float>());
287 if (label_prob_mode_) {
288 CAFFE_ENFORCE_GE(
T.dim(), 2);
289 CAFFE_ENFORCE_EQ(
T.size_to_dim(canonical_axis), N);
290 CAFFE_ENFORCE_EQ(
T.size_from_dim(canonical_axis), D);
292 if (
T.dim() == canonical_axis) {
293 CAFFE_ENFORCE_EQ(
T.numel(), N);
295 CAFFE_ENFORCE_EQ(
T.size_to_dim(canonical_axis), N);
296 CAFFE_ENFORCE_EQ(
T.size_from_dim(canonical_axis), 1);
300 const float* Pdata = P.data<
float>();
301 float* dX_data = dX->template mutable_data<float>();
306 context_.CopyFromCPU<
float>(P.numel(), Pdata, dX_data);
309 float total_weight = 0.0f;
310 if (!label_prob_mode_) {
311 const int* label_data =
T.data<
int>();
314 for (
int i = 0; i < N; ++i) {
315 int idx = i * D + label_data[i];
316 float weight = weights[i];
317 dX_data[idx] = Pdata[idx] - 1.0;
318 for (
int d = 0; d < D; d++) {
320 dX_data[k] *= weight;
323 total_weight += weight;
326 for (
int i = 0; i < N; ++i) {
327 int idx = i * D + label_data[i];
328 dX_data[idx] = Pdata[idx] - 1.0f;
333 const float* label_data =
T.data<
float>();
336 for (
int i = 0; i < N; ++i) {
337 float weight = weights[i];
338 for (
int j = 0; j < D; ++j) {
340 dX_data[idx] = (Pdata[idx] - label_data[idx]) * weight;
342 total_weight += weight;
345 for (
int i = 0; i < N; ++i) {
346 for (
int j = 0; j < D; ++j) {
348 dX_data[idx] = Pdata[idx] - label_data[idx];
356 if (total_weight > 0) {
357 math::Scale<float, float, CPUContext>(
359 scale_ / total_weight * d_avg_loss.data<
float>()[0],
368 class GetSoftmaxWithLossGradient :
public GradientMakerBase {
369 using GradientMakerBase::GradientMakerBase;
370 vector<OperatorDef> GetGradientDefs()
override {
371 vector<string> blob_names{
372 {I(0), I(1), O(0), GO(1)},
376 if (def_.input_size() == 3) {
377 blob_names.emplace(blob_names.begin() + 2, I(2));
379 return SingleGradientDef(
380 "SoftmaxWithLossGradient",
"", blob_names, vector<string>{GI(0)});
384 REGISTER_GRADIENT(SoftmaxWithLoss, GetSoftmaxWithLossGradient);
int64_t size_from_dim_(int k, IntArrayRef dims)
Return product of all dimensions starting from k.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...