17 #ifndef SIGMOID_FOCAL_LOSS_OP_H_ 18 #define SIGMOID_FOCAL_LOSS_OP_H_ 20 #include "caffe2/core/context.h" 21 #include "caffe2/core/logging.h" 22 #include "caffe2/core/operator.h" 23 #include "caffe2/utils/math.h" 27 template <
typename T,
class Context>
32 scale_(this->
template GetSingleArgument<float>(
"scale", 1.)),
33 num_classes_(this->
template GetSingleArgument<int>(
"num_classes", 80)),
34 gamma_(this->
template GetSingleArgument<float>(
"gamma", 1.)),
35 alpha_(this->
template GetSingleArgument<float>(
"alpha", 0.25)) {
36 CAFFE_ENFORCE(scale_ >= 0);
38 USE_OPERATOR_CONTEXT_FUNCTIONS;
40 bool RunOnDevice()
override {
42 CAFFE_NOT_IMPLEMENTED;
50 Tensor losses_{Context::GetDeviceType()};
51 Tensor counts_{Context::GetDeviceType()};
54 template <
typename T,
class Context>
59 scale_(this->
template GetSingleArgument<float>(
"scale", 1.)),
60 num_classes_(this->
template GetSingleArgument<int>(
"num_classes", 80)),
61 gamma_(this->
template GetSingleArgument<float>(
"gamma", 1.)),
62 alpha_(this->
template GetSingleArgument<float>(
"alpha", 0.25)) {
63 CAFFE_ENFORCE(scale_ >= 0);
65 USE_OPERATOR_CONTEXT_FUNCTIONS;
67 bool RunOnDevice()
override {
69 CAFFE_NOT_IMPLEMENTED;
77 Tensor counts_{Context::GetDeviceType()};
78 Tensor weights_{Context::GetDeviceType()};
83 #endif // SIGMOID_FOCAL_LOSS_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...