Caffe2 - C++ API
A deep learning, cross platform ML framework
softmax_focal_loss_op.h
1 
17 #ifndef SOFTMAX_FOCAL_LOSS_OP_H_
18 #define SOFTMAX_FOCAL_LOSS_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <typename T, class Context>
28 class SoftmaxFocalLossOp final : public Operator<Context> {
29  public:
30  SoftmaxFocalLossOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  scale_(OperatorBase::GetSingleArgument<float>("scale", 1.)),
33  gamma_(OperatorBase::GetSingleArgument<float>("gamma", 1.)),
34  alpha_(OperatorBase::GetSingleArgument<float>("alpha", 0.25)),
35  num_classes_(OperatorBase::GetSingleArgument<int>("num_classes", 81)),
36  order_(StringToStorageOrder(
37  OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
38  CAFFE_ENFORCE(scale_ >= 0);
39  CAFFE_ENFORCE_EQ(
40  order_, StorageOrder::NCHW, "Only NCHW order is supported right now.");
41  }
42  USE_OPERATOR_CONTEXT_FUNCTIONS;
43 
44  bool RunOnDevice() override {
45  // No CPU implementation for now
46  CAFFE_NOT_IMPLEMENTED;
47  }
48 
49  protected:
50  float scale_;
51  float gamma_;
52  float alpha_;
53  int num_classes_;
54  StorageOrder order_;
55  Tensor<Context> losses_;
56 };
57 
58 template <typename T, class Context>
59 class SoftmaxFocalLossGradientOp final : public Operator<Context> {
60  public:
61  SoftmaxFocalLossGradientOp(const OperatorDef& def, Workspace* ws)
62  : Operator<Context>(def, ws),
63  scale_(OperatorBase::GetSingleArgument<float>("scale", 1.)),
64  gamma_(OperatorBase::GetSingleArgument<float>("gamma", 1.)),
65  alpha_(OperatorBase::GetSingleArgument<float>("alpha", 0.25)),
66  num_classes_(OperatorBase::GetSingleArgument<int>("num_classes", 81)),
67  order_(StringToStorageOrder(
68  OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
69  CAFFE_ENFORCE(scale_ >= 0);
70  CAFFE_ENFORCE_EQ(
71  order_, StorageOrder::NCHW, "Only NCHW order is supported right now.");
72  }
73  USE_OPERATOR_CONTEXT_FUNCTIONS;
74 
75  bool RunOnDevice() override {
76  // No CPU implementation for now
77  CAFFE_NOT_IMPLEMENTED;
78  }
79 
80  protected:
81  float scale_;
82  float gamma_;
83  float alpha_;
84  int num_classes_;
85  StorageOrder order_;
86  Tensor<Context> buff_;
87 };
88 
89 } // namespace caffe2
90 
91 #endif // SOFTMAX_FOCAL_LOSS_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.