Caffe2 - C++ API
A deep learning, cross platform ML framework
stump_func_op.h
1 
18 #ifndef CAFFE2_FB_OPERATORS_UTILITY_OPS_H_
19 #define CAFFE2_FB_OPERATORS_UTILITY_OPS_H_
20 
21 #include "caffe2/core/context.h"
22 #include "caffe2/core/operator.h"
23 
24 namespace caffe2 {
25 
26 // Converts each input element into either high_ or low_value
27 // based on the given threshold.
28 //
29 // out[i] = low_value if in[i] <= threshold else high_value
30 template <typename TIN, typename TOUT, class Context>
31 class StumpFuncOp final : public Operator<Context> {
32  public:
33  USE_OPERATOR_CONTEXT_FUNCTIONS;
34 
35  template <class... Args>
36  explicit StumpFuncOp(Args&&... args)
37  : Operator<Context>(std::forward<Args>(args)...),
38  threshold_(this->template GetSingleArgument<TIN>("threshold", 0)),
39  low_value_(this->template GetSingleArgument<TOUT>("low_value", 0)),
40  high_value_(this->template GetSingleArgument<TOUT>("high_value", 0)) {}
41 
42  bool RunOnDevice() override;
43 
44  protected:
45  TIN threshold_;
46  TOUT low_value_;
47  TOUT high_value_;
48 
49  // Input: label, output: weight
50 };
51 
52 template <typename TIN, typename TOUT, class Context>
53 class StumpFuncIndexOp final : public Operator<Context> {
54  public:
55  USE_OPERATOR_CONTEXT_FUNCTIONS;
56 
57  template <class... Args>
58  explicit StumpFuncIndexOp(Args&&... args)
59  : Operator<Context>(std::forward<Args>(args)...),
60  threshold_(this->template GetSingleArgument<TIN>("threshold", 0)) {}
61 
62  bool RunOnDevice() override;
63 
64  protected:
65  TIN threshold_;
66  // Input: label, output: indices
67 };
68 
69 } // caffe2
70 
71 #endif // CAFFE2_FB_OPERATORS_UTILITY_OPS_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13