1 #include "caffe2/operators/minmax_ops.h" 6 #include "caffe2/utils/eigen_utils.h" 10 template <
typename T,
class Context>
11 bool SelectGradientOpBase<T, Context>::RunOnDevice() {
12 const auto& Y = Input(0);
13 const auto& dY = Input(1);
14 const int N = Y.numel();
15 ConstEigenVectorArrayMap<T> Y_arr(Y.template data<T>(), N);
16 ConstEigenVectorArrayMap<T> dY_arr(dY.template data<T>(), N);
17 for (
int i = 0; i < OutputSize(); i++) {
18 const auto& Xi = Input(i + 2);
19 auto* dXi = Output(i, Xi.sizes(), at::dtype<T>());
20 ConstEigenVectorArrayMap<T> Xi_arr(Xi.template data<T>(), N);
21 EigenVectorArrayMap<T> dXi_arr(dXi->template mutable_data<T>(), N);
22 dXi_arr = (Xi_arr == Y_arr).
template cast<T>() * dY_arr;
27 REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp<float, CPUContext>);
28 REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp<float, CPUContext>);
30 OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
31 OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
35 class GetMaxGradient :
public GradientMakerBase {
36 using GradientMakerBase::GradientMakerBase;
37 std::vector<OperatorDef> GetGradientDefs()
override {
38 std::vector<std::string> inputs = {O(0), GO(0)};
39 std::vector<std::string> grad_inputs;
40 for (
int i = 0; i < def_.input_size(); ++i) {
41 inputs.push_back(I(i));
42 grad_inputs.push_back(GI(i));
44 return SingleGradientDef(
"MaxGradient",
"", inputs, grad_inputs);
48 class GetMinGradient :
public GradientMakerBase {
49 using GradientMakerBase::GradientMakerBase;
50 vector<OperatorDef> GetGradientDefs()
override {
51 std::vector<std::string> inputs = {O(0), GO(0)};
52 std::vector<std::string> grad_inputs;
53 for (
int i = 0; i < def_.input_size(); ++i) {
54 inputs.push_back(I(i));
55 grad_inputs.push_back(GI(i));
57 return SingleGradientDef(
"MinGradient",
"", inputs, grad_inputs);
63 REGISTER_GRADIENT(Max, GetMaxGradient);
64 REGISTER_GRADIENT(Min, GetMinGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...