Caffe2 - C++ API
A deep learning, cross platform ML framework
arg_ops.h
1 #ifndef CAFFE2_OPERATORS_ARG_OPS_H_
2 #define CAFFE2_OPERATORS_ARG_OPS_H_
3 
4 #include <algorithm>
5 #include <iterator>
6 #include <vector>
7 
8 #include "caffe2/core/context.h"
9 #include "caffe2/core/operator.h"
10 #include "caffe2/core/types.h"
11 
12 namespace caffe2 {
13 
14 template <class Context, class Reducer>
15 class ArgOp final : public Operator<Context> {
16  public:
17  USE_OPERATOR_CONTEXT_FUNCTIONS;
18 
19  template <class... Args>
20  explicit ArgOp(Args&&... args)
21  : Operator<Context>(std::forward<Args>(args)...),
22  OP_SINGLE_ARG(int, "axis", axis_, -1),
23  OP_SINGLE_ARG(bool, "keepdims", keep_dims_, true) {}
24 
25  bool RunOnDevice() override {
26  return DispatchHelper<
28  call(this, Input(0));
29  }
30 
31  template <typename T>
32  bool DoRunWithType() {
33  const auto& X = Input(0);
34 
35  const int ndim = X.dim();
36  if (axis_ == -1) {
37  axis_ = ndim - 1;
38  }
39  CAFFE_ENFORCE_GE(axis_, 0);
40  CAFFE_ENFORCE_LT(axis_, ndim);
41  const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
42  std::vector<int64_t> Y_dims;
43  Y_dims.reserve(ndim);
44  int prev_size = 1;
45  int next_size = 1;
46  for (int i = 0; i < axis_; ++i) {
47  Y_dims.push_back(X_dims[i]);
48  prev_size *= X_dims[i];
49  }
50  if (keep_dims_) {
51  Y_dims.push_back(1);
52  }
53  for (int i = axis_ + 1; i < ndim; ++i) {
54  Y_dims.push_back(X_dims[i]);
55  next_size *= X_dims[i];
56  }
57  auto* Y = Output(0, Y_dims, at::dtype<int64_t>());
58  const int n = X_dims[axis_];
59  return reducer_(
60  prev_size,
61  next_size,
62  n,
63  X.template data<T>(),
64  Y->template mutable_data<int64_t>(),
65  &context_);
66  }
67 
68  private:
69  int axis_;
70  const bool keep_dims_;
71  Reducer reducer_{};
72 };
73 
74 template <class Context>
75 struct ArgMaxReducer {
76  template <typename T>
77  bool operator()(
78  const int prev_size,
79  const int next_size,
80  const int n,
81  const T* X,
82  int64_t* Y,
83  Context* context) const;
84 };
85 
86 template <class Context>
87 struct ArgMinReducer {
88  template <typename T>
89  bool operator()(
90  const int prev_size,
91  const int next_size,
92  const int n,
93  const T* X,
94  int64_t* Y,
95  Context* context) const;
96 };
97 
98 } // namespace caffe2
99 
100 #endif // CAFFE2_OPERATORS_ARG_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13