Caffe2 - C++ API
A deep learning, cross platform ML framework
sigmoid_focal_loss_op.h
1 
17 #ifndef SIGMOID_FOCAL_LOSS_OP_H_
18 #define SIGMOID_FOCAL_LOSS_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <typename T, class Context>
28 class SigmoidFocalLossOp final : public Operator<Context> {
29  public:
30  SigmoidFocalLossOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  scale_(this->template GetSingleArgument<float>("scale", 1.)),
33  num_classes_(this->template GetSingleArgument<int>("num_classes", 80)),
34  gamma_(this->template GetSingleArgument<float>("gamma", 1.)),
35  alpha_(this->template GetSingleArgument<float>("alpha", 0.25)) {
36  CAFFE_ENFORCE(scale_ >= 0);
37  }
38  USE_OPERATOR_CONTEXT_FUNCTIONS;
39 
40  bool RunOnDevice() override {
41  // No CPU implementation for now
42  CAFFE_NOT_IMPLEMENTED;
43  }
44 
45  protected:
46  float scale_;
47  int num_classes_;
48  float gamma_;
49  float alpha_;
50  Tensor losses_{Context::GetDeviceType()};
51  Tensor counts_{Context::GetDeviceType()};
52 };
53 
54 template <typename T, class Context>
55 class SigmoidFocalLossGradientOp final : public Operator<Context> {
56  public:
57  SigmoidFocalLossGradientOp(const OperatorDef& def, Workspace* ws)
58  : Operator<Context>(def, ws),
59  scale_(this->template GetSingleArgument<float>("scale", 1.)),
60  num_classes_(this->template GetSingleArgument<int>("num_classes", 80)),
61  gamma_(this->template GetSingleArgument<float>("gamma", 1.)),
62  alpha_(this->template GetSingleArgument<float>("alpha", 0.25)) {
63  CAFFE_ENFORCE(scale_ >= 0);
64  }
65  USE_OPERATOR_CONTEXT_FUNCTIONS;
66 
67  bool RunOnDevice() override {
68  // No CPU implementation for now
69  CAFFE_NOT_IMPLEMENTED;
70  }
71 
72  protected:
73  float scale_;
74  int num_classes_;
75  float gamma_;
76  float alpha_;
77  Tensor counts_{Context::GetDeviceType()};
78  Tensor weights_{Context::GetDeviceType()}; // unignored weights
79 };
80 
81 } // namespace caffe2
82 
83 #endif // SIGMOID_FOCAL_LOSS_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13