1 #ifndef CAFFE2_OPERATORS_ARG_OPS_H_ 2 #define CAFFE2_OPERATORS_ARG_OPS_H_ 8 #include "caffe2/core/context.h" 9 #include "caffe2/core/operator.h" 10 #include "caffe2/core/types.h" 14 template <
class Context,
class Reducer>
17 USE_OPERATOR_CONTEXT_FUNCTIONS;
19 template <
class... Args>
20 explicit ArgOp(Args&&... args)
22 OP_SINGLE_ARG(
int,
"axis", axis_, -1),
23 OP_SINGLE_ARG(
bool,
"keepdims", keep_dims_,
true) {}
25 bool RunOnDevice()
override {
32 bool DoRunWithType() {
33 const auto& X =
Input(0);
35 const int ndim = X.dim();
39 CAFFE_ENFORCE_GE(axis_, 0);
40 CAFFE_ENFORCE_LT(axis_, ndim);
41 const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
42 std::vector<int64_t> Y_dims;
46 for (
int i = 0; i < axis_; ++i) {
47 Y_dims.push_back(X_dims[i]);
48 prev_size *= X_dims[i];
53 for (
int i = axis_ + 1; i < ndim; ++i) {
54 Y_dims.push_back(X_dims[i]);
55 next_size *= X_dims[i];
57 auto* Y = Output(0, Y_dims, at::dtype<int64_t>());
58 const int n = X_dims[axis_];
64 Y->template mutable_data<int64_t>(),
70 const bool keep_dims_;
74 template <
class Context>
83 Context* context)
const;
86 template <
class Context>
95 Context* context)
const;
100 #endif // CAFFE2_OPERATORS_ARG_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...