Caffe2 - C++ API
A deep learning, cross platform ML framework
minmax_gradient_ops.cc
1 #include "caffe2/operators/minmax_ops.h"
2 
3 #include <string>
4 #include <vector>
5 
6 #include "caffe2/utils/eigen_utils.h"
7 
8 namespace caffe2 {
9 
10 template <typename T, class Context>
11 bool SelectGradientOpBase<T, Context>::RunOnDevice() {
12  const auto& Y = Input(0);
13  const auto& dY = Input(1);
14  const int N = Y.numel();
15  ConstEigenVectorArrayMap<T> Y_arr(Y.template data<T>(), N);
16  ConstEigenVectorArrayMap<T> dY_arr(dY.template data<T>(), N);
17  for (int i = 0; i < OutputSize(); i++) {
18  const auto& Xi = Input(i + 2);
19  auto* dXi = Output(i, Xi.sizes(), at::dtype<T>());
20  ConstEigenVectorArrayMap<T> Xi_arr(Xi.template data<T>(), N);
21  EigenVectorArrayMap<T> dXi_arr(dXi->template mutable_data<T>(), N);
22  dXi_arr = (Xi_arr == Y_arr).template cast<T>() * dY_arr;
23  }
24  return true;
25 }
26 
27 REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp<float, CPUContext>);
28 REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp<float, CPUContext>);
29 
30 OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
31 OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
32 
33 namespace {
34 
35 class GetMaxGradient : public GradientMakerBase {
36  using GradientMakerBase::GradientMakerBase;
37  std::vector<OperatorDef> GetGradientDefs() override {
38  std::vector<std::string> inputs = {O(0), GO(0)};
39  std::vector<std::string> grad_inputs;
40  for (int i = 0; i < def_.input_size(); ++i) {
41  inputs.push_back(I(i));
42  grad_inputs.push_back(GI(i));
43  }
44  return SingleGradientDef("MaxGradient", "", inputs, grad_inputs);
45  }
46 };
47 
48 class GetMinGradient : public GradientMakerBase {
49  using GradientMakerBase::GradientMakerBase;
50  vector<OperatorDef> GetGradientDefs() override {
51  std::vector<std::string> inputs = {O(0), GO(0)};
52  std::vector<std::string> grad_inputs;
53  for (int i = 0; i < def_.input_size(); ++i) {
54  inputs.push_back(I(i));
55  grad_inputs.push_back(GI(i));
56  }
57  return SingleGradientDef("MinGradient", "", inputs, grad_inputs);
58  }
59 };
60 
61 } // namespace
62 
63 REGISTER_GRADIENT(Max, GetMaxGradient);
64 REGISTER_GRADIENT(Min, GetMinGradient);
65 
66 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13