Caffe2 - C++ API
A deep learning, cross platform ML framework
minmax_ops.h
1 #ifndef CAFFE2_OPERATORS_MINMAX_OPS_H_
2 #define CAFFE2_OPERATORS_MINMAX_OPS_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/core/types.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 template <typename T, class Context>
13 class MaxOp final : public Operator<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16 
17  USE_SIMPLE_CTOR_DTOR(MaxOp)
18 
19  bool RunOnDevice() override {
20  const auto& X0 = Input(0);
21  auto* Y = Output(0);
22  Y->ResizeLike(X0);
23  const T* X0_data = X0.template data<T>();
24  T* Y_data = Y->template mutable_data<T>();
25  const int N = X0.numel();
26  if (InputSize() == 1) {
27  if (Y != &X0) {
28  context_.template CopySameDevice<T>(N, X0_data, Y_data);
29  }
30  return true;
31  }
32  const auto& X1 = Input(1);
33  CAFFE_ENFORCE_EQ(
34  X0.sizes(),
35  Y->sizes(),
36  "Description: Input #1, input dimension:",
37  X1.sizes(),
38  " should match output dimension: ",
39  Y->sizes());
40  const T* X1_data = X1.template data<T>();
41  math::Max<T, Context>(N, X0_data, X1_data, Y_data, &context_);
42  for (int i = 2; i < InputSize(); ++i) {
43  const auto& Xi = Input(i);
44  CAFFE_ENFORCE_EQ(
45  Xi.sizes(),
46  Y->sizes(),
47  "Description: Input #",
48  i,
49  ", input dimension:",
50  Input(i).sizes(),
51  " should match output dimension: ",
52  Y->sizes());
53  const T* Xi_data = Xi.template data<T>();
54  math::Max<T, Context>(N, Y_data, Xi_data, Y_data, &context_);
55  }
56  return true;
57  }
58 };
59 
60 template <typename T, class Context>
61 class MinOp final : public Operator<Context> {
62  public:
63  USE_OPERATOR_CONTEXT_FUNCTIONS;
64 
65  USE_SIMPLE_CTOR_DTOR(MinOp)
66 
67  bool RunOnDevice() override {
68  const auto& X0 = Input(0);
69  auto* Y = Output(0);
70  Y->ResizeLike(X0);
71  const T* X0_data = X0.template data<T>();
72  T* Y_data = Y->template mutable_data<T>();
73  const int N = X0.numel();
74  if (InputSize() == 1) {
75  if (Y != &X0) {
76  context_.template CopySameDevice<T>(N, X0_data, Y_data);
77  }
78  return true;
79  }
80  const auto& X1 = Input(1);
81  CAFFE_ENFORCE_EQ(
82  X0.sizes(),
83  Y->sizes(),
84  "Description: Input #1, input dimension:",
85  X1.sizes(),
86  " should match output dimension: ",
87  Y->sizes());
88  const T* X1_data = X1.template data<T>();
89  math::Min<T, Context>(N, X0_data, X1_data, Y_data, &context_);
90  for (int i = 2; i < InputSize(); ++i) {
91  const auto& Xi = Input(i);
92  CAFFE_ENFORCE_EQ(
93  Xi.sizes(),
94  Y->sizes(),
95  "Description: Input #",
96  i,
97  ", input dimension:",
98  Input(i).sizes(),
99  " should match output dimension: ",
100  Y->sizes());
101  const T* Xi_data = Xi.template data<T>();
102  math::Min<T, Context>(N, Y_data, Xi_data, Y_data, &context_);
103  }
104  return true;
105  }
106 };
107 
108 template <typename T, class Context>
109 class SelectGradientOpBase : public Operator<Context> {
110  public:
111  USE_OPERATOR_CONTEXT_FUNCTIONS;
112  USE_SIMPLE_CTOR_DTOR(SelectGradientOpBase)
113 
114  bool RunOnDevice() override;
115 };
116 
117 template <typename T, class Context>
118 class MaxGradientOp final : public SelectGradientOpBase<T, Context> {
119  public:
120  template <class... Args>
121  explicit MaxGradientOp(Args&&... args)
122  : SelectGradientOpBase<T, Context>(std::forward<Args>(args)...) {}
123 
124  ~MaxGradientOp() = default;
125 };
126 
127 template <typename T, class Context>
128 class MinGradientOp final : public SelectGradientOpBase<T, Context> {
129  public:
130  template <class... Args>
131  explicit MinGradientOp(Args&&... args)
132  : SelectGradientOpBase<T, Context>(std::forward<Args>(args)...) {}
133 
134  ~MinGradientOp() = default;
135 };
136 
137 } // namespace caffe2
138 
139 #endif // CAFFE2_OPERATORS_MINMAX_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13