Caffe2 - C++ API
A deep learning, cross platform ML framework
sigmoid_cross_entropy_with_logits_cpu.cc
1 #include <ATen/core/dispatch/KernelRegistration.h>
2 #include "caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h"
3 #include "caffe2/utils/math.h"
4 #include "caffe2/core/tensor.h"
5 
6 using caffe2::Tensor;
7 
8 namespace caffe2 {
9 namespace {
10 inline float sigmoid_partition(float lgt) {
11  // computes log(1 + exp(lgt)) with only exp(x) function when x >= 0
12  return lgt * (lgt >= 0) + log(1 + exp(lgt - 2 * lgt * (lgt >= 0)));
13 }
14 
15 inline float sigmoid_xent_forward(float lgt, float tgt) {
16  return lgt * (tgt - (lgt >= 0)) - log(1 + exp(lgt - 2 * lgt * (lgt >= 0)));
17 }
18 
19 inline float sigmoid_xent_forward_with_log_d_trick(float lgt, float tgt) {
20  return (2 * tgt - 1.) * (lgt - sigmoid_partition(lgt));
21 }
22 
23 inline float unjoined_sigmoid_xent_forward(float lgt, float tgt) {
24  return lgt * tgt + (tgt - 1) * lgt * (lgt >= 0) -
25  (1 - tgt) * log(1 + exp(lgt - 2 * lgt * (lgt >= 0)));
26 }
27 
28 void sigmoid_cross_entropy_with_logits_op_cpu_impl(
29  const at::Tensor& logits_,
30  const at::Tensor& targets_,
31  const at::Tensor& out_,
32  bool log_D_trick,
33  bool unjoined_lr_loss) {
34  Tensor logits{C10Tensor(logits_)};
35  Tensor targets{C10Tensor(targets_)};
36  Tensor out{C10Tensor(out_)};
37 
38  CAFFE_ENFORCE_EQ(logits.sizes(), targets.sizes());
39  const auto inner_size = logits.dim() > 0 ? logits.sizes().back() : 1;
40  const auto outer_size = logits.numel() / inner_size;
41 
42  if (logits.dim() == 0) {
43  out.Resize(std::vector<int64_t>{});
44  } else {
45  std::vector<int64_t> dims(logits.sizes().begin(), logits.sizes().end() - 1);
46  out.Resize(dims);
47  }
48  auto* out_ptr = out.mutable_data<float>();
49 
50  auto* logits_ptr = logits.data<float>();
51  auto* targets_ptr = targets.data<float>();
52 
53  auto in_idx = 0;
54  for (int i = 0; i < outer_size; ++i) {
55  float value = 0;
56  for (int j = 0; j < inner_size; ++j) {
57  if (unjoined_lr_loss) {
58  value += unjoined_sigmoid_xent_forward(
59  logits_ptr[in_idx], targets_ptr[in_idx]);
60  } else {
61  value +=
62  (log_D_trick ? sigmoid_xent_forward_with_log_d_trick(
63  logits_ptr[in_idx], targets_ptr[in_idx])
64  : sigmoid_xent_forward(
65  logits_ptr[in_idx], targets_ptr[in_idx]));
66  }
67  ++in_idx;
68  }
69  out_ptr[i] = -value / inner_size;
70  }
71 }
72 } // namespace
73 } // namespace caffe2
74 
75 namespace c10 {
76 C10_REGISTER_KERNEL(caffe2::ops::SigmoidCrossEntropyWithLogits)
77  .kernel<decltype(caffe2::sigmoid_cross_entropy_with_logits_op_cpu_impl), &caffe2::sigmoid_cross_entropy_with_logits_op_cpu_impl>()
78  .dispatchKey(CPUTensorId());
79 } // namespace c10
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Definition: tensor.h:25
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7