Caffe2 - C++ API
A deep learning, cross platform ML framework
ftrl_op.cc
1 #include "ftrl_op.h"
2 
3 namespace caffe2 {
4 
5 template <class T>
6 inline T sgn(const T x) {
7  return (x == 0 ? 0 : (x < 0 ? -1 : 1));
8 }
9 
10 template <typename T>
11 inline void ftrl_compute(
12  const T w,
13  const T n,
14  const T z,
15  const T g,
16  T& nw,
17  T& nn,
18  T& nz,
19  const FtrlParams<T>& params) {
20  auto new_n = n + g * g;
21  auto sigma = (sqrt(new_n) - sqrt(n)) * params.alphaInv;
22  nn = new_n;
23  nz = z + g - sigma * w;
24  // update the weight
25  if (std::abs(nz) > params.lambda1) {
26  nw = (params.lambda1 * sgn(nz) - nz) /
27  ((params.beta + sqrt(new_n)) * params.alphaInv + params.lambda2);
28  } else {
29  nw = 0.0;
30  }
31 }
32 
33 // TODO(dzhulgakov): implement SIMD-based version
34 template <typename Context, typename T>
35 void ftrl_update(
36  int N,
37  const T* w,
38  const T* nz,
39  const T* g,
40  T* new_w,
41  T* new_nz,
42  const FtrlParams<T>& params,
43  Context* /*context*/) {
44  // TODO(cxj): use OMP when it is reliable
45  // #pragma omp parallel for
46  for (auto i = 0; i < N; ++i) {
47  ftrl_compute(
48  w[i],
49  nz[i * 2],
50  nz[i * 2 + 1],
51  g[i],
52  new_w[i],
53  new_nz[i * 2],
54  new_nz[i * 2 + 1],
55  params);
56  }
57 }
58 
59 template <typename T, typename Context>
60 bool FtrlOp<T, Context>::RunOnDevice() {
61  // run time learning rate override
62  if (ALPHA < InputSize()) {
63  CAFFE_ENFORCE_EQ(Input(ALPHA).numel(), 1, "alpha should be real-valued");
64  params_.alphaInv = 1.0 / *(Input(ALPHA).template data<T>());
65  }
66  CAFFE_ENFORCE_EQ(Input(GRAD).numel(), Input(VAR).numel());
67  CAFFE_ENFORCE_EQ(Input(GRAD).numel() * 2, Input(N_Z).numel());
68  Output(OUTPUT_VAR)->ResizeLike(Input(VAR));
69  Output(OUTPUT_N_Z)->ResizeLike(Input(N_Z));
70  ftrl_update<Context>(
71  Input(GRAD).numel(),
72  Input(VAR).template data<T>(),
73  Input(N_Z).template data<T>(),
74  Input(GRAD).template data<T>(),
75  Output(OUTPUT_VAR)->template mutable_data<T>(),
76  Output(OUTPUT_N_Z)->template mutable_data<T>(),
77  params_,
78  &context_);
79  return true;
80 }
81 
82 template <typename T>
83 template <typename SIndex>
84 void SparseFtrlOp<T>::DoRun() {
85  auto* var = Output(OUTPUT_VAR);
86  auto* n_z = Output(OUTPUT_N_Z);
87  auto& indices = Input(INDICES);
88  auto& grad = Input(GRAD);
89  CAFFE_ENFORCE_EQ(&Input(VAR), var, "In place operation is required");
90  CAFFE_ENFORCE_EQ(&Input(N_Z), n_z, "In place operation is required");
91  int64_t M = var->numel();
92  int64_t N = var->size(0);
93  int64_t block_size = M / N;
94  int64_t K = indices.numel();
95  DCHECK_EQ(M * 2, n_z->numel());
96  DCHECK_EQ(grad.numel(), K * block_size);
97  T* w = var->template mutable_data<T>();
98  T* nz = n_z->template mutable_data<T>();
99  const SIndex* idxs = indices.template data<SIndex>();
100  const T* g = grad.template data<T>();
101 
102  // TODO(cxj): use OMP when it is reliable
103  // #pragma omp parallel for
104  for (int64_t i = 0; i < K; ++i) {
105  SIndex idx = idxs[i];
106  DCHECK(0 <= idx && idx < N) << "Index out of bounds: " << idx
107  << ", range 0 to " << N;
108  if (block_size == 1) {
109  ftrl_compute(
110  w[idx],
111  nz[idx * 2],
112  nz[idx * 2 + 1],
113  g[i],
114  w[idx],
115  nz[idx * 2],
116  nz[idx * 2 + 1],
117  params_);
118  } else {
119  int64_t x = block_size * idx;
120  ftrl_update(
121  block_size,
122  w + x,
123  nz + x * 2,
124  g + i * block_size,
125  w + x,
126  nz + x * 2,
127  params_,
128  &context_);
129  }
130  }
131 }
132 
133 namespace {
134 REGISTER_CPU_OPERATOR(Ftrl, FtrlOp<float, CPUContext>);
135 OPERATOR_SCHEMA(Ftrl).NumInputs(3, 4).NumOutputs(2).AllowInplace({{0, 0},
136  {1, 1}});
137 SHOULD_NOT_DO_GRADIENT(Ftrl);
138 
139 REGISTER_CPU_OPERATOR(SparseFtrl, SparseFtrlOp<float>);
140 OPERATOR_SCHEMA(SparseFtrl)
141  .NumInputs(4, 5)
142  .NumOutputs(2)
143  .EnforceInplace({{0, 0}, {1, 1}});
144 SHOULD_NOT_DO_GRADIENT(SparseFtrl);
145 }
146 
147 }
Definition: any.cpp:108
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13