17 #ifndef SIGMOID_CROSS_ENTROPY_LOSS_OP_H_ 18 #define SIGMOID_CROSS_ENTROPY_LOSS_OP_H_ 20 #include "caffe2/core/context.h" 21 #include "caffe2/core/logging.h" 22 #include "caffe2/core/operator.h" 23 #include "caffe2/utils/math.h" 27 template <
typename T,
class Context>
32 scale_(this->
template GetSingleArgument<float>(
"scale", 1.)),
33 normalize_(this->
template GetSingleArgument<int>(
"normalize", 1)) {
34 CAFFE_ENFORCE(scale_ >= 0);
35 CAFFE_ENFORCE(normalize_ == 0 || normalize_ == 1);
37 USE_OPERATOR_CONTEXT_FUNCTIONS;
39 bool RunOnDevice()
override {
41 CAFFE_NOT_IMPLEMENTED;
47 Tensor losses_{Context::GetDeviceType()};
48 Tensor counts_{Context::GetDeviceType()};
52 template <
typename T,
class Context>
57 scale_(this->
template GetSingleArgument<float>(
"scale", 1.)),
58 normalize_(this->
template GetSingleArgument<int>(
"normalize", 1)) {
59 CAFFE_ENFORCE(scale_ >= 0);
60 CAFFE_ENFORCE(normalize_ == 0 || normalize_ == 1);
62 USE_OPERATOR_CONTEXT_FUNCTIONS;
64 bool RunOnDevice()
override {
66 CAFFE_NOT_IMPLEMENTED;
72 Tensor counts_{Context::GetDeviceType()};
78 #endif // SIGMOID_CROSS_ENTROPY_LOSS_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...