Caffe2 - C++ API
A deep learning, cross platform ML framework
arg_ops.cc
1 #include "caffe2/operators/arg_ops.h"
2 
3 #include <functional>
4 
5 #include "caffe2/utils/math.h"
6 
7 namespace caffe2 {
8 
9 namespace {
10 
11 template <typename T, class Compare, class Context>
12 void ComputeArgImpl(
13  const int prev_size,
14  const int next_size,
15  const int n,
16  const Compare& comp,
17  const T* X,
18  int64_t* Y,
19  Context* context) {
20  math::Set<int64_t, Context>(prev_size * next_size, int64_t(0), Y, context);
21  for (int i = 0; i < prev_size; ++i) {
22  const T* cur_X = X + i * n * next_size + next_size;
23  for (int k = 1; k < n; ++k) {
24  for (int j = 0; j < next_size; ++j) {
25  int64_t* cur_Y = Y + i * next_size + j;
26  if (comp(*cur_X, X[i * n * next_size + *cur_Y * next_size + j])) {
27  *cur_Y = k;
28  }
29  ++cur_X;
30  }
31  }
32  }
33 }
34 
35 } // namespace
36 
37 template <>
38 template <typename T>
39 bool ArgMaxReducer<CPUContext>::operator()(
40  const int prev_size,
41  const int next_size,
42  const int n,
43  const T* X,
44  int64_t* Y,
45  CPUContext* context) const {
46  ComputeArgImpl(prev_size, next_size, n, std::greater<T>(), X, Y, context);
47  return true;
48 }
49 
50 template <>
51 template <typename T>
52 bool ArgMinReducer<CPUContext>::operator()(
53  const int prev_size,
54  const int next_size,
55  const int n,
56  const T* X,
57  int64_t* Y,
58  CPUContext* context) const {
59  ComputeArgImpl(prev_size, next_size, n, std::less<T>(), X, Y, context);
60  return true;
61 }
62 
63 REGISTER_CPU_OPERATOR(ArgMax, ArgOp<CPUContext, ArgMaxReducer<CPUContext>>);
64 REGISTER_CPU_OPERATOR(ArgMin, ArgOp<CPUContext, ArgMinReducer<CPUContext>>);
65 
66 namespace {
67 
68 std::vector<TensorShape> InferTensor(
69  const OperatorDef& def,
70  const std::vector<TensorShape>& in) {
71  std::vector<TensorShape> out(1);
72  ArgumentHelper helper(def);
73  int axis = helper.GetSingleArgument("axis", -1);
74  const bool keep_dims = helper.GetSingleArgument("keepdims", true);
75  const auto& in_dims = in[0].dims();
76  auto* out_dims = out[0].mutable_dims();
77  if (axis == -1) {
78  axis = in_dims.size() - 1;
79  }
80  for (int i = 0; i < axis; ++i) {
81  out_dims->Add(in_dims.Get(i));
82  }
83  if (keep_dims) {
84  out_dims->Add(1);
85  }
86  for (int i = axis + 1; i < in_dims.size(); ++i) {
87  out_dims->Add(in_dims.Get(i));
88  }
89  out[0].set_data_type(TensorProto::INT64);
90  return out;
91 }
92 
93 } // namespace
94 
95 OPERATOR_SCHEMA(ArgMax)
96  .NumInputs(1)
97  .NumOutputs(1)
98  .TensorInferenceFunction(InferTensor)
99  .SetDoc(R"DOC(
100 Retrieve the argmax of an axis dimension specified by the `axis`
101 argument. Given an input tensor and two arguments (`axis` and
102 `keepdims`), returns a tensor containing the indices of the largest
103 element along the given axis. If the `keepdims` arg is *True* (default),
104 the shape of the output tensor matches the input tensor except the
105 `axis` dimension equals 1. Else, the `axis` dimension of the output
106 tensor is removed.
107 
108 Github Links:
109 
110 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/arg_ops.cc
111 
112 <details>
113 
114 <summary> <b>Example</b> </summary>
115 
116 **Code**
117 
118 ```
119 workspace.ResetWorkspace()
120 
121 op = core.CreateOperator(
122  "ArgMax",
123  ["X"],
124  ["Indices"],
125  axis=2,
126  keepdims=False
127 )
128 
129 workspace.FeedBlob("X", (np.random.randint(10, size=(3,3,3))).astype(np.float32))
130 print("X:", workspace.FetchBlob("X"))
131 workspace.RunOperatorOnce(op)
132 print("Indices:", workspace.FetchBlob("Indices"))
133 
134 ```
135 
136 **Result**
137 
138 ```
139 X: [[[4. 9. 6.]
140  [6. 6. 1.]
141  [9. 5. 4.]]
142 
143  [[6. 7. 4.]
144  [7. 9. 1.]
145  [3. 2. 8.]]
146 
147  [[3. 4. 6.]
148  [5. 2. 7.]
149  [1. 5. 7.]]]
150 Indices: [[1 0 0]
151  [1 1 2]
152  [2 2 2]]
153 
154 ```
155 
156 </details>
157 
158  )DOC")
159  .Input(0, "X", "*(type: Tensor`<float>`)* Input tensor.")
160  .Output(
161  0,
162  "Indices",
163  "*(type: Tensor`<float>`)* Tensor of indices for the largest values.")
164  .Arg("axis", "*(type: int; default: -1)* The axis to get argmax.")
165  .Arg(
166  "keepdims",
167  "*(type: bool; default: True)* If True (default), the output tensor "
168  "shape will match the input tensor shape except the `axis` dimension "
169  "equals 1. Else, the `axis` dimension of the output tensor is removed.");
170 
171 OPERATOR_SCHEMA(ArgMin)
172  .NumInputs(1)
173  .NumOutputs(1)
174  .TensorInferenceFunction(InferTensor)
175  .SetDoc(R"DOC(
176 Retrieve the argmin of an axis dimension specified by the `axis`
177 argument. Given an input tensor and two arguments (`axis` and
178 `keepdims`), returns a tensor containing the indices of the smallest
179 element along the given axis. If the `keepdims` arg is *True* (default),
180 the shape of the output tensor matches the input tensor except the
181 `axis` dimension equals 1. Else, the `axis` dimension of the output
182 tensor is removed.
183 
184 Github Links:
185 
186 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/arg_ops.cc
187 
188 <details>
189 
190 <summary> <b>Example</b> </summary>
191 
192 **Code**
193 
194 ```
195 workspace.ResetWorkspace()
196 
197 op = core.CreateOperator(
198  "ArgMin",
199  ["X"],
200  ["Indices"],
201  axis=1
202 )
203 
204 workspace.FeedBlob("X", (np.random.randint(10, size=(5,5))).astype(np.float32))
205 print("X:", workspace.FetchBlob("X"))
206 workspace.RunOperatorOnce(op)
207 print("Indices:", workspace.FetchBlob("Indices"))
208 
209 ```
210 
211 **Result**
212 
213 ```
214 
215 X: [[9. 4. 6. 4. 1.]
216  [5. 9. 8. 3. 4.]
217  [6. 1. 0. 2. 9.]
218  [7. 8. 2. 4. 9.]
219  [3. 9. 4. 9. 4.]]
220 Indices: [[4]
221  [3]
222  [2]
223  [2]
224  [0]]
225 
226 ```
227 
228 </details>
229 
230  )DOC")
231  .Input(0, "X", "*(type: Tensor`<float>`)* Input tensor.")
232  .Output(
233  0,
234  "Indices",
235  "*(type: Tensor`<float>`)* Tensor of indices for the smallest values.")
236  .Arg("axis", "*(type: int; default: -1)* The axis to get argmin.")
237  .Arg(
238  "keepdims",
239  "*(type: bool; default: True)* If True (default), the output tensor "
240  "shape will match the input tensor shape except the `axis` dimension "
241  "equals 1. Else, the `axis` dimension of the output tensor is removed.");
242 
243 SHOULD_NOT_DO_GRADIENT(ArgMax);
244 SHOULD_NOT_DO_GRADIENT(ArgMin);
245 
246 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13