1 #include "caffe2/operators/arg_ops.h" 5 #include "caffe2/utils/math.h" 11 template <
typename T,
class Compare,
class Context>
20 math::Set<int64_t, Context>(prev_size * next_size, int64_t(0), Y, context);
21 for (
int i = 0; i < prev_size; ++i) {
22 const T* cur_X = X + i * n * next_size + next_size;
23 for (
int k = 1; k < n; ++k) {
24 for (
int j = 0; j < next_size; ++j) {
25 int64_t* cur_Y = Y + i * next_size + j;
26 if (comp(*cur_X, X[i * n * next_size + *cur_Y * next_size + j])) {
39 bool ArgMaxReducer<CPUContext>::operator()(
45 CPUContext* context)
const {
46 ComputeArgImpl(prev_size, next_size, n, std::greater<T>(), X, Y, context);
52 bool ArgMinReducer<CPUContext>::operator()(
58 CPUContext* context)
const {
59 ComputeArgImpl(prev_size, next_size, n, std::less<T>(), X, Y, context);
63 REGISTER_CPU_OPERATOR(ArgMax, ArgOp<CPUContext, ArgMaxReducer<CPUContext>>);
64 REGISTER_CPU_OPERATOR(ArgMin, ArgOp<CPUContext, ArgMinReducer<CPUContext>>);
68 std::vector<TensorShape> InferTensor(
69 const OperatorDef& def,
70 const std::vector<TensorShape>& in) {
71 std::vector<TensorShape> out(1);
72 ArgumentHelper helper(def);
73 int axis = helper.GetSingleArgument(
"axis", -1);
74 const bool keep_dims = helper.GetSingleArgument(
"keepdims",
true);
75 const auto& in_dims = in[0].dims();
76 auto* out_dims = out[0].mutable_dims();
78 axis = in_dims.size() - 1;
80 for (
int i = 0; i < axis; ++i) {
81 out_dims->Add(in_dims.Get(i));
86 for (
int i = axis + 1; i < in_dims.size(); ++i) {
87 out_dims->Add(in_dims.Get(i));
89 out[0].set_data_type(TensorProto::INT64);
95 OPERATOR_SCHEMA(ArgMax)
98 .TensorInferenceFunction(InferTensor)
100 Retrieve the argmax of an axis dimension specified by the `axis` 101 argument. Given an input tensor and two arguments (`axis` and 102 `keepdims`), returns a tensor containing the indices of the largest 103 element along the given axis. If the `keepdims` arg is *True* (default), 104 the shape of the output tensor matches the input tensor except the 105 `axis` dimension equals 1. Else, the `axis` dimension of the output 110 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/arg_ops.cc 114 <summary> <b>Example</b> </summary> 119 workspace.ResetWorkspace() 121 op = core.CreateOperator( 129 workspace.FeedBlob("X", (np.random.randint(10, size=(3,3,3))).astype(np.float32)) 130 print("X:", workspace.FetchBlob("X")) 131 workspace.RunOperatorOnce(op) 132 print("Indices:", workspace.FetchBlob("Indices")) 159 .Input(0, "X",
"*(type: Tensor`<float>`)* Input tensor.")
163 "*(type: Tensor`<float>`)* Tensor of indices for the largest values.")
164 .Arg(
"axis",
"*(type: int; default: -1)* The axis to get argmax.")
167 "*(type: bool; default: True)* If True (default), the output tensor " 168 "shape will match the input tensor shape except the `axis` dimension " 169 "equals 1. Else, the `axis` dimension of the output tensor is removed.");
171 OPERATOR_SCHEMA(ArgMin)
174 .TensorInferenceFunction(InferTensor)
176 Retrieve the argmin of an axis dimension specified by the `axis` 177 argument. Given an input tensor and two arguments (`axis` and 178 `keepdims`), returns a tensor containing the indices of the smallest 179 element along the given axis. If the `keepdims` arg is *True* (default), 180 the shape of the output tensor matches the input tensor except the 181 `axis` dimension equals 1. Else, the `axis` dimension of the output 186 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/arg_ops.cc 190 <summary> <b>Example</b> </summary> 195 workspace.ResetWorkspace() 197 op = core.CreateOperator( 204 workspace.FeedBlob("X", (np.random.randint(10, size=(5,5))).astype(np.float32)) 205 print("X:", workspace.FetchBlob("X")) 206 workspace.RunOperatorOnce(op) 207 print("Indices:", workspace.FetchBlob("Indices")) 231 .Input(0, "X",
"*(type: Tensor`<float>`)* Input tensor.")
235 "*(type: Tensor`<float>`)* Tensor of indices for the smallest values.")
236 .Arg(
"axis",
"*(type: int; default: -1)* The axis to get argmin.")
239 "*(type: bool; default: True)* If True (default), the output tensor " 240 "shape will match the input tensor shape except the `axis` dimension " 241 "equals 1. Else, the `axis` dimension of the output tensor is removed.");
243 SHOULD_NOT_DO_GRADIENT(ArgMax);
244 SHOULD_NOT_DO_GRADIENT(ArgMin);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...