8 inline void gftrl_compute(
18 const GFtrlParams<T>& params) {
19 auto new_n = n + g * g;
20 auto sigma = (sqrt(new_n) - sqrt(n)) * params.alphaInv;
22 nz = z + g - sigma * w;
24 if (z_norm > params.lambda1 * std::sqrt(OutputDim)) {
25 nw = nz * (params.lambda1 * std::sqrt(OutputDim) / z_norm - 1) /
26 ((params.beta + sqrt(new_n)) * params.alphaInv + params.lambda2);
32 template <
typename Context,
typename T>
41 const GFtrlParams<T>& params,
43 for (
auto j = 0; j < InputDim; ++j) {
45 for (
auto i = 0; i < OutputDim; ++i) {
46 int idx = i * InputDim + j;
47 auto new_n = nz[idx * 2] + g[idx] * g[idx];
48 auto sigma = (sqrt(new_n) - sqrt(nz[idx * 2])) * params.alphaInv;
49 auto new_z = nz[idx * 2 + 1] + g[idx] - sigma * w[idx];
50 z_norm = z_norm + new_z * new_z;
53 z_norm = sqrt(z_norm);
54 for (
auto i = 0; i < OutputDim; ++i) {
55 int idx = i * InputDim + j;
71 template <
typename T,
typename Context>
72 bool GFtrlOp<T, Context>::RunOnDevice() {
74 if (ALPHA < InputSize()) {
75 CAFFE_ENFORCE_EQ(Input(ALPHA).numel(), 1,
"alpha should be real-valued");
76 params_.alphaInv = 1.0 / *(Input(ALPHA).template data<T>());
79 CAFFE_ENFORCE_EQ(Input(GRAD).numel(), Input(VAR).numel());
80 CAFFE_ENFORCE_EQ(Input(GRAD).numel() * 2, Input(N_Z).numel());
81 Output(OUTPUT_VAR)->ResizeLike(Input(VAR));
82 Output(OUTPUT_N_Z)->ResizeLike(Input(N_Z));
83 gftrl_update<Context>(
85 Input(GRAD).numel() / Input(GRAD).size(0),
86 Input(VAR).template data<T>(),
87 Input(N_Z).template data<T>(),
88 Input(GRAD).template data<T>(),
89 Output(OUTPUT_VAR)->template mutable_data<T>(),
90 Output(OUTPUT_N_Z)->template mutable_data<T>(),
97 REGISTER_CPU_OPERATOR(GFtrl, GFtrlOp<float, CPUContext>);
98 OPERATOR_SCHEMA(GFtrl).NumInputs(3, 4).NumOutputs(2).AllowInplace({{0, 0},
100 SHOULD_NOT_DO_GRADIENT(GFtrl);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...