1 #ifndef SOFTMAX_WITH_LOSS_OP_H_ 2 #define SOFTMAX_WITH_LOSS_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 11 template <
typename T,
class Context>
14 template <
class... Args>
17 scale_(this->
template GetSingleArgument<float>(
"scale", 1.)),
19 this->
template GetSingleArgument<int>(
"label_prob", 0)),
20 order_(StringToStorageOrder(
21 this->
template GetSingleArgument<string>(
"order",
"NCHW"))),
22 axis_(this->
template GetSingleArgument<int>(
"axis", 1)) {
23 CAFFE_ENFORCE(scale_ >= 0);
25 order_, StorageOrder::NCHW,
"Only NCHW order is supported right now.");
27 USE_OPERATOR_CONTEXT_FUNCTIONS;
29 bool RunOnDevice()
override;
43 Tensor scratch_{Context::GetDeviceType()};
46 template <
typename T,
class Context>
49 template <
class... Args>
52 scale_(this->
template GetSingleArgument<float>(
"scale", 1.)),
54 this->
template GetSingleArgument<int>(
"label_prob", 0)),
55 order_(StringToStorageOrder(
56 this->
template GetSingleArgument<string>(
"order",
"NCHW"))),
57 only_loss_(this->
template GetSingleArgument<bool>(
"only_loss",
false)),
58 axis_(this->
template GetSingleArgument<int>(
"axis", 1)) {
59 CAFFE_ENFORCE(scale_ >= 0);
61 order_, StorageOrder::NCHW,
"Only NCHW order is supported right now.");
63 USE_OPERATOR_CONTEXT_FUNCTIONS;
65 bool RunOnDevice()
override;
71 Tensor sum_multiplier_{Context::GetDeviceType()};
77 Tensor scratch_{Context::GetDeviceType()};
82 #endif // SOFTMAX_WITH_LOSS_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...