Caffe2 - C++ API
A deep learning, cross platform ML framework
top_k.h
1 
17 #ifndef CAFFE2_OPERATORS_TOP_K_H_
18 #define CAFFE2_OPERATORS_TOP_K_H_
19 
20 #include "caffe2/core/logging.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26 template <typename T, class Context>
27 class TopKOp : public Operator<Context> {
28  public:
29  USE_OPERATOR_CONTEXT_FUNCTIONS;
30 
31  TopKOp(const OperatorDef& operator_def, Workspace* ws)
32  : Operator<Context>(operator_def, ws), OP_SINGLE_ARG(int, "k", k_, -1) {
33  CAFFE_ENFORCE(k_ >= 1, "k argument must be >= 1");
34  }
35 
36  bool RunOnDevice() override;
37 
38  private:
39  int k_;
40 };
41 
42 template <typename T, class Context>
43 class TopKGradientOp : public Operator<Context> {
44  public:
45  USE_OPERATOR_CONTEXT_FUNCTIONS;
46 
47  TopKGradientOp(const OperatorDef& operator_def, Workspace* ws)
48  : Operator<Context>(operator_def, ws) {}
49 
50  bool RunOnDevice() override;
51 };
52 
53 } // namespace caffe2
54 
55 #endif // CAFFE2_OPERATORS_TOP_K_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.