1 #ifndef CAFFE2_OPERATORS_REDUCE_FRONT_BACK_MAX_OPS_H_ 2 #define CAFFE2_OPERATORS_REDUCE_FRONT_BACK_MAX_OPS_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 11 template <
typename T,
class Context,
bool FIRSTDIMS>
14 template <
class... Args>
18 this->
template GetSingleArgument<int32_t>(
"num_reduce_dim", 1)) {}
20 USE_OPERATOR_CONTEXT_FUNCTIONS;
26 num_reduce_dims_ >= 0 && num_reduce_dims_ <= X.dim(),
27 "For N-dim input tensor, support num_reduce_dims in range [0, N].");
29 const int rows = FIRSTDIMS ? X.size_to_dim(num_reduce_dims_)
30 : X.size_to_dim(X.dim() - num_reduce_dims_);
31 const int cols = FIRSTDIMS ? X.size_from_dim(num_reduce_dims_)
32 : X.size_from_dim(X.dim() - num_reduce_dims_);
34 vector<int64_t> output_shape;
35 int start_index = FIRSTDIMS ? num_reduce_dims_ : 0;
37 FIRSTDIMS ? X.dim() : X.dim() - num_reduce_dims_;
39 for (
int i = start_index; i < end_index; ++i) {
40 output_shape.push_back(X.sizes()[i]);
42 auto* Y = Output(0, output_shape, at::dtype<float>());
43 float* out_data = Y->template mutable_data<float>();
45 if (cols == 0 || rows == 0) {
46 math::Set(Y->numel(),
static_cast<float>(0), out_data, &context_);
50 const int32_t* lengths_data =
nullptr;
51 if (InputSize() > 1) {
52 const auto& lengths =
Input(1);
53 lengths_data = lengths.template data<int32_t>();
55 num_reduce_dims_ == 1,
56 "Given lengths input, the number of reduce dimensions should be one.");
57 const int batch_size = FIRSTDIMS ? cols : rows;
59 lengths.numel() == batch_size,
60 "The size of lengths vector doesn't match the batch size.");
63 const float* data = X.template data<float>();
64 Compute(rows, cols, data, lengths_data, out_data);
73 const int32_t* lengths_data,
79 template <
typename T,
class Context,
bool FIRSTDIMS>
82 template <
class... Args>
86 this->
template GetSingleArgument<int32_t>(
"num_reduce_dim", 1)) {}
88 USE_OPERATOR_CONTEXT_FUNCTIONS;
90 bool RunOnDevice()
override {
95 auto* dX = Output(0, X.sizes(), at::dtype<float>());
96 const int rows = FIRSTDIMS ? X.size_to_dim(num_reduce_dims_)
97 : X.size_to_dim(X.dim() - num_reduce_dims_);
98 const int cols = FIRSTDIMS ? X.size_from_dim(num_reduce_dims_)
99 : X.size_from_dim(X.dim() - num_reduce_dims_);
101 const float* dYdata = dY.template data<float>();
102 const float* Xdata = X.template data<float>();
103 const float* Ydata = Y.template data<float>();
105 const int32_t* lengths_data =
nullptr;
106 if (InputSize() > 3) {
107 const auto& lengths =
Input(3);
108 lengths_data = lengths.template data<int32_t>();
110 num_reduce_dims_ == 1,
111 "Given lengths input, the number of reduce dimensions should be one.");
112 const int batch_size = FIRSTDIMS ? cols : rows;
114 lengths.numel() == batch_size,
115 "The size of lengths vector doesn't match the batch size.");
118 float* dXdata = dX->template mutable_data<float>();
119 Compute(rows, cols, dYdata, Xdata, Ydata, lengths_data, dXdata);
130 const int32_t* lengths_data,
133 int num_reduce_dims_;
138 #endif // CAFFE2_OPERATORS_REDUCE_FRONT_BACK_MAX_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...