1 #include "caffe2/operators/jsd_op.h" 7 static constexpr
float kLOG_THRESHOLD() {
11 inline float logit(
float p) {
14 float x = std::min(std::max(p, kLOG_THRESHOLD()), 1 - kLOG_THRESHOLD());
15 return -log(1. / x - 1.);
18 inline float entropy(
float p) {
19 if (p < kLOG_THRESHOLD() || 1 - p < kLOG_THRESHOLD()) {
23 return -p * log(p) - q * log(q);
29 bool BernoulliJSDOp<float, CPUContext>::RunOnDevice() {
33 CAFFE_ENFORCE_EQ(
T.numel(), N);
34 auto* L = Output(0, X.sizes(), at::dtype<float>());
35 auto* x_data = X.data<
float>();
36 auto* t_data =
T.data<
float>();
37 auto* l_data = L->template mutable_data<float>();
38 for (
int i = 0; i < N; i++) {
39 auto p_mdl = x_data[i];
40 auto p_emp = t_data[i];
41 auto p_avg = (p_mdl + p_emp) / 2.;
42 auto jsd = entropy(p_avg) - (entropy(p_mdl) + entropy(p_emp)) / 2.;
49 bool BernoulliJSDGradientOp<float, CPUContext>::RunOnDevice() {
55 auto* gi = Output(0, X.sizes(), at::dtype<float>());
56 auto* go_data = go.data<
float>();
57 auto* x_data = X.data<
float>();
58 auto* t_data =
T.data<
float>();
59 auto* gi_data = gi->template mutable_data<float>();
60 for (
int i = 0; i < N; i++) {
61 auto p_mdl = x_data[i];
62 auto p_emp = t_data[i];
63 auto p_avg = (p_mdl + p_emp) / 2.;
64 auto g_jsd = (logit(p_mdl) - logit(p_avg)) / 2.;
65 gi_data[i] = go_data[i] * g_jsd;
69 REGISTER_CPU_OPERATOR(BernoulliJSD, BernoulliJSDOp<float, CPUContext>);
70 REGISTER_CPU_OPERATOR(
72 BernoulliJSDGradientOp<float, CPUContext>);
73 OPERATOR_SCHEMA(BernoulliJSD)
77 Computes the Jensen-Shannon divergence (JSD) between two Bernoulli distributions 78 where each is parametrized by a single probability. 80 .Input(0, "X",
"array of probabilities for prediction")
81 .Input(0,
"T",
"array of probabilities for target")
82 .Output(0,
"L",
"array of JSD losses");
83 OPERATOR_SCHEMA(BernoulliJSDGradient).NumInputs(3).NumOutputs(1);
86 using GradientMakerBase::GradientMakerBase;
87 vector<OperatorDef> GetGradientDefs()
override {
89 "BernoulliJSDGradient",
91 vector<string>{GO(0), I(0), I(1)},
92 vector<string>{GI(0)});
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...