1 #ifndef CAFFE2_OPERATORS_MINMAX_OPS_H_ 2 #define CAFFE2_OPERATORS_MINMAX_OPS_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/core/types.h" 8 #include "caffe2/utils/math.h" 12 template <
typename T,
class Context>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
17 USE_SIMPLE_CTOR_DTOR(
MaxOp)
19 bool RunOnDevice()
override {
20 const auto& X0 =
Input(0);
23 const T* X0_data = X0.template data<T>();
24 T* Y_data = Y->template mutable_data<T>();
25 const int N = X0.numel();
26 if (InputSize() == 1) {
28 context_.template CopySameDevice<T>(N, X0_data, Y_data);
32 const auto& X1 =
Input(1);
36 "Description: Input #1, input dimension:",
38 " should match output dimension: ",
40 const T* X1_data = X1.template data<T>();
41 math::Max<T, Context>(N, X0_data, X1_data, Y_data, &context_);
42 for (
int i = 2; i < InputSize(); ++i) {
43 const auto& Xi =
Input(i);
47 "Description: Input #",
51 " should match output dimension: ",
53 const T* Xi_data = Xi.template data<T>();
54 math::Max<T, Context>(N, Y_data, Xi_data, Y_data, &context_);
60 template <
typename T,
class Context>
63 USE_OPERATOR_CONTEXT_FUNCTIONS;
65 USE_SIMPLE_CTOR_DTOR(
MinOp)
67 bool RunOnDevice()
override {
68 const auto& X0 =
Input(0);
71 const T* X0_data = X0.template data<T>();
72 T* Y_data = Y->template mutable_data<T>();
73 const int N = X0.numel();
74 if (InputSize() == 1) {
76 context_.template CopySameDevice<T>(N, X0_data, Y_data);
80 const auto& X1 =
Input(1);
84 "Description: Input #1, input dimension:",
86 " should match output dimension: ",
88 const T* X1_data = X1.template data<T>();
89 math::Min<T, Context>(N, X0_data, X1_data, Y_data, &context_);
90 for (
int i = 2; i < InputSize(); ++i) {
91 const auto& Xi =
Input(i);
95 "Description: Input #",
99 " should match output dimension: ",
101 const T* Xi_data = Xi.template data<T>();
102 math::Min<T, Context>(N, Y_data, Xi_data, Y_data, &context_);
108 template <
typename T,
class Context>
111 USE_OPERATOR_CONTEXT_FUNCTIONS;
114 bool RunOnDevice()
override;
117 template <
typename T,
class Context>
120 template <
class... Args>
124 ~MaxGradientOp() =
default;
127 template <
typename T,
class Context>
130 template <
class... Args>
134 ~MinGradientOp() =
default;
139 #endif // CAFFE2_OPERATORS_MINMAX_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...