Caffe2 - C++ API
A deep learning, cross platform ML framework
ftrl_op.cc
1 
17 #include "ftrl_op.h"
18 
19 namespace caffe2 {
20 
21 template <class T>
22 inline T sgn(const T x) {
23  return (x == 0 ? 0 : (x < 0 ? -1 : 1));
24 }
25 
26 template <typename T>
27 inline void ftrl_compute(
28  const T w,
29  const T n,
30  const T z,
31  const T g,
32  T& nw,
33  T& nn,
34  T& nz,
35  const FtrlParams<T>& params) {
36  auto new_n = n + g * g;
37  auto sigma = (sqrt(new_n) - sqrt(n)) * params.alphaInv;
38  nn = new_n;
39  nz = z + g - sigma * w;
40  // update the weight
41  if (std::abs(nz) > params.lambda1) {
42  nw = (params.lambda1 * sgn(nz) - nz) /
43  ((params.beta + sqrt(new_n)) * params.alphaInv + params.lambda2);
44  } else {
45  nw = 0.0;
46  }
47 }
48 
49 // TODO(dzhulgakov): implement SIMD-based version
50 template <typename Context, typename T>
51 void ftrl_update(
52  int N,
53  const T* w,
54  const T* nz,
55  const T* g,
56  T* new_w,
57  T* new_nz,
58  const FtrlParams<T>& params,
59  Context* /*context*/) {
60  // TODO(cxj): use OMP when it is reliable
61  // #pragma omp parallel for
62  for (auto i = 0; i < N; ++i) {
63  ftrl_compute(
64  w[i],
65  nz[i * 2],
66  nz[i * 2 + 1],
67  g[i],
68  new_w[i],
69  new_nz[i * 2],
70  new_nz[i * 2 + 1],
71  params);
72  }
73 }
74 
75 template <typename T, typename Context>
76 bool FtrlOp<T, Context>::RunOnDevice() {
77  // run time learning rate override
78  if (ALPHA < InputSize()) {
79  CAFFE_ENFORCE_EQ(Input(ALPHA).size(), 1, "alpha should be real-valued");
80  params_.alphaInv = 1.0 / *(Input(ALPHA).template data<T>());
81  }
82  CAFFE_ENFORCE_EQ(Input(GRAD).size(), Input(VAR).size());
83  CAFFE_ENFORCE_EQ(Input(GRAD).size() * 2, Input(N_Z).size());
84  Output(OUTPUT_VAR)->ResizeLike(Input(VAR));
85  Output(OUTPUT_N_Z)->ResizeLike(Input(N_Z));
86  ftrl_update<Context>(
87  Input(GRAD).size(),
88  Input(VAR).template data<T>(),
89  Input(N_Z).template data<T>(),
90  Input(GRAD).template data<T>(),
91  Output(OUTPUT_VAR)->template mutable_data<T>(),
92  Output(OUTPUT_N_Z)->template mutable_data<T>(),
93  params_,
94  &context_);
95  return true;
96 }
97 
98 template <typename T>
99 template <typename SIndex>
100 void SparseFtrlOp<T>::DoRun() {
101  auto* var = Output(OUTPUT_VAR);
102  auto* n_z = Output(OUTPUT_N_Z);
103  auto& indices = Input(INDICES);
104  auto& grad = Input(GRAD);
105  CAFFE_ENFORCE_EQ(&Input(VAR), var, "In place operation is required");
106  CAFFE_ENFORCE_EQ(&Input(N_Z), n_z, "In place operation is required");
107  TIndex M = var->size();
108  TIndex N = var->dim(0);
109  TIndex block_size = M / N;
110  TIndex K = indices.size();
111  DCHECK_EQ(M * 2, n_z->size());
112  DCHECK_EQ(grad.size(), K * block_size);
113  T* w = var->template mutable_data<T>();
114  T* nz = n_z->template mutable_data<T>();
115  const SIndex* idxs = indices.template data<SIndex>();
116  const T* g = grad.template data<T>();
117 
118  // TODO(cxj): use OMP when it is reliable
119  // #pragma omp parallel for
120  for (TIndex i = 0; i < K; ++i) {
121  SIndex idx = idxs[i];
122  DCHECK(0 <= idx && idx < N) << "Index out of bounds: " << idx
123  << ", range 0 to " << N;
124  if (block_size == 1) {
125  ftrl_compute(
126  w[idx],
127  nz[idx * 2],
128  nz[idx * 2 + 1],
129  g[i],
130  w[idx],
131  nz[idx * 2],
132  nz[idx * 2 + 1],
133  params_);
134  } else {
135  TIndex x = block_size * idx;
136  ftrl_update(
137  block_size,
138  w + x,
139  nz + x * 2,
140  g + i * block_size,
141  w + x,
142  nz + x * 2,
143  params_,
144  &context_);
145  }
146  }
147 }
148 
149 namespace {
150 REGISTER_CPU_OPERATOR(Ftrl, FtrlOp<float, CPUContext>);
151 OPERATOR_SCHEMA(Ftrl).NumInputs(3, 4).NumOutputs(2).AllowInplace({{0, 0},
152  {1, 1}});
153 SHOULD_NOT_DO_GRADIENT(Ftrl);
154 
155 REGISTER_CPU_OPERATOR(SparseFtrl, SparseFtrlOp<float>);
156 OPERATOR_SCHEMA(SparseFtrl)
157  .NumInputs(4, 5)
158  .NumOutputs(2)
159  .EnforceInplace({{0, 0}, {1, 1}});
160 SHOULD_NOT_DO_GRADIENT(SparseFtrl);
161 }
162 
163 }
Copyright (c) 2016-present, Facebook, Inc.