1 #include "caffe2/operators/logit_op.h" 6 #include "caffe2/operators/elementwise_ops.h" 7 #include "caffe2/utils/eigen_utils.h" 13 bool LogitFunctor<CPUContext>::
14 operator()(
const int size,
const T* X,
T* Y, CPUContext* )
const {
15 ConstEigenVectorMap<T> X_vec(X, size);
16 EigenVectorMap<T> Y_vec(Y, size);
17 Y_vec = X_vec.array().min(static_cast<T>(1.0f - eps_));
18 Y_vec = Y_vec.array().max(eps_);
19 Y_vec = (Y_vec.array() / (
T(1) - Y_vec.array())).log();
24 bool LogitGradientOp<float, CPUContext>::RunOnDevice() {
25 const auto& X = Input(0);
26 const auto& dY = Input(1);
28 auto* dX = Output(0, X.sizes(), at::dtype<float>());
29 int channels = X.dim32(X.dim() - 1);
30 ConstEigenArrayMap<float> Xmat(
31 X.template data<float>(), channels, X.numel() / channels);
32 ConstEigenArrayMap<float> dYmat(
33 dY.template data<float>(), channels, X.numel() / channels);
34 EigenArrayMap<float> dXmat(
35 dX->template mutable_data<float>(), channels, X.numel() / channels);
36 dXmat = (Xmat < eps_ || Xmat > 1.0 - eps_)
37 .select(0, dYmat * ((1 - Xmat) * Xmat).inverse());
41 REGISTER_CPU_OPERATOR(
43 UnaryElementwiseWithArgsOp<
46 LogitFunctor<CPUContext>>);
48 REGISTER_CPU_OPERATOR(LogitGradient, LogitGradientOp<float, CPUContext>);
50 OPERATOR_SCHEMA(Logit)
53 .AllowInplace({{0, 0}})
54 .IdenticalTypeAndShape()
56 Elementwise logit transform: logit(x) = log(x / (1 - x)), where x is the 57 input data clampped in (eps, 1-eps). 59 .Arg("eps (optional)",
"small positive epsilon value, the default is 1e-6.")
60 .Input(0,
"X",
"input float tensor")
61 .Output(0,
"Y",
"output float tensor");
63 OPERATOR_SCHEMA(LogitGradient)
66 .Input(0,
"X",
"input float tensor")
67 .Input(1,
"dY",
"input float tensor")
68 .Output(0,
"dX",
"output float tensor")
69 .Arg(
"eps",
"small positive epsilon value, the default is 1e-6.");
73 class GetLogitGradient :
public GradientMakerBase {
74 using GradientMakerBase::GradientMakerBase;
75 vector<OperatorDef> GetGradientDefs()
override {
76 return vector<OperatorDef>{CreateOperatorDef(
79 std::vector<std::string>{I(0), GO(0)},
80 std::vector<std::string>{GI(0)})};
86 REGISTER_GRADIENT(Logit, GetLogitGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...