1 #include "caffe2/operators/cross_entropy_op.h" 2 #include "caffe2/utils/eigen_utils.h" 8 inline float sigmoid_xent_forward(
float lgt,
float tgt) {
9 return lgt * (tgt - (lgt >= 0)) - log(1 + exp(lgt - 2 * lgt * (lgt >= 0)));
12 inline float sigmoid_xent_backward(
float lgt,
float tgt) {
13 return tgt - 1. / (1. + exp(-lgt));
16 inline float sigmoid_partition(
float lgt) {
18 return lgt * (lgt >= 0) + log(1 + exp(lgt - 2 * lgt * (lgt >= 0)));
21 inline float sigmoid_xent_forward_with_log_d_trick(
float lgt,
float tgt) {
22 return (2 * tgt - 1.) * (lgt - sigmoid_partition(lgt));
25 inline float sigmoid_xent_backward_with_log_d_trick(
float lgt,
float tgt) {
26 return (2 * tgt - 1.) / (1. + exp(lgt));
29 inline float unjoined_sigmoid_xent_forward(
float lgt,
float tgt) {
30 return lgt * tgt + (tgt - 1) * lgt * (lgt >= 0) -
31 (1 - tgt) * log(1 + exp(lgt - 2 * lgt * (lgt >= 0)));
34 inline float unjoined_sigmoid_xent_backward(
float lgt,
float tgt) {
35 return tgt - (1. - tgt) / (1. + exp(-lgt));
41 bool LabelCrossEntropyOp<float, CPUContext>::RunOnDevice() {
43 auto& label = Input(1);
48 D = X.size_from_dim(1);
54 (label.dim() == 1) || (label.dim() == 2 && label.dim32(1) == 1));
55 CAFFE_ENFORCE_EQ(label.dim32(0), N);
56 auto* Y = Output(0, {N}, at::dtype<float>());
57 const auto* Xdata = X.data<
float>();
58 const auto* labelData = label.data<
int>();
59 auto* Ydata = Y->template mutable_data<float>();
61 (ConstEigenVectorArrayMap<int>(labelData, N) < D).all() &&
62 (ConstEigenVectorArrayMap<int>(labelData, N) >= 0).all(),
63 "Label seems to be outside of supported range. Supported labels are in " 67 for (
int i = 0; i < N; ++i) {
68 Ydata[i] = -log(std::max(Xdata[i * D + labelData[i]], kLOG_THRESHOLD()));
74 bool SigmoidCrossEntropyWithLogitsOp<float, CPUContext>::RunOnDevice() {
75 auto& logits = Input(0);
76 auto& targets = Input(1);
77 CAFFE_ENFORCE_EQ(logits.sizes(), targets.sizes());
78 const auto inner_size = logits.dim() > 0 ? logits.sizes().back() : 1;
79 const auto outer_size = logits.numel() / inner_size;
81 std::vector<int64_t> dims;
82 if (logits.dim() != 0) {
84 std::vector<int64_t>(logits.sizes().begin(), logits.sizes().end() - 1);
86 auto* out = Output(0, dims, at::dtype<float>());
87 auto* out_ptr = out->template mutable_data<float>();
89 auto* logits_ptr = logits.data<
float>();
90 auto* targets_ptr = targets.data<
float>();
93 for (
int i = 0; i < outer_size; ++i) {
95 for (
int j = 0; j < inner_size; ++j) {
96 if (unjoined_lr_loss_) {
97 value += unjoined_sigmoid_xent_forward(
98 logits_ptr[in_idx], targets_ptr[in_idx]);
101 (log_D_trick_ ? sigmoid_xent_forward_with_log_d_trick(
102 logits_ptr[in_idx], targets_ptr[in_idx])
103 : sigmoid_xent_forward(
104 logits_ptr[in_idx], targets_ptr[in_idx]));
108 out_ptr[i] = -value / inner_size;
114 bool SigmoidCrossEntropyWithLogitsGradientOp<float, CPUContext>::RunOnDevice() {
116 auto& logits = Input(1);
117 auto& targets = Input(2);
118 CAFFE_ENFORCE(logits.sizes() == targets.sizes());
119 const auto inner_size = logits.dim() > 0 ? logits.sizes().back() : 1;
120 const auto outer_size = logits.numel() / inner_size;
121 CAFFE_ENFORCE(g.numel() == outer_size);
123 auto* out = Output(0, logits.sizes(), at::dtype<float>());
124 auto* out_ptr = out->template mutable_data<float>();
126 auto* logits_ptr = logits.data<
float>();
127 auto* targets_ptr = targets.data<
float>();
128 auto* g_ptr = g.data<
float>();
131 for (
int i = 0; i < outer_size; ++i) {
132 auto g_factor = -g_ptr[i] / inner_size;
133 for (
int j = 0; j < inner_size; ++j) {
134 if (unjoined_lr_loss_) {
135 out_ptr[in_idx] = g_factor *
136 unjoined_sigmoid_xent_backward(
137 logits_ptr[in_idx], targets_ptr[in_idx]);
139 out_ptr[in_idx] = g_factor *
140 (log_D_trick_ ? sigmoid_xent_backward_with_log_d_trick(
141 logits_ptr[in_idx], targets_ptr[in_idx])
142 : sigmoid_xent_backward(
143 logits_ptr[in_idx], targets_ptr[in_idx]));
152 bool WeightedSigmoidCrossEntropyWithLogitsOp<float, CPUContext>::RunOnDevice() {
153 auto& logits = Input(0);
154 auto& targets = Input(1);
155 auto& weights = Input(2);
156 CAFFE_ENFORCE(logits.sizes() == targets.sizes());
157 CAFFE_ENFORCE(weights.sizes() == targets.sizes());
158 const auto inner_size = logits.dim() > 0 ? logits.sizes().back() : 1;
159 const auto outer_size = logits.numel() / inner_size;
161 std::vector<int64_t> dims;
162 if (logits.dim() != 0) {
164 std::vector<int64_t>(logits.sizes().begin(), logits.sizes().end() - 1);
167 auto* out = Output(0, dims, at::dtype<float>());
168 auto* out_ptr = out->template mutable_data<float>();
170 auto* logits_ptr = logits.data<
float>();
171 auto* targets_ptr = targets.data<
float>();
172 auto* weights_ptr = weights.data<
float>();
175 for (
int i = 0; i < outer_size; ++i) {
177 for (
int j = 0; j < inner_size; ++j) {
178 value += sigmoid_xent_forward(logits_ptr[in_idx], targets_ptr[in_idx]) *
182 out_ptr[i] = -value / inner_size;
188 bool WeightedSigmoidCrossEntropyWithLogitsGradientOp<float, CPUContext>::
191 auto& logits = Input(1);
192 auto& targets = Input(2);
193 auto& weights = Input(3);
194 CAFFE_ENFORCE(logits.sizes() == targets.sizes());
195 CAFFE_ENFORCE(weights.sizes() == targets.sizes());
196 const auto inner_size = logits.dim() > 0 ? logits.sizes().back() : 1;
197 const auto outer_size = logits.numel() / inner_size;
198 CAFFE_ENFORCE(g.numel() == outer_size);
200 auto* out = Output(0, logits.sizes(), at::dtype<float>());
201 auto* out_ptr = out->template mutable_data<float>();
203 auto* logits_ptr = logits.data<
float>();
204 auto* targets_ptr = targets.data<
float>();
205 auto* weights_ptr = weights.data<
float>();
206 auto* g_ptr = g.data<
float>();
209 for (
int i = 0; i < outer_size; ++i) {
210 auto g_factor = -g_ptr[i] / inner_size;
211 for (
int j = 0; j < inner_size; ++j) {
212 out_ptr[in_idx] = g_factor *
213 sigmoid_xent_backward(logits_ptr[in_idx], targets_ptr[in_idx]) *
222 bool LabelCrossEntropyGradientOp<float, CPUContext>::RunOnDevice() {
224 auto& label = Input(1);
230 D = X.size_from_dim(1);
236 (label.dim() == 1) || (label.dim() == 2 && label.dim32(1) == 1));
237 CAFFE_ENFORCE_EQ(label.dim32(0), N);
238 CAFFE_ENFORCE_EQ(dY.dim(), 1);
239 CAFFE_ENFORCE_EQ(dY.dim32(0), N);
240 auto* dX = Output(0, X.sizes(), at::dtype<float>());
241 math::Set<float, CPUContext>(
242 dX->numel(), 0.f, dX->template mutable_data<float>(), &context_);
243 const float* Xdata = X.data<
float>();
244 const float* dYdata = dY.data<
float>();
245 const int* labelData = label.data<
int>();
246 float* dXdata = dX->template mutable_data<float>();
247 for (
int i = 0; i < N; ++i) {
248 dXdata[i * D + labelData[i]] =
249 - dYdata[i] / std::max(Xdata[i * D + labelData[i]], kLOG_THRESHOLD());
255 bool MakeTwoClassOp<float, CPUContext>::RunOnDevice() {
258 auto shape = X.sizes().vec();
260 int64_t N = X.numel();
261 auto* Y = Output(0, shape, at::dtype<float>());
262 const auto* Xdata = X.data<
float>();
263 auto* Ydata = Y->template mutable_data<float>();
264 for (int64_t i = 0; i < N; ++i) {
265 DCHECK_GE(Xdata[i], 0.0);
266 DCHECK_LE(Xdata[i], 1.0);
267 Ydata[i * 2] = 1.0 - Xdata[i];
268 Ydata[i * 2 + 1] = Xdata[i];
274 bool MakeTwoClassGradientOp<float, CPUContext>::RunOnDevice() {
277 auto shape = dY.sizes().vec();
278 CAFFE_ENFORCE_GE(shape.size(), 1);
279 CAFFE_ENFORCE_EQ(shape.back(), 2);
281 auto* dX = Output(0, shape, at::dtype<float>());
282 const float* dYdata = dY.data<
float>();
283 float* dXdata = dX->template mutable_data<float>();
284 int64_t N = dX->numel();
286 for (int64_t i = 0; i < N; ++i) {
287 dXdata[i] = dYdata[i * 2 + 1] - dYdata[i * 2];
293 bool CrossEntropyOp<float, CPUContext>::RunOnDevice() {
295 auto& label = Input(1);
300 D = X.size_from_dim(1);
306 (label.dim() == 1) || (label.dim() == 2 && label.dim32(1) == D));
307 CAFFE_ENFORCE_EQ(label.dim32(0), N);
308 auto* Y = Output(0, vector<int64_t>{N}, at::dtype<float>());
309 const float* Xdata = X.data<
float>();
310 const float* labelData = label.data<
float>();
311 auto* Ydata = Y->template mutable_data<float>();
313 (ConstEigenArrayMap<float>(labelData, D, N) <= 1.0f).all() &&
314 (ConstEigenArrayMap<float>(labelData, D, N) >= 0.0f).all(),
315 "Soft label seems incorrect: label value should be a probability ",
316 "between 0 and 1.0. You may be using the wrong cross entropy operator; ",
317 "use LabelCrossEntropy if the labels are integers whose values are at ",
318 "most the number of classes, ",
321 EigenArrayMap<float>(Ydata, 1, N) =
322 -(ConstEigenArrayMap<float>(labelData, D, N) *
323 ConstEigenArrayMap<float>(Xdata, D, N).cwiseMax(kLOG_THRESHOLD()).log())
330 bool CrossEntropyGradientOp<float, CPUContext>::RunOnDevice() {
332 auto& label = Input(1);
338 D = X.size_from_dim(1);
344 (label.dim() == 1) || (label.dim() == 2 && label.dim32(1) == D));
345 CAFFE_ENFORCE_EQ(label.dim32(0), N);
346 CAFFE_ENFORCE_EQ(dY.dim(), 1);
347 CAFFE_ENFORCE_EQ(dY.dim32(0), N);
348 auto* dX = Output(0, X.sizes(), at::dtype<float>());
349 math::Set<float, CPUContext>(
350 dX->numel(), 0.f, dX->template mutable_data<float>(), &context_);
351 const float* Xdata = X.data<
float>();
352 const float* dYdata = dY.data<
float>();
353 const float* labelData = label.data<
float>();
354 float* dXdata = dX->template mutable_data<float>();
355 EigenArrayMap<float>(dXdata, D, N) =
356 (ConstEigenArrayMap<float>(labelData, D, N) /
357 ConstEigenArrayMap<float>(Xdata, D, N).cwiseMax(kLOG_THRESHOLD()))
359 (-ConstEigenVectorArrayMap<float>(dYdata, N).transpose());
363 REGISTER_CPU_OPERATOR(LabelCrossEntropy,
364 LabelCrossEntropyOp<float, CPUContext>);
365 REGISTER_CPU_OPERATOR(LabelCrossEntropyGradient,
366 LabelCrossEntropyGradientOp<float, CPUContext>);
368 OPERATOR_SCHEMA(LabelCrossEntropy)
371 .IdenticalTypeAndShapeOfInputDim(0, 0)
373 This operator computes the cross entropy between a $NxD$ dimensional input data tensor $X$ and a one dimensional input label tensor $label$. The op produces a single length $N$ output tensor $Y$. Here, $N$ is considered the batch size and $D$ is the size of each element in the batch. In practice, it is most commonly used at the end of models as a part of the loss computation, after the SoftMax operator and before the AveragedLoss operator. The cross entropy operation is defined as follows 375 $$Y_i = -log(X_{ij})$$ 377 where ($i$, $j$) is the classifier's prediction of the $j$th class (the correct one), and $i$ is the batch size. Each log has a lower limit for numerical stability. 379 The difference between *LabelCrossEntropy* and *CrossEntropy* is how the labels are specified. Here, the labels are a length $N$ list of integers, whereas in CrossEntropy the labels are a $NxD$ dimensional matrix of one hot label vectors. However, the results of computation should be the same, as shown in the two examples where ($i$, $j$) is the classifier's prediction of the $j$th class (the correct one), and $i$ is the batch size. Each log has a lower limit for numerical stability. 382 - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/cross_entropy_op.h 383 - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/cross_entropy_op.cc 387 <summary> <b>Example</b> </summary> 393 workspace.ResetWorkspace() 395 op = core.CreateOperator( 401 // Create X: Sample softmax output for 5-class model 402 X = np.array([[.01, .05, .02, .02, .9],[.03, .1, .42, .05, .4]]) 405 // Create label: Sample 1-hot ground truth label vectors 406 label = np.array([4,2]) 407 print("label:\n",label) 409 // Feed X & label into workspace 410 workspace.FeedBlob("X", X.astype(np.float32)) 411 workspace.FeedBlob("label", label.astype(np.int32)) 414 workspace.RunOperatorOnce(op) 417 print("Y:\n", workspace.FetchBlob("Y")) 426 [[0.01 0.05 0.02 0.02 0.9 ] 427 [0.03 0.1 0.42 0.05 0.4 ]] 431 [0.10536055 0.8675006 ] 442 "Input tensor which is almost always the result of a softmax operation. $X$ is a 2D array of size $NxD$, where $N$ is the batch size and $D$ is the number of classes.")
446 "Blob containing the labels used to compare the input. $label$ is a length $N$ list of integers, where each element is the integer label for the $n$th element of the batch.")
450 "Output blob from the cross entropy computation. $Y$ is 1D length $N$ tensor.");
451 OPERATOR_SCHEMA(LabelCrossEntropyGradient)
456 using GradientMakerBase::GradientMakerBase;
457 vector<OperatorDef> GetGradientDefs()
override {
459 "LabelCrossEntropyGradient",
"",
460 vector<string>{I(0), I(1), GO(0)},
461 vector<string>{GI(0)});
466 REGISTER_CPU_OPERATOR(MakeTwoClass,
468 REGISTER_CPU_OPERATOR(MakeTwoClassGradient,
471 REGISTER_CPU_OPERATOR(
472 SigmoidCrossEntropyWithLogits,
474 REGISTER_CPU_OPERATOR(
475 SigmoidCrossEntropyWithLogitsGradient,
478 REGISTER_CPU_OPERATOR(
479 WeightedSigmoidCrossEntropyWithLogits,
481 REGISTER_CPU_OPERATOR(
482 WeightedSigmoidCrossEntropyWithLogitsGradient,
485 OPERATOR_SCHEMA(MakeTwoClass)
488 .TensorInferenceFunction(
489 [](
const OperatorDef& ,
const vector<TensorShape>& in) {
490 vector<TensorShape> out(1);
491 out[0].add_dims(in[0].dims(0));
496 Given a vector of probabilities, this operator transforms this into a 2-column 497 matrix with complimentary probabilities for binary classification. In explicit 498 terms, given the vector X, the output Y is vstack(1 - X, X). 500 .Input(0, "X",
"Input vector of probabilities")
504 "2-column matrix with complimentary probabilities of X for " 505 "binary classification");
507 OPERATOR_SCHEMA(MakeTwoClassGradient)
511 OPERATOR_SCHEMA(SigmoidCrossEntropyWithLogits)
512 .Arg(
"log_D_trick", R
"DOC( 513 default is false; if enabled, will use the log d trick to avoid the vanishing 514 gradients early on; see Goodfellow et. al (2014) 516 .Arg("unjoined_lr_loss", R
"DOC( 517 default is false; if enabled, the model will be allowed to train on an unjoined 518 dataset, where some examples might be false negative and might appear 519 in the dataset later as (true) positive example. 523 .IdenticalTypeAndShapeOfInputDim(0, 0) 525 Given two matrices logits and targets, of same shape, 526 (batch_size, num_classes), computes the sigmoid cross entropy between the two. 527 Returns a tensor of shape (batch_size,) of losses for each example. 529 .Input(0, "logits",
"matrix of logits for each example and class.")
530 .Input(1,
"targets",
"matrix of targets, same shape as logits.")
531 .Output(0,
"xentropy",
"Vector with the total xentropy for each example.");
533 OPERATOR_SCHEMA(SigmoidCrossEntropyWithLogitsGradient)
537 OPERATOR_SCHEMA(WeightedSigmoidCrossEntropyWithLogits)
540 .IdenticalTypeAndShapeOfInputDim(0, 0)
542 Given three matrices: logits, targets, weights, all of the same shape, 543 (batch_size, num_classes), computes the weighted sigmoid cross entropy between 544 logits and targets. Specifically, at each position r,c, this computes 545 weights[r, c] * crossentropy(sigmoid(logits[r, c]), targets[r, c]), and then 546 averages over each row. 547 Returns a tensor of shape (batch_size,) of losses for each example. 549 .Input(0, "logits",
"matrix of logits for each example and class.")
550 .Input(1,
"targets",
"matrix of targets, same shape as logits.")
551 .Input(2,
"weights",
"matrix of weights, same shape as logits.")
552 .Output(0,
"xentropy",
"Vector with the total xentropy for each example.");
554 OPERATOR_SCHEMA(WeightedSigmoidCrossEntropyWithLogitsGradient)
559 using GradientMakerBase::GradientMakerBase;
560 vector<OperatorDef> GetGradientDefs()
override {
562 "MakeTwoClassGradient",
564 vector<string>{GO(0)},
565 vector<string>{GI(0)});
571 using GradientMakerBase::GradientMakerBase;
572 vector<OperatorDef> GetGradientDefs()
override {
574 "SigmoidCrossEntropyWithLogitsGradient",
576 vector<string>{GO(0), I(0), I(1)},
577 vector<string>{GI(0)});
581 SigmoidCrossEntropyWithLogits,
586 using GradientMakerBase::GradientMakerBase;
587 vector<OperatorDef> GetGradientDefs()
override {
589 "WeightedSigmoidCrossEntropyWithLogitsGradient",
591 vector<string>{GO(0), I(0), I(1), I(2)},
592 vector<string>{GI(0)});
596 WeightedSigmoidCrossEntropyWithLogits,
599 REGISTER_CPU_OPERATOR(CrossEntropy,
601 REGISTER_CPU_OPERATOR(CrossEntropyGradient,
604 OPERATOR_SCHEMA(CrossEntropy)
607 .IdenticalTypeAndShapeOfInputDim(0, 0)
609 This operator computes the cross entropy between a $NxD$ dimensional input data tensor $X$ and a $NxD$ dimensional input label tensor $label$. The op produces a single length $N$ output tensor $Y$. Here, $N$ is considered the batch size and $D$ is the size of each element in the batch. In practice, it is most commonly used at the end of models as a part of the loss computation, after the SoftMax operator and before the AveragedLoss operator. The cross entropy operation is defined as follows 611 $$Y_i = \sum_j (label_{ij} * log(X_{ij}))$$ 613 where ($i$, $j$) is the classifier's prediction of the $j$th class (the correct one), and $i$ is the batch size. Each log has a lower limit for numerical stability. 616 - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/cross_entropy_op.h 617 - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/cross_entropy_op.cc 621 <summary> <b>Example</b> </summary> 627 workspace.ResetWorkspace() 629 op = core.CreateOperator( 635 // Create X: Sample softmax output for 5-class model 636 X = np.array([[.01, .05, .02, .02, .9],[.03, .1, .42, .05, .4]]) 639 // Create label: Sample 1-hot ground truth label vectors 640 label = np.array([[0.,0.,0.,0.,1.],[0.,0.,1.,0.,0.]]) 641 print("label:\n",label) 643 // Feed X & label into workspace 644 workspace.FeedBlob("X", X.astype(np.float32)) 645 workspace.FeedBlob("label", label.astype(np.float32)) 648 workspace.RunOperatorOnce(op) 651 print("Y:\n", workspace.FetchBlob("Y")) 660 [[0.01 0.05 0.02 0.02 0.9 ] 661 [0.03 0.1 0.42 0.05 0.4 ]] 666 [0.10536055 0.8675006 ] 677 "Input tensor which is almost always the result of a softmax operation. $X$ is a 2D array of size $NxD$, where $N$ is the batch size and $D$ is the number of classes.")
681 "Blob containing the labels used to compare the input. $label$ is the same shape as $X$.")
685 "Output blob from the cross entropy computation. $Y$ is 1D length $N$ tensor.");
686 OPERATOR_SCHEMA(CrossEntropyGradient)
691 using GradientMakerBase::GradientMakerBase;
692 vector<OperatorDef> GetGradientDefs()
override {
694 "CrossEntropyGradient",
"",
695 vector<string>{I(0), I(1), GO(0)},
696 vector<string>{GI(0)});
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...