1 #ifndef CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_ 2 #define CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 11 template <
typename T,
class Context>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
16 bool RunOnDevice()
override;
19 static constexpr
T kLOG_THRESHOLD() {
20 return static_cast<T>(1e-20);
26 template <
typename T,
class Context>
30 USE_OPERATOR_CONTEXT_FUNCTIONS;
31 bool RunOnDevice()
override;
36 static constexpr
T kLOG_THRESHOLD() {
37 return static_cast<T>(1e-20);
43 template <
typename T,
class Context>
47 USE_OPERATOR_CONTEXT_FUNCTIONS;
48 bool RunOnDevice()
override;
55 template <
typename T,
class Context>
59 USE_OPERATOR_CONTEXT_FUNCTIONS;
60 bool RunOnDevice()
override;
67 template <
typename T,
class Context>
70 USE_OPERATOR_CONTEXT_FUNCTIONS;
71 template <
class... Args>
75 this->
template GetSingleArgument<bool>(
"log_D_trick",
false)),
77 this->
template GetSingleArgument<bool>(
"unjoined_lr_loss",
false)) {
79 !(log_D_trick_ && unjoined_lr_loss_),
80 "log_D_trick_ and unjoined_lr_loss_ cannot be set as True simultaneously");
83 bool RunOnDevice()
override;
87 bool unjoined_lr_loss_;
90 template <
typename T,
class Context>
93 USE_OPERATOR_CONTEXT_FUNCTIONS;
94 template <
class... Args>
98 this->
template GetSingleArgument<bool>(
"log_D_trick",
false)),
100 this->
template GetSingleArgument<bool>(
"unjoined_lr_loss",
false)) {
103 bool RunOnDevice()
override;
107 bool unjoined_lr_loss_;
110 template <
typename T,
class Context>
114 USE_OPERATOR_CONTEXT_FUNCTIONS;
115 bool RunOnDevice()
override;
118 template <
typename T,
class Context>
123 USE_OPERATOR_CONTEXT_FUNCTIONS;
124 bool RunOnDevice()
override;
127 template <
typename T,
class Context>
131 USE_OPERATOR_CONTEXT_FUNCTIONS;
132 bool RunOnDevice()
override;
137 static constexpr
T kLOG_THRESHOLD() {
138 return static_cast<T>(1e-20);
142 template <
typename T,
class Context>
146 USE_OPERATOR_CONTEXT_FUNCTIONS;
147 bool RunOnDevice()
override;
152 static constexpr
T kLOG_THRESHOLD() {
153 return static_cast<T>(1e-20);
159 #endif // CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...