1 #ifndef SPATIAL_SOFTMAX_WITH_LOSS_OP_H_ 2 #define SPATIAL_SOFTMAX_WITH_LOSS_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 11 template <
typename T,
class Context>
14 template <
class... Args>
17 scale_(this->
template GetSingleArgument<float>(
"scale", 1.)),
18 order_(StringToStorageOrder(
19 this->
template GetSingleArgument<string>(
"order",
"NCHW"))) {
20 CAFFE_ENFORCE(scale_ >= 0);
22 order_, StorageOrder::NCHW,
"Only NCHW order is supported right now.");
24 USE_OPERATOR_CONTEXT_FUNCTIONS;
26 bool RunOnDevice()
override;
37 Tensor scratch_{Context::GetDeviceType()};
40 template <
typename T,
class Context>
43 template <
class... Args>
46 scale_(this->
template GetSingleArgument<float>(
"scale", 1.)),
47 order_(StringToStorageOrder(
48 this->
template GetSingleArgument<string>(
"order",
"NCHW"))),
49 only_loss_(this->
template GetSingleArgument<bool>(
"only_loss",
false)) {
50 CAFFE_ENFORCE(scale_ >= 0);
52 order_, StorageOrder::NCHW,
"Only NCHW order is supported right now.");
54 USE_OPERATOR_CONTEXT_FUNCTIONS;
56 bool RunOnDevice()
override;
65 Tensor scratch_{Context::GetDeviceType()};
70 #endif // SOFTMAX_WITH_LOSS_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...