Caffe2 - C++ API
A deep learning, cross platform ML framework
top_k.h
1 #ifndef CAFFE2_OPERATORS_TOP_K_H_
2 #define CAFFE2_OPERATORS_TOP_K_H_
3 
4 #include "caffe2/core/logging.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <typename T, class Context>
11 class TopKOp : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14 
15  template <class... Args>
16  explicit TopKOp(Args&&... args)
17  : Operator<Context>(std::forward<Args>(args)...),
18  OP_SINGLE_ARG(int, "k", k_, -1),
19  OP_SINGLE_ARG(int, "axis", axis_, -1) {
20  CAFFE_ENFORCE(k_ >= 1, "k argument must be >= 1");
21  }
22 
23  ~TopKOp() {}
24 
25  bool RunOnDevice() override;
26 
27  private:
28  const int k_;
29  int axis_;
30 };
31 
32 template <typename T, class Context>
33 class TopKGradientOp : public Operator<Context> {
34  public:
35  USE_OPERATOR_CONTEXT_FUNCTIONS;
36 
37  template <class... Args>
38  explicit TopKGradientOp(Args&&... args)
39  : Operator<Context>(std::forward<Args>(args)...),
40  OP_SINGLE_ARG(int, "axis", axis_, -1) {}
41 
42  ~TopKGradientOp() {}
43 
44  bool RunOnDevice() override;
45 
46  private:
47  int axis_;
48 };
49 
50 } // namespace caffe2
51 
52 #endif // CAFFE2_OPERATORS_TOP_K_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13