Caffe2 - C++ API
A deep learning, cross platform ML framework
cross_entropy_op.cc
1 #include "caffe2/operators/cross_entropy_op.h"
2 #include "caffe2/utils/eigen_utils.h"
3 
4 namespace caffe2 {
5 
6 namespace {
7 
8 inline float sigmoid_xent_forward(float lgt, float tgt) {
9  return lgt * (tgt - (lgt >= 0)) - log(1 + exp(lgt - 2 * lgt * (lgt >= 0)));
10 }
11 
12 inline float sigmoid_xent_backward(float lgt, float tgt) {
13  return tgt - 1. / (1. + exp(-lgt));
14 }
15 
16 inline float sigmoid_partition(float lgt) {
17  // computes log(1 + exp(lgt)) with only exp(x) function when x >= 0
18  return lgt * (lgt >= 0) + log(1 + exp(lgt - 2 * lgt * (lgt >= 0)));
19 }
20 
21 inline float sigmoid_xent_forward_with_log_d_trick(float lgt, float tgt) {
22  return (2 * tgt - 1.) * (lgt - sigmoid_partition(lgt));
23 }
24 
25 inline float sigmoid_xent_backward_with_log_d_trick(float lgt, float tgt) {
26  return (2 * tgt - 1.) / (1. + exp(lgt));
27 }
28 
29 inline float unjoined_sigmoid_xent_forward(float lgt, float tgt) {
30  return lgt * tgt + (tgt - 1) * lgt * (lgt >= 0) -
31  (1 - tgt) * log(1 + exp(lgt - 2 * lgt * (lgt >= 0)));
32 }
33 
34 inline float unjoined_sigmoid_xent_backward(float lgt, float tgt) {
35  return tgt - (1. - tgt) / (1. + exp(-lgt));
36 }
37 
38 } // namespace
39 
40 template <>
41 bool LabelCrossEntropyOp<float, CPUContext>::RunOnDevice() {
42  auto& X = Input(0);
43  auto& label = Input(1);
44 
45  int N, D;
46  if (X.dim() > 1) {
47  N = X.dim32(0);
48  D = X.size_from_dim(1);
49  } else {
50  N = 1;
51  D = X.dim32(0);
52  }
53  CAFFE_ENFORCE(
54  (label.dim() == 1) || (label.dim() == 2 && label.dim32(1) == 1));
55  CAFFE_ENFORCE_EQ(label.dim32(0), N);
56  auto* Y = Output(0, {N}, at::dtype<float>());
57  const auto* Xdata = X.data<float>();
58  const auto* labelData = label.data<int>();
59  auto* Ydata = Y->template mutable_data<float>();
60  CAFFE_ENFORCE(
61  (ConstEigenVectorArrayMap<int>(labelData, N) < D).all() &&
62  (ConstEigenVectorArrayMap<int>(labelData, N) >= 0).all(),
63  "Label seems to be outside of supported range. Supported labels are in "
64  "range [0,",
65  D,
66  ")");
67  for (int i = 0; i < N; ++i) {
68  Ydata[i] = -log(std::max(Xdata[i * D + labelData[i]], kLOG_THRESHOLD()));
69  }
70  return true;
71 }
72 
73 template <>
74 bool SigmoidCrossEntropyWithLogitsOp<float, CPUContext>::RunOnDevice() {
75  auto& logits = Input(0);
76  auto& targets = Input(1);
77  CAFFE_ENFORCE_EQ(logits.sizes(), targets.sizes());
78  const auto inner_size = logits.dim() > 0 ? logits.sizes().back() : 1;
79  const auto outer_size = logits.numel() / inner_size;
80 
81  std::vector<int64_t> dims;
82  if (logits.dim() != 0) {
83  dims =
84  std::vector<int64_t>(logits.sizes().begin(), logits.sizes().end() - 1);
85  }
86  auto* out = Output(0, dims, at::dtype<float>());
87  auto* out_ptr = out->template mutable_data<float>();
88 
89  auto* logits_ptr = logits.data<float>();
90  auto* targets_ptr = targets.data<float>();
91 
92  auto in_idx = 0;
93  for (int i = 0; i < outer_size; ++i) {
94  float value = 0;
95  for (int j = 0; j < inner_size; ++j) {
96  if (unjoined_lr_loss_) {
97  value += unjoined_sigmoid_xent_forward(
98  logits_ptr[in_idx], targets_ptr[in_idx]);
99  } else {
100  value +=
101  (log_D_trick_ ? sigmoid_xent_forward_with_log_d_trick(
102  logits_ptr[in_idx], targets_ptr[in_idx])
103  : sigmoid_xent_forward(
104  logits_ptr[in_idx], targets_ptr[in_idx]));
105  }
106  ++in_idx;
107  }
108  out_ptr[i] = -value / inner_size;
109  }
110  return true;
111 }
112 
113 template <>
114 bool SigmoidCrossEntropyWithLogitsGradientOp<float, CPUContext>::RunOnDevice() {
115  auto& g = Input(0);
116  auto& logits = Input(1);
117  auto& targets = Input(2);
118  CAFFE_ENFORCE(logits.sizes() == targets.sizes());
119  const auto inner_size = logits.dim() > 0 ? logits.sizes().back() : 1;
120  const auto outer_size = logits.numel() / inner_size;
121  CAFFE_ENFORCE(g.numel() == outer_size);
122 
123  auto* out = Output(0, logits.sizes(), at::dtype<float>());
124  auto* out_ptr = out->template mutable_data<float>();
125 
126  auto* logits_ptr = logits.data<float>();
127  auto* targets_ptr = targets.data<float>();
128  auto* g_ptr = g.data<float>();
129 
130  auto in_idx = 0;
131  for (int i = 0; i < outer_size; ++i) {
132  auto g_factor = -g_ptr[i] / inner_size;
133  for (int j = 0; j < inner_size; ++j) {
134  if (unjoined_lr_loss_) {
135  out_ptr[in_idx] = g_factor *
136  unjoined_sigmoid_xent_backward(
137  logits_ptr[in_idx], targets_ptr[in_idx]);
138  } else {
139  out_ptr[in_idx] = g_factor *
140  (log_D_trick_ ? sigmoid_xent_backward_with_log_d_trick(
141  logits_ptr[in_idx], targets_ptr[in_idx])
142  : sigmoid_xent_backward(
143  logits_ptr[in_idx], targets_ptr[in_idx]));
144  }
145  ++in_idx;
146  }
147  }
148  return true;
149 }
150 
151 template <>
152 bool WeightedSigmoidCrossEntropyWithLogitsOp<float, CPUContext>::RunOnDevice() {
153  auto& logits = Input(0);
154  auto& targets = Input(1);
155  auto& weights = Input(2);
156  CAFFE_ENFORCE(logits.sizes() == targets.sizes());
157  CAFFE_ENFORCE(weights.sizes() == targets.sizes());
158  const auto inner_size = logits.dim() > 0 ? logits.sizes().back() : 1;
159  const auto outer_size = logits.numel() / inner_size;
160 
161  std::vector<int64_t> dims;
162  if (logits.dim() != 0) {
163  dims =
164  std::vector<int64_t>(logits.sizes().begin(), logits.sizes().end() - 1);
165  }
166 
167  auto* out = Output(0, dims, at::dtype<float>());
168  auto* out_ptr = out->template mutable_data<float>();
169 
170  auto* logits_ptr = logits.data<float>();
171  auto* targets_ptr = targets.data<float>();
172  auto* weights_ptr = weights.data<float>();
173 
174  auto in_idx = 0;
175  for (int i = 0; i < outer_size; ++i) {
176  float value = 0;
177  for (int j = 0; j < inner_size; ++j) {
178  value += sigmoid_xent_forward(logits_ptr[in_idx], targets_ptr[in_idx]) *
179  weights_ptr[in_idx];
180  ++in_idx;
181  }
182  out_ptr[i] = -value / inner_size;
183  }
184  return true;
185 }
186 
187 template <>
188 bool WeightedSigmoidCrossEntropyWithLogitsGradientOp<float, CPUContext>::
189  RunOnDevice() {
190  auto& g = Input(0);
191  auto& logits = Input(1);
192  auto& targets = Input(2);
193  auto& weights = Input(3);
194  CAFFE_ENFORCE(logits.sizes() == targets.sizes());
195  CAFFE_ENFORCE(weights.sizes() == targets.sizes());
196  const auto inner_size = logits.dim() > 0 ? logits.sizes().back() : 1;
197  const auto outer_size = logits.numel() / inner_size;
198  CAFFE_ENFORCE(g.numel() == outer_size);
199 
200  auto* out = Output(0, logits.sizes(), at::dtype<float>());
201  auto* out_ptr = out->template mutable_data<float>();
202 
203  auto* logits_ptr = logits.data<float>();
204  auto* targets_ptr = targets.data<float>();
205  auto* weights_ptr = weights.data<float>();
206  auto* g_ptr = g.data<float>();
207 
208  auto in_idx = 0;
209  for (int i = 0; i < outer_size; ++i) {
210  auto g_factor = -g_ptr[i] / inner_size;
211  for (int j = 0; j < inner_size; ++j) {
212  out_ptr[in_idx] = g_factor *
213  sigmoid_xent_backward(logits_ptr[in_idx], targets_ptr[in_idx]) *
214  weights_ptr[in_idx];
215  ++in_idx;
216  }
217  }
218  return true;
219 }
220 
221 template <>
222 bool LabelCrossEntropyGradientOp<float, CPUContext>::RunOnDevice() {
223  auto& X = Input(0);
224  auto& label = Input(1);
225  auto& dY = Input(2);
226 
227  int N, D;
228  if (X.dim() > 1) {
229  N = X.dim32(0);
230  D = X.size_from_dim(1);
231  } else {
232  N = 1;
233  D = X.dim32(0);
234  }
235  CAFFE_ENFORCE(
236  (label.dim() == 1) || (label.dim() == 2 && label.dim32(1) == 1));
237  CAFFE_ENFORCE_EQ(label.dim32(0), N);
238  CAFFE_ENFORCE_EQ(dY.dim(), 1);
239  CAFFE_ENFORCE_EQ(dY.dim32(0), N);
240  auto* dX = Output(0, X.sizes(), at::dtype<float>());
241  math::Set<float, CPUContext>(
242  dX->numel(), 0.f, dX->template mutable_data<float>(), &context_);
243  const float* Xdata = X.data<float>();
244  const float* dYdata = dY.data<float>();
245  const int* labelData = label.data<int>();
246  float* dXdata = dX->template mutable_data<float>();
247  for (int i = 0; i < N; ++i) {
248  dXdata[i * D + labelData[i]] =
249  - dYdata[i] / std::max(Xdata[i * D + labelData[i]], kLOG_THRESHOLD());
250  }
251  return true;
252 }
253 
254 template <>
255 bool MakeTwoClassOp<float, CPUContext>::RunOnDevice() {
256  auto& X = Input(0);
257 
258  auto shape = X.sizes().vec();
259  shape.push_back(2);
260  int64_t N = X.numel();
261  auto* Y = Output(0, shape, at::dtype<float>());
262  const auto* Xdata = X.data<float>();
263  auto* Ydata = Y->template mutable_data<float>();
264  for (int64_t i = 0; i < N; ++i) {
265  DCHECK_GE(Xdata[i], 0.0);
266  DCHECK_LE(Xdata[i], 1.0);
267  Ydata[i * 2] = 1.0 - Xdata[i];
268  Ydata[i * 2 + 1] = Xdata[i];
269  }
270  return true;
271 }
272 
273 template <>
274 bool MakeTwoClassGradientOp<float, CPUContext>::RunOnDevice() {
275  auto& dY = Input(0);
276 
277  auto shape = dY.sizes().vec();
278  CAFFE_ENFORCE_GE(shape.size(), 1);
279  CAFFE_ENFORCE_EQ(shape.back(), 2);
280  shape.pop_back();
281  auto* dX = Output(0, shape, at::dtype<float>());
282  const float* dYdata = dY.data<float>();
283  float* dXdata = dX->template mutable_data<float>();
284  int64_t N = dX->numel();
285  // use eigen?
286  for (int64_t i = 0; i < N; ++i) {
287  dXdata[i] = dYdata[i * 2 + 1] - dYdata[i * 2];
288  }
289  return true;
290 }
291 
292 template <>
293 bool CrossEntropyOp<float, CPUContext>::RunOnDevice() {
294  auto& X = Input(0);
295  auto& label = Input(1);
296 
297  int N, D;
298  if (X.dim() > 1) {
299  N = X.dim32(0);
300  D = X.size_from_dim(1);
301  } else {
302  N = 1;
303  D = X.dim32(0);
304  }
305  CAFFE_ENFORCE(
306  (label.dim() == 1) || (label.dim() == 2 && label.dim32(1) == D));
307  CAFFE_ENFORCE_EQ(label.dim32(0), N);
308  auto* Y = Output(0, vector<int64_t>{N}, at::dtype<float>());
309  const float* Xdata = X.data<float>();
310  const float* labelData = label.data<float>();
311  auto* Ydata = Y->template mutable_data<float>();
312  CAFFE_ENFORCE(
313  (ConstEigenArrayMap<float>(labelData, D, N) <= 1.0f).all() &&
314  (ConstEigenArrayMap<float>(labelData, D, N) >= 0.0f).all(),
315  "Soft label seems incorrect: label value should be a probability ",
316  "between 0 and 1.0. You may be using the wrong cross entropy operator; ",
317  "use LabelCrossEntropy if the labels are integers whose values are at ",
318  "most the number of classes, ",
319  D,
320  ".");
321  EigenArrayMap<float>(Ydata, 1, N) =
322  -(ConstEigenArrayMap<float>(labelData, D, N) *
323  ConstEigenArrayMap<float>(Xdata, D, N).cwiseMax(kLOG_THRESHOLD()).log())
324  .colwise()
325  .sum();
326  return true;
327 }
328 
329 template <>
330 bool CrossEntropyGradientOp<float, CPUContext>::RunOnDevice() {
331  auto& X = Input(0);
332  auto& label = Input(1);
333  auto& dY = Input(2);
334 
335  int N, D;
336  if (X.dim() > 1) {
337  N = X.dim32(0);
338  D = X.size_from_dim(1);
339  } else {
340  N = 1;
341  D = X.dim32(0);
342  }
343  CAFFE_ENFORCE(
344  (label.dim() == 1) || (label.dim() == 2 && label.dim32(1) == D));
345  CAFFE_ENFORCE_EQ(label.dim32(0), N);
346  CAFFE_ENFORCE_EQ(dY.dim(), 1);
347  CAFFE_ENFORCE_EQ(dY.dim32(0), N);
348  auto* dX = Output(0, X.sizes(), at::dtype<float>());
349  math::Set<float, CPUContext>(
350  dX->numel(), 0.f, dX->template mutable_data<float>(), &context_);
351  const float* Xdata = X.data<float>();
352  const float* dYdata = dY.data<float>();
353  const float* labelData = label.data<float>();
354  float* dXdata = dX->template mutable_data<float>();
355  EigenArrayMap<float>(dXdata, D, N) =
356  (ConstEigenArrayMap<float>(labelData, D, N) /
357  ConstEigenArrayMap<float>(Xdata, D, N).cwiseMax(kLOG_THRESHOLD()))
358  .rowwise() *
359  (-ConstEigenVectorArrayMap<float>(dYdata, N).transpose());
360  return true;
361 }
362 
363 REGISTER_CPU_OPERATOR(LabelCrossEntropy,
364  LabelCrossEntropyOp<float, CPUContext>);
365 REGISTER_CPU_OPERATOR(LabelCrossEntropyGradient,
366  LabelCrossEntropyGradientOp<float, CPUContext>);
367 
368 OPERATOR_SCHEMA(LabelCrossEntropy)
369  .NumInputs(2)
370  .NumOutputs(1)
371  .IdenticalTypeAndShapeOfInputDim(0, 0)
372  .SetDoc(R"DOC(
373 This operator computes the cross entropy between a $NxD$ dimensional input data tensor $X$ and a one dimensional input label tensor $label$. The op produces a single length $N$ output tensor $Y$. Here, $N$ is considered the batch size and $D$ is the size of each element in the batch. In practice, it is most commonly used at the end of models as a part of the loss computation, after the SoftMax operator and before the AveragedLoss operator. The cross entropy operation is defined as follows
374 
375 $$Y_i = -log(X_{ij})$$
376 
377 where ($i$, $j$) is the classifier's prediction of the $j$th class (the correct one), and $i$ is the batch size. Each log has a lower limit for numerical stability.
378 
379 The difference between *LabelCrossEntropy* and *CrossEntropy* is how the labels are specified. Here, the labels are a length $N$ list of integers, whereas in CrossEntropy the labels are a $NxD$ dimensional matrix of one hot label vectors. However, the results of computation should be the same, as shown in the two examples where ($i$, $j$) is the classifier's prediction of the $j$th class (the correct one), and $i$ is the batch size. Each log has a lower limit for numerical stability.
380 
381 Github Links:
382 - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/cross_entropy_op.h
383 - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/cross_entropy_op.cc
384 
385 <details>
386 
387 <summary> <b>Example</b> </summary>
388 
389 **Code**
390 
391 ```
392 
393 workspace.ResetWorkspace()
394 
395 op = core.CreateOperator(
396  "LabelCrossEntropy",
397  ["X", "label"],
398  ["Y"]
399 )
400 
401 // Create X: Sample softmax output for 5-class model
402 X = np.array([[.01, .05, .02, .02, .9],[.03, .1, .42, .05, .4]])
403 print("X:\n",X)
404 
405 // Create label: Sample 1-hot ground truth label vectors
406 label = np.array([4,2])
407 print("label:\n",label)
408 
409 // Feed X & label into workspace
410 workspace.FeedBlob("X", X.astype(np.float32))
411 workspace.FeedBlob("label", label.astype(np.int32))
412 
413 // Run op
414 workspace.RunOperatorOnce(op)
415 
416 // Collect Output
417 print("Y:\n", workspace.FetchBlob("Y"))
418 
419 ```
420 
421 **Result**
422 
423 ```
424 
425 X:
426  [[0.01 0.05 0.02 0.02 0.9 ]
427  [0.03 0.1 0.42 0.05 0.4 ]]
428 label:
429  [4 2]
430 Y:
431  [0.10536055 0.8675006 ]
432 
433 ```
434 
435 </details>
436 
437 
438 )DOC")
439  .Input(
440  0,
441  "X",
442  "Input tensor which is almost always the result of a softmax operation. $X$ is a 2D array of size $NxD$, where $N$ is the batch size and $D$ is the number of classes.")
443  .Input(
444  1,
445  "label",
446  "Blob containing the labels used to compare the input. $label$ is a length $N$ list of integers, where each element is the integer label for the $n$th element of the batch.")
447  .Output(
448  0,
449  "Y",
450  "Output blob from the cross entropy computation. $Y$ is 1D length $N$ tensor.");
451 OPERATOR_SCHEMA(LabelCrossEntropyGradient)
452  .NumInputs(3)
453  .NumOutputs(1);
454 
456  using GradientMakerBase::GradientMakerBase;
457  vector<OperatorDef> GetGradientDefs() override {
458  return SingleGradientDef(
459  "LabelCrossEntropyGradient", "",
460  vector<string>{I(0), I(1), GO(0)},
461  vector<string>{GI(0)});
462  }
463 };
464 REGISTER_GRADIENT(LabelCrossEntropy, GetLabelCrossEntropyGradient);
465 
466 REGISTER_CPU_OPERATOR(MakeTwoClass,
468 REGISTER_CPU_OPERATOR(MakeTwoClassGradient,
470 
471 REGISTER_CPU_OPERATOR(
472  SigmoidCrossEntropyWithLogits,
474 REGISTER_CPU_OPERATOR(
475  SigmoidCrossEntropyWithLogitsGradient,
477 
478 REGISTER_CPU_OPERATOR(
479  WeightedSigmoidCrossEntropyWithLogits,
481 REGISTER_CPU_OPERATOR(
482  WeightedSigmoidCrossEntropyWithLogitsGradient,
484 
485 OPERATOR_SCHEMA(MakeTwoClass)
486  .NumInputs(1)
487  .NumOutputs(1)
488  .TensorInferenceFunction(
489  [](const OperatorDef& /* unused */, const vector<TensorShape>& in) {
490  vector<TensorShape> out(1);
491  out[0].add_dims(in[0].dims(0));
492  out[0].add_dims(2);
493  return out;
494  })
495  .SetDoc(R"DOC(
496 Given a vector of probabilities, this operator transforms this into a 2-column
497  matrix with complimentary probabilities for binary classification. In explicit
498  terms, given the vector X, the output Y is vstack(1 - X, X).
499  )DOC")
500  .Input(0, "X", "Input vector of probabilities")
501  .Output(
502  0,
503  "Y",
504  "2-column matrix with complimentary probabilities of X for "
505  "binary classification");
506 
507 OPERATOR_SCHEMA(MakeTwoClassGradient)
508  .NumInputs(1)
509  .NumOutputs(1);
510 
511 OPERATOR_SCHEMA(SigmoidCrossEntropyWithLogits)
512  .Arg("log_D_trick", R"DOC(
513 default is false; if enabled, will use the log d trick to avoid the vanishing
514 gradients early on; see Goodfellow et. al (2014)
515 )DOC")
516  .Arg("unjoined_lr_loss", R"DOC(
517 default is false; if enabled, the model will be allowed to train on an unjoined
518 dataset, where some examples might be false negative and might appear
519 in the dataset later as (true) positive example.
520 )DOC")
521  .NumInputs(2)
522  .NumOutputs(1)
523  .IdenticalTypeAndShapeOfInputDim(0, 0)
524  .SetDoc(R"DOC(
525 Given two matrices logits and targets, of same shape,
526 (batch_size, num_classes), computes the sigmoid cross entropy between the two.
527 Returns a tensor of shape (batch_size,) of losses for each example.
528 )DOC")
529  .Input(0, "logits", "matrix of logits for each example and class.")
530  .Input(1, "targets", "matrix of targets, same shape as logits.")
531  .Output(0, "xentropy", "Vector with the total xentropy for each example.");
532 
533 OPERATOR_SCHEMA(SigmoidCrossEntropyWithLogitsGradient)
534  .NumInputs(3)
535  .NumOutputs(1);
536 
537 OPERATOR_SCHEMA(WeightedSigmoidCrossEntropyWithLogits)
538  .NumInputs(3)
539  .NumOutputs(1)
540  .IdenticalTypeAndShapeOfInputDim(0, 0)
541  .SetDoc(R"DOC(
542 Given three matrices: logits, targets, weights, all of the same shape,
543 (batch_size, num_classes), computes the weighted sigmoid cross entropy between
544 logits and targets. Specifically, at each position r,c, this computes
545 weights[r, c] * crossentropy(sigmoid(logits[r, c]), targets[r, c]), and then
546 averages over each row.
547 Returns a tensor of shape (batch_size,) of losses for each example.
548 )DOC")
549  .Input(0, "logits", "matrix of logits for each example and class.")
550  .Input(1, "targets", "matrix of targets, same shape as logits.")
551  .Input(2, "weights", "matrix of weights, same shape as logits.")
552  .Output(0, "xentropy", "Vector with the total xentropy for each example.");
553 
554 OPERATOR_SCHEMA(WeightedSigmoidCrossEntropyWithLogitsGradient)
555  .NumInputs(4)
556  .NumOutputs(1);
557 
559  using GradientMakerBase::GradientMakerBase;
560  vector<OperatorDef> GetGradientDefs() override {
561  return SingleGradientDef(
562  "MakeTwoClassGradient",
563  "",
564  vector<string>{GO(0)},
565  vector<string>{GI(0)});
566  }
567 };
568 REGISTER_GRADIENT(MakeTwoClass, GetMakeTwoClassGradient);
569 
571  using GradientMakerBase::GradientMakerBase;
572  vector<OperatorDef> GetGradientDefs() override {
573  return SingleGradientDef(
574  "SigmoidCrossEntropyWithLogitsGradient",
575  "",
576  vector<string>{GO(0), I(0), I(1)},
577  vector<string>{GI(0)});
578  }
579 };
580 REGISTER_GRADIENT(
581  SigmoidCrossEntropyWithLogits,
583 
585  : public GradientMakerBase {
586  using GradientMakerBase::GradientMakerBase;
587  vector<OperatorDef> GetGradientDefs() override {
588  return SingleGradientDef(
589  "WeightedSigmoidCrossEntropyWithLogitsGradient",
590  "",
591  vector<string>{GO(0), I(0), I(1), I(2)},
592  vector<string>{GI(0)});
593  }
594 };
595 REGISTER_GRADIENT(
596  WeightedSigmoidCrossEntropyWithLogits,
598 
599 REGISTER_CPU_OPERATOR(CrossEntropy,
601 REGISTER_CPU_OPERATOR(CrossEntropyGradient,
603 
604 OPERATOR_SCHEMA(CrossEntropy)
605  .NumInputs(2)
606  .NumOutputs(1)
607  .IdenticalTypeAndShapeOfInputDim(0, 0)
608  .SetDoc(R"DOC(
609 This operator computes the cross entropy between a $NxD$ dimensional input data tensor $X$ and a $NxD$ dimensional input label tensor $label$. The op produces a single length $N$ output tensor $Y$. Here, $N$ is considered the batch size and $D$ is the size of each element in the batch. In practice, it is most commonly used at the end of models as a part of the loss computation, after the SoftMax operator and before the AveragedLoss operator. The cross entropy operation is defined as follows
610 
611 $$Y_i = \sum_j (label_{ij} * log(X_{ij}))$$
612 
613 where ($i$, $j$) is the classifier's prediction of the $j$th class (the correct one), and $i$ is the batch size. Each log has a lower limit for numerical stability.
614 
615 Github Links:
616 - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/cross_entropy_op.h
617 - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/cross_entropy_op.cc
618 
619 <details>
620 
621 <summary> <b>Example</b> </summary>
622 
623 **Code**
624 
625 ```
626 
627 workspace.ResetWorkspace()
628 
629 op = core.CreateOperator(
630  "CrossEntropy",
631  ["X", "label"],
632  ["Y"]
633 )
634 
635 // Create X: Sample softmax output for 5-class model
636 X = np.array([[.01, .05, .02, .02, .9],[.03, .1, .42, .05, .4]])
637 print("X:\n",X)
638 
639 // Create label: Sample 1-hot ground truth label vectors
640 label = np.array([[0.,0.,0.,0.,1.],[0.,0.,1.,0.,0.]])
641 print("label:\n",label)
642 
643 // Feed X & label into workspace
644 workspace.FeedBlob("X", X.astype(np.float32))
645 workspace.FeedBlob("label", label.astype(np.float32))
646 
647 // Run op
648 workspace.RunOperatorOnce(op)
649 
650 // Collect Output
651 print("Y:\n", workspace.FetchBlob("Y"))
652 
653 ```
654 
655 **Result**
656 
657 ```
658 
659 X:
660  [[0.01 0.05 0.02 0.02 0.9 ]
661  [0.03 0.1 0.42 0.05 0.4 ]]
662 label:
663  [[0. 0. 0. 0. 1.]
664  [0. 0. 1. 0. 0.]]
665 Y:
666  [0.10536055 0.8675006 ]
667 
668 ```
669 
670 </details>
671 
672 
673 )DOC")
674  .Input(
675  0,
676  "X",
677  "Input tensor which is almost always the result of a softmax operation. $X$ is a 2D array of size $NxD$, where $N$ is the batch size and $D$ is the number of classes.")
678  .Input(
679  1,
680  "label",
681  "Blob containing the labels used to compare the input. $label$ is the same shape as $X$.")
682  .Output(
683  0,
684  "Y",
685  "Output blob from the cross entropy computation. $Y$ is 1D length $N$ tensor.");
686 OPERATOR_SCHEMA(CrossEntropyGradient)
687  .NumInputs(3)
688  .NumOutputs(1);
689 
691  using GradientMakerBase::GradientMakerBase;
692  vector<OperatorDef> GetGradientDefs() override {
693  return SingleGradientDef(
694  "CrossEntropyGradient", "",
695  vector<string>{I(0), I(1), GO(0)},
696  vector<string>{GI(0)});
697  }
698 };
699 REGISTER_GRADIENT(CrossEntropy, GetCrossEntropyGradient);
700 
701 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...
Definition: static.cpp:70