Caffe2 - C++ API
A deep learning, cross platform ML framework
select_smooth_l1_loss_op.h
1 
17 #ifndef SELECT_SMOOTH_L1_LOSS_OP_H_
18 #define SELECT_SMOOTH_L1_LOSS_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <typename T, class Context>
28 class SelectSmoothL1LossOp final : public Operator<Context> {
29  public:
30  SelectSmoothL1LossOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  beta_(this->template GetSingleArgument<float>("beta", 1.)),
33  scale_(this->template GetSingleArgument<float>("scale", 1.)) {
34  CAFFE_ENFORCE(beta_ > 0);
35  CAFFE_ENFORCE(scale_ >= 0);
36  }
37  USE_OPERATOR_CONTEXT_FUNCTIONS;
38 
39  bool RunOnDevice() override {
40  // No CPU implementation for now
41  CAFFE_NOT_IMPLEMENTED;
42  }
43 
44  protected:
45  float beta_; // Transition point from L1 to L2 loss
46  float scale_; // Scale the loss by scale_
47  int dim_; // dimension for 1 anchor prediction
48  Tensor buff_{Context::GetDeviceType()}; // Buffer for element-wise differences
49 };
50 
51 template <typename T, class Context>
52 class SelectSmoothL1LossGradientOp final : public Operator<Context> {
53  public:
54  SelectSmoothL1LossGradientOp(const OperatorDef& def, Workspace* ws)
55  : Operator<Context>(def, ws),
56  beta_(this->template GetSingleArgument<float>("beta", 1.)),
57  scale_(this->template GetSingleArgument<float>("scale", 1.)) {
58  CAFFE_ENFORCE(beta_ > 0);
59  CAFFE_ENFORCE(scale_ >= 0);
60  }
61  USE_OPERATOR_CONTEXT_FUNCTIONS;
62 
63  bool RunOnDevice() override {
64  // No CPU implementation for now
65  CAFFE_NOT_IMPLEMENTED;
66  }
67 
68  protected:
69  float beta_; // Transition point from L1 to L2 loss
70  float scale_; // Scale the loss by scale_
71  int dim_; // dimension for 1 anchor prediction
72  Tensor buff_{Context::GetDeviceType()}; // Buffer for element-wise differences
73 };
74 
75 } // namespace caffe2
76 
77 #endif // SELECT_SMOOTH_L1_LOSS_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13