Caffe2 - C++ API
A deep learning, cross platform ML framework
smooth_l1_loss_op.h
1 
17 #ifndef SMOOTH_L1_LOSS_OP_H_
18 #define SMOOTH_L1_LOSS_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <typename T, class Context>
28 class SmoothL1LossOp final : public Operator<Context> {
29  public:
30  SmoothL1LossOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  beta_(OperatorBase::GetSingleArgument<float>("beta", 1.)),
33  scale_(OperatorBase::GetSingleArgument<float>("scale", 1.)) {
34  CAFFE_ENFORCE(beta_ > 0);
35  CAFFE_ENFORCE(scale_ >= 0);
36  }
37  USE_OPERATOR_CONTEXT_FUNCTIONS;
38 
39  bool RunOnDevice() override {
40  // No CPU implementation for now
41  CAFFE_NOT_IMPLEMENTED;
42  }
43 
44  protected:
45  float beta_; // Transition point from L1 to L2 loss
46  float scale_; // Scale the loss by scale_
47  Tensor<Context> buff_; // Buffer for element-wise differences
48 };
49 
50 template <typename T, class Context>
51 class SmoothL1LossGradientOp final : public Operator<Context> {
52  public:
53  SmoothL1LossGradientOp(const OperatorDef& def, Workspace* ws)
54  : Operator<Context>(def, ws),
55  beta_(OperatorBase::GetSingleArgument<float>("beta", 1.)),
56  scale_(OperatorBase::GetSingleArgument<float>("scale", 1.)) {
57  CAFFE_ENFORCE(beta_ > 0);
58  CAFFE_ENFORCE(scale_ >= 0);
59  }
60  USE_OPERATOR_CONTEXT_FUNCTIONS;
61 
62  bool RunOnDevice() override {
63  // No CPU implementation for now
64  CAFFE_NOT_IMPLEMENTED;
65  }
66 
67  protected:
68  float beta_; // Transition point from L1 to L2 loss
69  float scale_; // Scale the loss by scale_
70  Tensor<Context> buff_; // Buffer for element-wise differences
71 };
72 
73 } // namespace caffe2
74 
75 #endif // SMOOTH_L1_LOSS_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.