6 inline T sgn(
const T x) {
7 return (x == 0 ? 0 : (x < 0 ? -1 : 1));
11 inline void ftrl_compute(
19 const FtrlParams<T>& params) {
20 auto new_n = n + g * g;
21 auto sigma = (sqrt(new_n) - sqrt(n)) * params.alphaInv;
23 nz = z + g - sigma * w;
25 if (std::abs(nz) > params.lambda1) {
26 nw = (params.lambda1 * sgn(nz) - nz) /
27 ((params.beta + sqrt(new_n)) * params.alphaInv + params.lambda2);
34 template <
typename Context,
typename T>
42 const FtrlParams<T>& params,
46 for (
auto i = 0; i < N; ++i) {
59 template <
typename T,
typename Context>
60 bool FtrlOp<T, Context>::RunOnDevice() {
62 if (ALPHA < InputSize()) {
63 CAFFE_ENFORCE_EQ(Input(ALPHA).numel(), 1,
"alpha should be real-valued");
64 params_.alphaInv = 1.0 / *(Input(ALPHA).template data<T>());
66 CAFFE_ENFORCE_EQ(Input(GRAD).numel(), Input(VAR).numel());
67 CAFFE_ENFORCE_EQ(Input(GRAD).numel() * 2, Input(N_Z).numel());
68 Output(OUTPUT_VAR)->ResizeLike(Input(VAR));
69 Output(OUTPUT_N_Z)->ResizeLike(Input(N_Z));
72 Input(VAR).template data<T>(),
73 Input(N_Z).template data<T>(),
74 Input(GRAD).template data<T>(),
75 Output(OUTPUT_VAR)->template mutable_data<T>(),
76 Output(OUTPUT_N_Z)->template mutable_data<T>(),
83 template <
typename SIndex>
84 void SparseFtrlOp<T>::DoRun() {
85 auto* var = Output(OUTPUT_VAR);
86 auto* n_z = Output(OUTPUT_N_Z);
87 auto& indices = Input(INDICES);
88 auto& grad = Input(GRAD);
89 CAFFE_ENFORCE_EQ(&Input(VAR), var,
"In place operation is required");
90 CAFFE_ENFORCE_EQ(&Input(N_Z), n_z,
"In place operation is required");
91 int64_t
M = var->numel();
92 int64_t N = var->size(0);
93 int64_t block_size = M / N;
94 int64_t K = indices.numel();
95 DCHECK_EQ(M * 2, n_z->numel());
96 DCHECK_EQ(grad.numel(), K * block_size);
97 T* w = var->template mutable_data<T>();
98 T* nz = n_z->template mutable_data<T>();
99 const SIndex* idxs = indices.template data<SIndex>();
100 const T* g = grad.template data<T>();
104 for (int64_t i = 0; i < K; ++i) {
105 SIndex idx = idxs[i];
106 DCHECK(0 <= idx && idx < N) <<
"Index out of bounds: " << idx
107 <<
", range 0 to " << N;
108 if (block_size == 1) {
119 int64_t x = block_size * idx;
134 REGISTER_CPU_OPERATOR(Ftrl, FtrlOp<float, CPUContext>);
135 OPERATOR_SCHEMA(Ftrl).NumInputs(3, 4).NumOutputs(2).AllowInplace({{0, 0},
137 SHOULD_NOT_DO_GRADIENT(Ftrl);
139 REGISTER_CPU_OPERATOR(SparseFtrl, SparseFtrlOp<float>);
140 OPERATOR_SCHEMA(SparseFtrl)
143 .EnforceInplace({{0, 0}, {1, 1}});
144 SHOULD_NOT_DO_GRADIENT(SparseFtrl);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...