3 #include "caffe2/core/operator.h" 10 : alphaInv(1.0 / op->GetSingleArgument<
float>(
"alpha", 0.005f)),
11 beta(op->GetSingleArgument<
float>(
"beta", 1.0f)),
12 lambda1(op->GetSingleArgument<
float>(
"lambda1", 0.001f)),
13 lambda2(op->GetSingleArgument<
float>(
"lambda2", 0.001f)) {}
21 template <
typename T,
class Context>
24 USE_OPERATOR_CONTEXT_FUNCTIONS;
28 !HasArgument(
"alpha") || ALPHA >= InputSize(),
29 "Cannot specify alpha by both input and argument");
31 bool RunOnDevice()
override;
35 INPUT_TAGS(VAR, N_Z, GRAD, ALPHA);
36 OUTPUT_TAGS(OUTPUT_VAR, OUTPUT_N_Z);
45 !HasArgument(
"alpha") || ALPHA >= InputSize(),
46 "Cannot specify alpha by both input and argument");
49 bool RunOnDevice()
override {
51 if (ALPHA < InputSize()) {
52 CAFFE_ENFORCE_EQ(Input(ALPHA).numel(), 1,
"alpha should be real-valued");
53 params_.alphaInv = 1.0 / *(Input(ALPHA).template data<T>());
56 auto& indices = Input(INDICES);
57 if (indices.template IsType<int32_t>()) {
59 }
else if (indices.template IsType<int64_t>()) {
62 LOG(FATAL) <<
"Unsupported type of INDICES in SparseFtrlOp: " 63 << indices.dtype().name();
70 INPUT_TAGS(VAR, N_Z, INDICES, GRAD, ALPHA);
71 OUTPUT_TAGS(OUTPUT_VAR, OUTPUT_N_Z);
74 template <
typename SIndex>
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...