Caffe2 - C++ API
A deep learning, cross platform ML framework
minmax_gradient_ops.cc
1 
17 #include "caffe2/operators/minmax_ops.h"
18 
19 namespace caffe2 {
20 
21 REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp<float, CPUContext>);
22 REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp<float, CPUContext>);
23 
24 OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
25 OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
26 
27 template <typename T, class Context>
28 bool SelectGradientOpBase<T, Context>::RunOnDevice() {
29  auto& output = Input(0);
30  auto& grad_output = Input(1);
31  const int kInputStartOffset = 2;
32 
33  const T* data = output.template data<T>();
34  ConstEigenArrayMap<T> output_array(
35  output.template data<T>(), 1, output.size());
36  ConstEigenArrayMap<T> grad_out_array(
37  grad_output.template data<T>(), 1, grad_output.size());
38 
39  for (int i = 0; i < OutputSize(); i++) {
40  auto& input = Input(i + kInputStartOffset);
41  ConstEigenArrayMap<T> input_array(
42  input.template data<T>(), 1, input.size());
43 
44  auto* grad_input = Output(i);
45  grad_input->ResizeLike(input);
46  EigenArrayMap<T> grad_in_array(
47  grad_input->template mutable_data<T>(), 1, grad_input->size());
48  grad_in_array = grad_out_array *
49  input_array.cwiseEqual(output_array).template cast<T>();
50  }
51  return true;
52 }
53 
55  using GradientMakerBase::GradientMakerBase;
56  vector<OperatorDef> GetGradientDefs() override {
57  auto gradInputs = vector<string>();
58  auto inputs = vector<string>{O(0), GO(0)};
59  for (int i = 0; i < def_.input_size(); i++) {
60  gradInputs.push_back(GI(i));
61  inputs.push_back(I(i));
62  }
63  return SingleGradientDef("MaxGradient", "", inputs, gradInputs);
64  }
65 };
66 REGISTER_GRADIENT(Max, GetMaxGradient);
67 
69  using GradientMakerBase::GradientMakerBase;
70  vector<OperatorDef> GetGradientDefs() override {
71  auto gradInputs = vector<string>();
72  auto inputs = vector<string>{O(0), GO(0)};
73  for (int i = 0; i < def_.input_size(); i++) {
74  gradInputs.push_back(GI(i));
75  inputs.push_back(I(i));
76  }
77  return SingleGradientDef("MinGradient", "", inputs, gradInputs);
78  }
79 };
80 REGISTER_GRADIENT(Min, GetMinGradient);
81 
82 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...