Caffe2 - C++ API
A deep learning, cross platform ML framework
Related Pages
Modules
Data Structures
Files
C++ API
Python API
GitHub
File List
Globals
caffe2
sgd
yellowfin_op.cc
1
#include "caffe2/sgd/yellowfin_op.h"
2
3
namespace
caffe2
{
4
5
REGISTER_CPU_OPERATOR(YellowFin, YellowFinOp<float, CPUContext>);
6
OPERATOR_SCHEMA(YellowFin)
7
.NumInputs(10)
8
.NumOutputs(8)
9
.AllowInplace(
10
{{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 4}, {5, 5}, {6, 6}, {7, 7}})
11
.SetDoc(R
"DOC(
12
13
Computes the YellowFin update (https://arxiv.org/abs/1706.03471) and performs
14
momentum SGD optimization step. lr and mu are not being shared between
15
parameters. curv_win, g_avg, g2_avg and scalars_memory are just auxiliary
16
memory for computing moving averages (see the publication). Takes arguments
17
beta: coefficient for moving averages,
18
curv_win_width: timeframe when average squared gradient is being stored,
19
epsilon: for numerical purposes,
20
nesterov and zero_debias for debias of moving average.
21
22
)DOC")
23
.Input(0,
"param"
,
"Parameters to be updated"
)
24
.Input(1,
"moment"
,
"Momentum"
)
25
.Input(2,
"lr"
,
"Learning rate"
)
26
.Input(3,
"mu"
,
"Momentum coefficient"
)
27
.Input(4,
"curv_win"
,
"Memory for latest curvature ranges"
)
28
.Input(5,
"g_avg"
,
"Moving average of gradient"
)
29
.Input(6,
"g2_avg"
,
"Moving average of squared gradient"
)
30
.Input(7,
"scalars_memory"
,
"Memory for stateful scalars"
)
31
.Input(8,
"grad"
,
"Gradient computed"
)
32
.Input(9,
"iter"
,
"Iteration number"
)
33
.Output(0,
"output_param"
,
"Parameters to be updated"
)
34
.Output(1,
"output_moment"
,
"Momentum"
)
35
.Output(2,
"output_lr"
,
"Output learning rate"
)
36
.Output(3,
"output_mu"
,
"Output momentum coefficient"
)
37
.Output(4,
"output_curv_win"
,
"Output memory for latest curvature ranges"
)
38
.Output(5,
"output_g_avg"
,
"Output moving average of gradient"
)
39
.Output(6,
"output_g2_avg"
,
"Output moving average of squared gradient"
)
40
.Output(7,
"output_scalars_memory"
,
"Output memory for stateful scalars"
)
41
.Arg(
"beta"
,
"Default 0.999"
)
42
.Arg(
"curv_win_width"
,
"Default 20"
)
43
.Arg(
"epsilon"
,
"Default 1e-6"
)
44
.Arg(
"nesterov"
,
"Default false"
)
45
.Arg(
"zero_debias"
,
"Default true"
);
46
47
SHOULD_NOT_DO_GRADIENT(YellowFin);
48
49
#define CAFFE2_YELLOWFIN_GETLRMU(T) \
50
template <> \
51
void YellowFinOp<T, CPUContext>::GetLrMu() { \
52
const T curv_ratio = std::sqrt(*g_norm2_max_deb_ / *g_norm2_min_deb_); \
53
const T mu_limit = (curv_ratio - 1.0f) / (curv_ratio + 1.0f); \
54
const T pre_p = *distance_deb_ * *g_norm2_min_deb_; \
55
const T p = (pre_p * pre_p) / (2.0f * *variance_); \
56
const T w3 = (-std::sqrt(p * p + 4.0f / 27.0f * p * p * p) - p) / 2.0f; \
57
const T w3_sign = w3 > 0.0f ? 1.0f : -1.0f; \
58
const T w = w3_sign * std::pow(std::abs(w3), 1.0f / 3.0f); \
59
const T y = w - p / 3.0f / w; \
60
const T root = y + 1.0f; \
61
*mu_ = std::max(root * root, mu_limit * mu_limit); \
62
*lr_ = std::pow(1.0f - std::sqrt(*mu_), 2) / *g_norm2_min_deb_; \
63
MovingAverage(1, mu_, mu_avg_, mu_avg_out_, mu_deb_); \
64
MovingAverage(1, lr_, lr_avg_, lr_avg_out_, lr_deb_); \
65
}
66
67
CAFFE2_YELLOWFIN_GETLRMU(
float
)
68
#undef CAFFE2_YELLOWFIN_GETLRMU
69
70
// Usually moment_ == moment_out_ && param_ == param_out_
71
#define CAFFE2_YELLOWFIN_MOMENTUMSGDUPDATE(T) \
72
template <> \
73
void YellowFinOp<T, CPUContext>::MomentumSgdUpdate() { \
74
const T mu = *mu_avg_out_; \
75
const T lr = *lr_avg_out_; \
76
if (!nesterov_) { \
77
for (int i = 0; i < D_; ++i) { \
78
moment_out_[i] = mu * moment_[i] + lr * grad_[i]; \
79
param_out_[i] = param_[i] - moment_out_[i]; \
80
} \
81
} else { \
82
for (int i = 0; i < D_; ++i) { \
83
const T moment_i = moment_[i]; \
84
moment_out_[i] = mu * moment_i + lr * grad_[i]; \
85
param_out_[i] = param_[i] - (1 + mu) * moment_out_[i] + mu * moment_i; \
86
} \
87
} \
88
}
89
90
CAFFE2_YELLOWFIN_MOMENTUMSGDUPDATE(
float
)
91
#undef CAFFE2_YELLOWFIN_MOMENTUMSGDUPDATE
92
93
}
// caffe2
caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition:
blob.h:13
Generated on Thu Mar 21 2019 13:06:21 for Caffe2 - C++ API by
1.8.11