Caffe2 - C++ API
A deep learning, cross platform ML framework
adam_op.cc
1 #include <caffe2/ideep/ideep_utils.h>
2 
3 namespace caffe2 {
4 
5 void adam_ideep_update(
6  int N,
7  const float* g,
8  const float* m,
9  const float* v,
10  float* ng,
11  float* nm,
12  float* nv,
13  float beta1,
14  float beta2,
15  float eps_hat,
16  float correction,
17  const float* lr) {
18 #ifdef _OPENMP
19  #pragma omp parallel for schedule(static)
20 #endif
21  for (auto i = 0; i < N; ++i) {
22  float gi = g[i];
23  float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
24  float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
25  ng[i] = lr[0] * correction * mi / (std::sqrt(vi) + eps_hat);
26  }
27 }
28 
29 void adam_ideep_compute(
30  int N,
31  const float* w,
32  const float* g,
33  const float* m,
34  const float* v,
35  float* nw,
36  float* nm,
37  float* nv,
38  float beta1,
39  float beta2,
40  float eps_hat,
41  float correction,
42  const float* lr) {
43 #ifdef _OPENMP
44  #pragma omp parallel for schedule(static)
45 #endif
46  for (auto i = 0; i < N; ++i) {
47  float gi = g[i];
48  float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
49  float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
50  nw[i] = w[i] + lr[0] * correction * mi / (std::sqrt(vi) + eps_hat);
51  }
52 }
53 
54 void adam_ideep_compute_output_grad(
55  int N,
56  const float* w,
57  const float* g,
58  const float* m,
59  const float* v,
60  float* nw,
61  float* nm,
62  float* nv,
63  float* ng,
64  float beta1,
65  float beta2,
66  float eps_hat,
67  float correction,
68  const float* lr) {
69 
70 #ifdef _OPENMP
71  #pragma omp parallel for schedule(static)
72 #endif
73  for (auto i = 0; i < N; ++i) {
74  float gi = g[i];
75  float mi = nm[i] = m[i] * beta1 + gi * (1 - beta1);
76  float vi = nv[i] = v[i] * beta2 + gi * gi * (1 - beta2);
77  float ngi = ng[i] = correction * mi / (std::sqrt(vi) + eps_hat);
78  nw[i] = w[i] + lr[0] * ngi;
79  }
80 }
81 
82 template <typename T>
83 class IDEEPAdamOp final : public IDEEPOperator {
84  public:
85  USE_IDEEP_DEF_ALIASES();
86  USE_IDEEP_OPERATOR_FUNCTIONS();
87 
88  IDEEPAdamOp(const OperatorDef& operator_def, Workspace* ws)
89  : IDEEPOperator(operator_def, ws),
90  beta1_(OperatorBase::GetSingleArgument<float>("beta1", 0.9f)),
91  beta2_(OperatorBase::GetSingleArgument<float>("beta2", 0.999f)),
92  epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5f)) {}
93  bool RunOnDevice() override {
94  // Iter live on the CPU
95  CAFFE_ENFORCE(OperatorBase::InputIsTensorType(ITER, CPU));
96  const auto& params = Input(PARAM);
97  const auto& moment_1 = Input(MOMENT_1);
98  const auto& moment_2 = Input(MOMENT_2);
99  const auto& grad = Input(GRAD);
100  // TODO: Use itensor after 0-dim is supported. Now use CPU tensor.
101  const auto& lr = OperatorBase::Input<TensorCPU>(LR, CPU);
102  auto* out_params = Output(OUTPUT_PARAM);
103  auto* out_moment1 = Output(OUTPUT_MOMENT_1);
104  auto* out_moment2 = Output(OUTPUT_MOMENT_2);
105 
106  CAFFE_ENFORCE(lr.size() == 1);
107  CAFFE_ENFORCE(grad.get_nelems() == params.get_nelems());
108  CAFFE_ENFORCE(grad.get_nelems() == moment_1.get_nelems());
109  CAFFE_ENFORCE(grad.get_nelems() == moment_2.get_nelems());
110  if (params != *out_params)
111  out_params->reinit(params.get_descriptor());
112  if (moment_1 != *out_moment1)
113  out_moment1->reinit(moment_1.get_descriptor());
114  if (moment_2 != *out_moment2)
115  out_moment2->reinit(moment_2.get_descriptor());
116  const auto w = static_cast<float *>(params.get_data_handle());
117  const auto g = static_cast<float *>(grad.get_data_handle());
118  const auto m = static_cast<float *>(moment_1.get_data_handle());
119  const auto v = static_cast<float *>(moment_2.get_data_handle());
120  auto nw = static_cast<float *>(out_params->get_data_handle());
121  auto nm = static_cast<float *>(out_moment1->get_data_handle());
122  auto nv = static_cast<float *>(out_moment2->get_data_handle());
123  const auto nlr = lr.template data<T>();
124  const auto iter =
125  OperatorBase::Input<TensorCPU>(ITER, CPU).template data<int64_t>()[0];
126  const auto t = iter + 1;
127  const auto correction =
128  std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t));
129  if (OutputSize() == 3) {
130  adam_ideep_compute(
131  grad.get_nelems(),
132  w,
133  g,
134  m,
135  v,
136  nw,
137  nm,
138  nv,
139  beta1_,
140  beta2_,
141  epsilon_,
142  correction,
143  nlr);
144  } else {
145  auto* out_grad = Output(OUTPUT_GRAD);
146  if (grad != *out_grad)
147  out_grad->reinit(grad.get_descriptor());
148  auto ng = static_cast<float *>(out_grad->get_data_handle());
149  adam_ideep_compute_output_grad(
150  grad.get_nelems(),
151  w,
152  g,
153  m,
154  v,
155  nw,
156  nm,
157  nv,
158  ng,
159  beta1_,
160  beta2_,
161  epsilon_,
162  correction,
163  nlr);
164  }
165 
166  return true;
167  }
168 
169  protected:
170  T beta1_{0.9};
171  T beta2_{0.999};
172  T epsilon_{1e-8};
173  INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, GRAD, LR, ITER);
174  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2, OUTPUT_GRAD);
175 };
176 
177 REGISTER_IDEEP_OPERATOR(Adam, IDEEPAdamOp<float>);
178 
179 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13