Caffe2 - C++ API
A deep learning, cross platform ML framework
gftrl_op.cc
1 #include "gftrl_op.h"
2 
3 namespace caffe2 {
4 
5 // Computes one coordinate
6 template <typename T>
7 
8 inline void gftrl_compute(
9  const T& w,
10  const T& n,
11  const T& z,
12  const T& g,
13  T& nw,
14  T& nn,
15  T& nz,
16  const T& z_norm,
17  const int OutputDim,
18  const GFtrlParams<T>& params) {
19  auto new_n = n + g * g;
20  auto sigma = (sqrt(new_n) - sqrt(n)) * params.alphaInv;
21  nn = new_n;
22  nz = z + g - sigma * w;
23  // update the weight
24  if (z_norm > params.lambda1 * std::sqrt(OutputDim)) {
25  nw = nz * (params.lambda1 * std::sqrt(OutputDim) / z_norm - 1) /
26  ((params.beta + sqrt(new_n)) * params.alphaInv + params.lambda2);
27  } else {
28  nw = 0.0;
29  }
30 }
31 
32 template <typename Context, typename T>
33 void gftrl_update(
34  int OutputDim, // # of output nodes
35  int InputDim, // # of input features
36  const T* w,
37  const T* nz,
38  const T* g,
39  T* new_w,
40  T* new_nz,
41  const GFtrlParams<T>& params,
42  Context* /*context*/) {
43  for (auto j = 0; j < InputDim; ++j) {
44  T z_norm = 0.0;
45  for (auto i = 0; i < OutputDim; ++i) {
46  int idx = i * InputDim + j;
47  auto new_n = nz[idx * 2] + g[idx] * g[idx];
48  auto sigma = (sqrt(new_n) - sqrt(nz[idx * 2])) * params.alphaInv;
49  auto new_z = nz[idx * 2 + 1] + g[idx] - sigma * w[idx];
50  z_norm = z_norm + new_z * new_z;
51  }
52 
53  z_norm = sqrt(z_norm);
54  for (auto i = 0; i < OutputDim; ++i) {
55  int idx = i * InputDim + j;
56  gftrl_compute(
57  w[idx],
58  nz[idx * 2],
59  nz[idx * 2 + 1],
60  g[idx],
61  new_w[idx],
62  new_nz[idx * 2],
63  new_nz[idx * 2 + 1],
64  z_norm,
65  OutputDim,
66  params);
67  }
68  }
69 }
70 
71 template <typename T, typename Context>
72 bool GFtrlOp<T, Context>::RunOnDevice() {
73  // run time learning rate override
74  if (ALPHA < InputSize()) {
75  CAFFE_ENFORCE_EQ(Input(ALPHA).numel(), 1, "alpha should be real-valued");
76  params_.alphaInv = 1.0 / *(Input(ALPHA).template data<T>());
77  }
78 
79  CAFFE_ENFORCE_EQ(Input(GRAD).numel(), Input(VAR).numel());
80  CAFFE_ENFORCE_EQ(Input(GRAD).numel() * 2, Input(N_Z).numel());
81  Output(OUTPUT_VAR)->ResizeLike(Input(VAR));
82  Output(OUTPUT_N_Z)->ResizeLike(Input(N_Z));
83  gftrl_update<Context>(
84  Input(GRAD).size(0), // # of output nodes
85  Input(GRAD).numel() / Input(GRAD).size(0), // # of input features
86  Input(VAR).template data<T>(),
87  Input(N_Z).template data<T>(),
88  Input(GRAD).template data<T>(),
89  Output(OUTPUT_VAR)->template mutable_data<T>(),
90  Output(OUTPUT_N_Z)->template mutable_data<T>(),
91  params_,
92  &context_);
93  return true;
94 }
95 
96 namespace {
97 REGISTER_CPU_OPERATOR(GFtrl, GFtrlOp<float, CPUContext>);
98 OPERATOR_SCHEMA(GFtrl).NumInputs(3, 4).NumOutputs(2).AllowInplace({{0, 0},
99  {1, 1}});
100 SHOULD_NOT_DO_GRADIENT(GFtrl);
101 
102 } // namespace
103 
104 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13