Caffe2 - C++ API
A deep learning, cross platform ML framework
adagrad.cc
1 #include "caffe2/perfkernels/adagrad.h"
2 
3 #include <cmath>
4 
5 #include "caffe2/perfkernels/common.h"
6 
7 namespace caffe2 {
8 
9 void adagrad_update__base(
10  int N,
11  const float* w,
12  const float* g,
13  const float* h,
14  float* nw,
15  float* nh,
16  float epsilon,
17  float decay,
18  const float lr) {
19  internal::adagrad_update_base_inlined(N, w, g, h, nw, nh, decay, epsilon, lr);
20 }
21 
22 void adagrad_update_prefetch__base(
23  int N,
24  const float* w,
25  const float* /* w_n */, // prefetch ptr
26 
27  const float* g,
28 
29  const float* h,
30  const float* /* h_n */, // prefetch ptr
31 
32  float* nw,
33  float* /* nw_n */, // prefetch ptr
34 
35  float* nh,
36  float* /* nh_n */, // prefetch ptr
37 
38  float epsilon,
39  float lr) {
40  adagrad_update__base(N, w, g, h, nw, nh, epsilon, 1.0f, lr);
41 }
42 
43 void adagrad_fp16_update_prefetch__base(
44  int N,
45  const at::Half* w,
46  const at::Half* /* w_n */, // prefetch ptr
47  const float* g,
48  const at::Half* h,
49  const at::Half* /* h_n */, // prefetch ptr
50  at::Half* nw,
51  at::Half* /* nw_n */, // prefetch ptr
52  at::Half* nh,
53  at::Half* /* nh_n */, // prefetch ptr
54  float epsilon,
55  float lr) {
56  internal::adagrad_update_base_inlined(N, w, g, h, nw, nh, 1.0f, epsilon, lr);
57 }
58 
59 void rowwise_adagrad_update__base(
60  int N,
61  float* w,
62  float* w_n, // prefetch ptr
63 
64  const float* g,
65 
66  float* h,
67  float* h_n, // prefetch ptr
68 
69  float epsilon,
70  float lr) {
71  internal::rowwise_adagrad_update_inlined(N, w, w_n, g, h, h_n, epsilon, lr);
72 }
73 
74 // version without prefetching
75 decltype(adagrad_update__base) adagrad_update__avx_f16c;
76 void adagrad_update(
77  int N,
78  const float* w,
79  const float* g,
80  const float* h,
81  float* nw,
82  float* nh,
83  float epsilon,
84  float decay,
85  float lr) {
86  AVX_F16C_DO(adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr);
87  BASE_DO(adagrad_update, N, w, g, h, nw, nh, epsilon, decay, lr);
88 }
89 
90 decltype(adagrad_update_prefetch__base) adagrad_update_prefetch__avx_f16c;
91 void adagrad_update_prefetch(
92  int N,
93  const float* w,
94  const float* w_n, // prefetch ptr
95 
96  const float* g,
97 
98  const float* h,
99  const float* h_n, // prefetch ptr
100 
101  float* nw,
102  float* nw_n, // prefetch ptr
103 
104  float* nh,
105  float* nh_n, // prefetch ptr
106 
107  float epsilon,
108  float lr) {
109  AVX_F16C_DO(
110  adagrad_update_prefetch,
111  N,
112  w,
113  w_n,
114  g,
115  h,
116  h_n,
117  nw,
118  nw_n,
119  nh,
120  nh_n,
121  epsilon,
122  lr);
123  BASE_DO(
124  adagrad_update_prefetch,
125  N,
126  w,
127  w_n,
128  g,
129  h,
130  h_n,
131  nw,
132  nw_n,
133  nh,
134  nh_n,
135  epsilon,
136  lr);
137 }
138 
139 // Version with prefetching for embeddings and
140 // momentum using fp16
141 decltype(
142  adagrad_fp16_update_prefetch__base) adagrad_fp16_update_prefetch__avx_f16c;
143 void adagrad_fp16_update_prefetch(
144  int N,
145  const at::Half* w,
146  const at::Half* w_n, // prefetch ptr
147  const float* g,
148  const at::Half* h,
149  const at::Half* h_n, // prefetch ptr
150  at::Half* nw,
151  at::Half* nw_n, // prefetch ptr
152  at::Half* nh,
153  at::Half* nh_n, // prefetch ptr
154  float epsilon,
155  float lr) {
156  AVX_F16C_DO(
157  adagrad_fp16_update_prefetch,
158  N,
159  w,
160  w_n,
161  g,
162  h,
163  h_n,
164  nw,
165  nw_n,
166  nh,
167  nh_n,
168  epsilon,
169  lr);
170  BASE_DO(
171  adagrad_fp16_update_prefetch,
172  N,
173  w,
174  w_n,
175  g,
176  h,
177  h_n,
178  nw,
179  nw_n,
180  nh,
181  nh_n,
182  epsilon,
183  lr);
184 }
185 
186 decltype(rowwise_adagrad_update__base) rowwise_adagrad_update__avx_f16c;
187 void rowwise_adagrad_update(
188  int N,
189  float* w,
190  float* w_n, // prefetch ptr
191 
192  const float* g,
193 
194  float* h,
195  float* h_n, // prefetch ptr
196 
197  float epsilon,
198  float lr) {
199  AVX_F16C_DO(rowwise_adagrad_update, N, w, w_n, g, h, h_n, epsilon, lr);
200  BASE_DO(rowwise_adagrad_update, N, w, w_n, g, h, h_n, epsilon, lr);
201 }
202 
203 SPARSE_ADAGRAD_SPECIALIZATION(int32_t, base);
204 
205 decltype(sparse_adagrad_int32_t__base) sparse_adagrad_int32_t__avx_f16c;
206 template <>
207 int sparse_adagrad(
208  int num_rows,
209  int block_size,
210  uint64_t param_size,
211  const float* w,
212  const float* g,
213  const float* h,
214  const int32_t* indices,
215  float* nw,
216  float* nh,
217  float epsilon,
218  float lr) {
219  AVX_F16C_DO(
220  sparse_adagrad_int32_t,
221  num_rows,
222  block_size,
223  param_size,
224  w,
225  g,
226  h,
227  indices,
228  nw,
229  nh,
230  epsilon,
231  lr);
232  BASE_DO(
233  sparse_adagrad_int32_t,
234  num_rows,
235  block_size,
236  param_size,
237  w,
238  g,
239  h,
240  indices,
241  nw,
242  nh,
243  epsilon,
244  lr);
245 }
246 
247 SPARSE_ADAGRAD_SPECIALIZATION(int64_t, base);
248 
249 decltype(sparse_adagrad_int64_t__base) sparse_adagrad_int64_t__avx_f16c;
250 template <>
251 int sparse_adagrad(
252  int num_rows,
253  int block_size,
254  uint64_t param_size,
255  const float* w,
256  const float* g,
257  const float* h,
258  const int64_t* indices,
259  float* nw,
260  float* nh,
261  float epsilon,
262  float lr) {
263  AVX_F16C_DO(
264  sparse_adagrad_int64_t,
265  num_rows,
266  block_size,
267  param_size,
268  w,
269  g,
270  h,
271  indices,
272  nw,
273  nh,
274  epsilon,
275  lr);
276  BASE_DO(
277  sparse_adagrad_int64_t,
278  num_rows,
279  block_size,
280  param_size,
281  w,
282  g,
283  h,
284  indices,
285  nw,
286  nh,
287  epsilon,
288  lr);
289 }
290 
291 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Flush-To-Zero and Denormals-Are-Zero mode.