1 #include <ATen/core/dispatch/KernelRegistration.h> 2 #include "caffe2/operators/experimental/c10/schemas/sigmoid_cross_entropy_with_logits.h" 3 #include "caffe2/utils/math.h" 4 #include "caffe2/core/tensor.h" 10 inline float sigmoid_partition(
float lgt) {
12 return lgt * (lgt >= 0) + log(1 + exp(lgt - 2 * lgt * (lgt >= 0)));
15 inline float sigmoid_xent_forward(
float lgt,
float tgt) {
16 return lgt * (tgt - (lgt >= 0)) - log(1 + exp(lgt - 2 * lgt * (lgt >= 0)));
19 inline float sigmoid_xent_forward_with_log_d_trick(
float lgt,
float tgt) {
20 return (2 * tgt - 1.) * (lgt - sigmoid_partition(lgt));
23 inline float unjoined_sigmoid_xent_forward(
float lgt,
float tgt) {
24 return lgt * tgt + (tgt - 1) * lgt * (lgt >= 0) -
25 (1 - tgt) * log(1 + exp(lgt - 2 * lgt * (lgt >= 0)));
28 void sigmoid_cross_entropy_with_logits_op_cpu_impl(
33 bool unjoined_lr_loss) {
34 Tensor logits{C10Tensor(logits_)};
35 Tensor targets{C10Tensor(targets_)};
36 Tensor out{C10Tensor(out_)};
38 CAFFE_ENFORCE_EQ(logits.sizes(), targets.sizes());
39 const auto inner_size = logits.dim() > 0 ? logits.sizes().back() : 1;
40 const auto outer_size = logits.numel() / inner_size;
42 if (logits.dim() == 0) {
43 out.Resize(std::vector<int64_t>{});
45 std::vector<int64_t> dims(logits.sizes().begin(), logits.sizes().end() - 1);
48 auto* out_ptr = out.mutable_data<
float>();
50 auto* logits_ptr = logits.data<
float>();
51 auto* targets_ptr = targets.data<
float>();
54 for (
int i = 0; i < outer_size; ++i) {
56 for (
int j = 0; j < inner_size; ++j) {
57 if (unjoined_lr_loss) {
58 value += unjoined_sigmoid_xent_forward(
59 logits_ptr[in_idx], targets_ptr[in_idx]);
62 (log_D_trick ? sigmoid_xent_forward_with_log_d_trick(
63 logits_ptr[in_idx], targets_ptr[in_idx])
64 : sigmoid_xent_forward(
65 logits_ptr[in_idx], targets_ptr[in_idx]));
69 out_ptr[i] = -value / inner_size;
76 C10_REGISTER_KERNEL(caffe2::ops::SigmoidCrossEntropyWithLogits)
77 .kernel<decltype(caffe2::sigmoid_cross_entropy_with_logits_op_cpu_impl), &caffe2::sigmoid_cross_entropy_with_logits_op_cpu_impl>()
78 .dispatchKey(CPUTensorId());
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...