Caffe2 - C++ API
A deep learning, cross platform ML framework
optim_baseline.h
1 // @generated from test/cpp/api/optim_baseline.py
2 
3 #include <torch/types.h>
4 
5 #include <vector>
6 
7 namespace expected_parameters {
8 
9 inline std::vector<std::vector<torch::Tensor>> Adam() {
10  return {
11  {
12  torch::tensor({0.7889791973017408, 0.5023527440741749, 0.8586918159203789, 0.6579591153929213, 0.747610883848348, 1.697537897359327}),
13  torch::tensor({0.8914325948147117, 0.7020467393446147, 1.6891939505415117}),
14  torch::tensor({-1.0508020464078212, -1.3941340315784612, -1.2843369730699878}),
15  torch::tensor({-1.0711376814874036}),
16  },
17  {
18  torch::tensor({8.232343369651838, 7.970643300186943, 6.643546447481872, 6.470927350729255, 6.1699929180461135, 7.150644529115176}),
19  torch::tensor({8.417513698774671, 6.597182008001362, 7.231731333798338}),
20  torch::tensor({-6.7296200590850805, -7.097441464483235, -6.7533081426144665}),
21  torch::tensor({-6.435644769127909}),
22  },
23  {
24  torch::tensor({8.232728629793431, 7.971029896507964, 6.643845645407439, 6.471228017080045, 6.170273299845014, 7.150926509034132}),
25  torch::tensor({8.41790341440641, 6.597486966540781, 7.232017965615688}),
26  torch::tensor({-6.729913952070152, -7.097736635913622, -6.7535910561856145}),
27  torch::tensor({-6.435922230155599}),
28  },
29  {
30  torch::tensor({8.232728644291049, 7.9710299110617, 6.6438456566822675, 6.471228028420597, 6.170273310405472, 7.150926519662466}),
31  torch::tensor({8.417903429093856, 6.597486978073832, 7.232017976441867}),
32  torch::tensor({-6.729913994977667, -7.0977366809700015, -6.753591085251188}),
33  torch::tensor({-6.435922253675698}),
34  },
35  {
36  torch::tensor({8.232728644308507, 7.971029911086269, 6.643845656714944, 6.471228028466046, 6.170273310429644, 7.150926519696103}),
37  torch::tensor({8.417903429138356, 6.597486978157095, 7.232017976503389}),
38  torch::tensor({-6.729914033672082, -7.097736722215963, -6.753591107635156}),
39  torch::tensor({-6.435922269573368}),
40  },
41  {
42  torch::tensor({8.232728644328265, 7.971029911114316, 6.643845656752554, 6.471228028518527, 6.170273310457381, 7.150926519734861}),
43  torch::tensor({8.417903429189604, 6.597486978253601, 7.232017976574605}),
44  torch::tensor({-6.729914078724216, -7.09773677023889, -6.753591133696778}),
45  torch::tensor({-6.435922288082886}),
46  },
47  {
48  torch::tensor({8.232728644350706, 7.9710299111461715, 6.643845656795268, 6.471228028578131, 6.1702733104888825, 7.150926519778878}),
49  torch::tensor({8.417903429247799, 6.597486978363203, 7.232017976655485}),
50  torch::tensor({-6.729914129890292, -7.0977368247789245, -6.75359116329518}),
51  torch::tensor({-6.435922309104302}),
52  },
53  {
54  torch::tensor({8.232728644375776, 7.971029911181762, 6.64384565684299, 6.471228028644725, 6.170273310524078, 7.150926519828057}),
55  torch::tensor({8.417903429312823, 6.597486978485657, 7.232017976745851}),
56  torch::tensor({-6.729914187056874, -7.097736885715142, -6.753591196364736}),
57  torch::tensor({-6.435922332591009}),
58  },
59  {
60  torch::tensor({8.232728644403466, 7.97102991122107, 6.643845656895699, 6.471228028718272, 6.17027331056295, 7.150926519882374}),
61  torch::tensor({8.417903429384637, 6.597486978620901, 7.232017976845656}),
62  torch::tensor({-6.729914250194772, -7.09773695301644, -6.753591232888567}),
63  torch::tensor({-6.435922358531008}),
64  },
65  {
66  torch::tensor({8.232728644433786, 7.9710299112641145, 6.643845656953418, 6.471228028798811, 6.170273310605518, 7.150926519941853}),
67  torch::tensor({8.41790342946328, 6.597486978769002, 7.23201797695495}),
68  torch::tensor({-6.729914319334562, -7.097737026715396, -6.753591272884355}),
69  torch::tensor({-6.43592238693687}),
70  },
71  {
72  torch::tensor({8.232728644466773, 7.971029911310943, 6.64384565701621, 6.471228028886431, 6.170273310651826, 7.150926520006559}),
73  torch::tensor({8.417903429548836, 6.5974869789301245, 7.23201797707385}),
74  torch::tensor({-6.729914394552445, -7.097737106893246, -6.753591316396183}),
75  torch::tensor({-6.435922417839901}),
76  },
77  };
78 }
79 
80 inline std::vector<std::vector<torch::Tensor>> Adam_with_weight_decay() {
81  return {
82  {
83  torch::tensor({0.7890338917145869, 0.5024064972281554, 0.8586928862731152, 0.6579604913213795, 0.7476152291155436, 1.697523935068844}),
84  torch::tensor({0.8914365922767382, 0.7020469437416427, 1.689192459420757}),
85  torch::tensor({-1.0508020445177773, -1.3941340146813552, -1.2843369695447353}),
86  torch::tensor({-1.071137681045855}),
87  },
88  {
89  torch::tensor({0.1783589728828845, 0.25421417357795134, 0.19682011079203035, 0.2352275872329244, 0.17806013441679713, 0.22943640290803413}),
90  torch::tensor({0.6227661366261539, 0.6058596073202991, 0.607717700489737}),
91  torch::tensor({-1.4259754714918282, -1.433334796707656, -1.4085456423279246}),
92  torch::tensor({-2.0710783910024624}),
93  },
94  {
95  torch::tensor({0.17965695285279176, 0.24254347996350384, 0.1796466384372904, 0.24250836158041741, 0.17962895987963462, 0.24249920721192175}),
96  torch::tensor({0.6287145245742372, 0.6286955878301645, 0.6286563325801346}),
97  torch::tensor({-1.412388723147344, -1.4124007117933108, -1.4122701547931238}),
98  torch::tensor({-2.0633570397920584}),
99  },
100  {
101  torch::tensor({0.17963666509735718, 0.2425086193101887, 0.1796373176307127, 0.242508611366408, 0.17963720028927957, 0.24250890248690504}),
102  torch::tensor({0.6287221269545458, 0.6287225821212604, 0.628722027552364}),
103  torch::tensor({-1.4123466102957034, -1.4123465669345845, -1.4123462614739397}),
104  torch::tensor({-2.063368365141719}),
105  },
106  {
107  torch::tensor({0.17963666103062498, 0.24250882317040753, 0.17963665831242875, 0.24250882481070415, 0.1796366602923916, 0.24250882426284776}),
108  torch::tensor({0.6287216329916353, 0.6287216340516892, 0.6287216326966388}),
109  torch::tensor({-1.4123467542625323, -1.4123467542352317, -1.4123467478192078}),
110  torch::tensor({-2.0633690432441707}),
111  },
112  {
113  torch::tensor({0.17963666098500547, 0.2425088244237952, 0.17963666099348424, 0.2425088244120238, 0.17963666097250158, 0.24250882441058078}),
114  torch::tensor({0.6287216343798244, 0.6287216343800633, 0.6287216343742589}),
115  torch::tensor({-1.4123467490742736, -1.412346749072556, -1.4123467490678534}),
116  torch::tensor({-2.063369043442539}),
117  },
118  {
119  torch::tensor({0.17963666098407072, 0.24250882442250232, 0.17963666098407363, 0.24250882442242327, 0.17963666098409353, 0.24250882442251578}),
120  torch::tensor({0.6287216343836612, 0.6287216343836145, 0.6287216343836255}),
121  torch::tensor({-1.4123467490672261, -1.4123467490672432, -1.412346749067169}),
122  torch::tensor({-2.063369043434909}),
123  },
124  {
125  torch::tensor({0.17963666098407027, 0.242508824422441, 0.1796366609840707, 0.24250882442244046, 0.1796366609840702, 0.24250882442244093}),
126  torch::tensor({0.6287216343837065, 0.6287216343837063, 0.6287216343837067}),
127  torch::tensor({-1.4123467490671706, -1.412346749067171, -1.4123467490671715}),
128  torch::tensor({-2.0633690434349052}),
129  },
130  {
131  torch::tensor({0.17963666098407013, 0.24250882442244104, 0.17963666098407038, 0.2425088244224411, 0.1796366609840705, 0.242508824422441}),
132  torch::tensor({0.628721634383707, 0.6287216343837067, 0.6287216343837067}),
133  torch::tensor({-1.4123467490671706, -1.4123467490671708, -1.4123467490671706}),
134  torch::tensor({-2.0633690434349052}),
135  },
136  {
137  torch::tensor({0.1796366609840703, 0.24250882442244112, 0.17963666098407036, 0.24250882442244112, 0.17963666098407038, 0.242508824422441}),
138  torch::tensor({0.628721634383707, 0.6287216343837068, 0.6287216343837069}),
139  torch::tensor({-1.4123467490671706, -1.4123467490671708, -1.4123467490671706}),
140  torch::tensor({-2.0633690434349052}),
141  },
142  {
143  torch::tensor({0.1796366609840689, 0.24250882442244087, 0.1796366609840702, 0.2425088244224409, 0.17963666098407016, 0.2425088244224408}),
144  torch::tensor({0.6287216343837064, 0.6287216343837068, 0.6287216343837067}),
145  torch::tensor({-1.4123467490671706, -1.412346749067171, -1.4123467490671706}),
146  torch::tensor({-2.0633690434349052}),
147  },
148  };
149 }
150 
151 inline std::vector<std::vector<torch::Tensor>> Adam_with_weight_decay_and_amsgrad() {
152  return {
153  {
154  torch::tensor({0.7889792072185753, 0.502352757161707, 0.8586918160350755, 0.6579591155483523, 0.7476108843716649, 1.6975378965521928}),
155  torch::tensor({0.8914325952640971, 0.7020467393659713, 1.6891939504169646}),
156  torch::tensor({-1.0508020464076324, -1.3941340315767958, -1.2843369730696377}),
157  torch::tensor({-1.0711376814873597}),
158  },
159  {
160  torch::tensor({6.790172150533645, 6.914645717041209, 6.415265715837617, 6.297596948228211, 5.845043191735576, 6.8621426309929}),
161  torch::tensor({7.958560080109429, 6.511332850580849, 7.100944983002762}),
162  torch::tensor({-6.690685400916639, -7.0565911313808, -6.7211550236002875}),
163  torch::tensor({-6.4066139474811115}),
164  },
165  {
166  torch::tensor({4.707385912311418, 5.291370521290748, 6.045088989172656, 6.0243608811971, 5.309433796539775, 6.388035972014135}),
167  torch::tensor({7.200400889176835, 6.398381825147094, 6.904102965924078}),
168  torch::tensor({-6.664146183636052, -7.026722729554791, -6.705827953119048}),
169  torch::tensor({-6.3963166164396865}),
170  },
171  {
172  torch::tensor({2.950915886969457, 3.7657694681407414, 5.607364810245922, 5.6957520744313666, 4.701371245420197, 5.835001213891114}),
173  torch::tensor({6.3434824675494825, 6.258238067497203, 6.6630019537301886}),
174  torch::tensor({-6.630457443730077, -6.988861422125357, -6.686200564723017}),
175  torch::tensor({-6.3830161296508745}),
176  },
177  {
178  torch::tensor({1.713005764813263, 2.536434364252745, 5.140371990006822, 5.337998982747556, 4.083858717692442, 5.254544687928855}),
179  torch::tensor({5.477920783257686, 6.100064103985914, 6.3947278642170975}),
180  torch::tensor({-6.591738209089138, -6.94536203674196, -6.663591073998776}),
181  torch::tensor({-6.3676809797869085}),
182  },
183  {
184  torch::tensor({0.9342570168001864, 1.6340561447464659, 4.667953486696388, 4.96765960685892, 3.4932659094295624, 4.678153658892501}),
185  torch::tensor({4.655156481111641, 5.929480127155326, 6.110044531515085}),
186  torch::tensor({-6.549112173520781, -6.897492853519762, -6.63863605131254}),
187  torch::tensor({-6.35073759645243}),
188  },
189  {
190  torch::tensor({0.483598008466956, 1.0143518793820192, 4.205896945782518, 4.596169793720168, 2.950176323373372, 4.125793850415685}),
191  torch::tensor({3.9036580383003816, 5.750502254470126, 5.816601628842057}),
192  torch::tensor({-6.5033672421993005, -6.846143861683003, -6.611779357893088}),
193  torch::tensor({-6.332482574254225}),
194  },
195  {
196  torch::tensor({0.23940946262165735, 0.6100979055073588, 3.764576093338919, 4.23158112532352, 2.464744464995124, 3.6096701455897247}),
197  torch::tensor({3.236802951205823, 5.566165699240412, 5.520071244528464}),
198  torch::tensor({-6.4550968672620765, -6.791985760021875, -6.58335397163011}),
199  torch::tensor({-6.313137987164135}),
200  },
201  {
202  torch::tensor({0.1140432483273187, 0.35710698257092627, 3.3504972618755904, 3.879514641744938, 2.040187813168003, 3.1365304963147236}),
203  torch::tensor({2.6580472298004074, 5.378832755537033, 5.224730288801605}),
204  torch::tensor({-6.40476778326666, -6.735546192994843, -6.553621096227888}),
205  torch::tensor({-6.292877798552594}),
206  },
207  {
208  torch::tensor({0.05253197784154339, 0.2041401230295221, 2.9673514620600994, 3.5437785414463425, 1.6752543175367425, 2.7092725988958803}),
209  torch::tensor({2.1645721260950497, 5.190371109223889, 4.933812708502093}),
210  torch::tensor({-6.352757654031313, -6.677252148986569, -6.52279182009777}),
211  torch::tensor({-6.271842467566972}),
212  },
213  {
214  torch::tensor({0.0234986931355529, 0.11430841965659456, 2.616783974306594, 3.226808143825143, 1.3659901252130222, 2.328124446275979}),
215  torch::tensor({1.7498774217269997, 5.002268663873182, 4.649746694230509}),
216  torch::tensor({-6.299378178284937, -6.617456004434928, -6.491040351925264}),
217  torch::tensor({-6.250147860242402}),
218  },
219  };
220 }
221 
222 inline std::vector<std::vector<torch::Tensor>> Adagrad() {
223  return {
224  {
225  torch::tensor({0.7891011045987429, 0.502443924512199, 0.8587078329085825, 0.6579710994224826, 0.7476364836215006, 1.697557019500397}),
226  torch::tensor({0.8914687688941954, 0.7020514988069096, 1.6892015076050444}),
227  torch::tensor({-1.0508031297732776, -1.3941351871450518, -1.284337597261839}),
228  torch::tensor({-1.071138124161711}),
229  },
230  {
231  torch::tensor({2.407922969689259, 2.2346803754764286, 1.6967885588547365, 1.552279695827649, 1.2259044248443602, 2.221279696180243}),
232  torch::tensor({2.9334079162217193, 1.7619824934767887, 2.3464577179091473}),
233  torch::tensor({-2.221396083069719, -2.549950976011168, -1.9709315957317095}),
234  torch::tensor({-1.5858816837541876}),
235  },
236  {
237  torch::tensor({2.5104044339418126, 2.3522584510262887, 1.7921695110761213, 1.657755825836846, 1.2891186618593045, 2.291878516133922}),
238  torch::tensor({3.092171180776419, 1.8971624370952997, 2.438734251283465}),
239  torch::tensor({-2.437641633486504, -2.7704264590526573, -2.0949471699460225}),
240  torch::tensor({-1.6769121890401757}),
241  },
242  {
243  torch::tensor({2.565264896810942, 2.4155313947260972, 1.844241233613541, 1.7156513351246399, 1.3245206506797171, 2.3315409972138825}),
244  torch::tensor({3.178399916514377, 1.9721945764936502, 2.4909037706250428}),
245  torch::tensor({-2.5658710403147933, -2.901921821645266, -2.168560672193225}),
246  torch::tensor({-1.7307903926154131}),
247  },
248  {
249  torch::tensor({2.6021584494332597, 2.4582101324909065, 1.8796060082750778, 1.7550965207414717, 1.3489253597999988, 2.3589345190118247}),
250  torch::tensor({3.2368674310041516, 2.0236468833666894, 2.52707132741292}),
251  torch::tensor({-2.6573969292994164, -2.9960731060650505, -2.2211375717304076}),
252  torch::tensor({-1.7692090167089707}),
253  },
254  {
255  torch::tensor({2.6297007725792083, 2.4901377017698683, 1.906173477530586, 1.7847957161833832, 1.3674517119505822, 2.3797578857769905}),
256  torch::tensor({3.2807643102638546, 2.062561811940094, 2.5546379424362775}),
257  torch::tensor({-2.7286379977755035, -3.0695109399636236, -2.262081199960513}),
258  torch::tensor({-1.7990936323432214}),
259  },
260  {
261  torch::tensor({2.651547176699525, 2.51550257362603, 1.927341363452414, 1.8084994719811576, 1.3823309942932445, 2.3964995243914373}),
262  torch::tensor({3.3157334001309473, 2.093728023484945, 2.5768468697402924}),
263  torch::tensor({-2.786981763434855, -3.129746439571402, -2.29562487034177}),
264  torch::tensor({-1.8235564908139104}),
265  },
266  {
267  torch::tensor({2.669578054483789, 2.53646401614724, 1.9448721033433505, 1.828157582353901, 1.3947329882074622, 2.4104657178934947}),
268  torch::tensor({3.344694775590452, 2.1196465761628516, 2.5954050923596252}),
269  torch::tensor({-2.8363936812536537, -3.1808219609745194, -2.32404190866147}),
270  torch::tensor({-1.8442667636913117}),
271  },
272  {
273  torch::tensor({2.6848838015330725, 2.5542762192735515, 1.9597939532350015, 1.844909608012419, 1.4053459079217485, 2.4224257790968386}),
274  torch::tensor({3.369349515259956, 2.1417845308976795, 2.611319989214332}),
275  torch::tensor({-2.879251075341889, -3.225165734647855, -2.3486956737228057}),
276  torch::tensor({-1.86222449978646}),
277  },
278  {
279  torch::tensor({2.6981510124237693, 2.56972998600169, 1.972757472697587, 1.8594775691681182, 1.4146081751022495, 2.43287021079559}),
280  torch::tensor({3.390772758897601, 2.1610741754331757, 2.6252349489549824}),
281  torch::tensor({-2.917092322961074, -3.264351563375218, -2.370468664387175}),
282  torch::tensor({-1.8780765115117757}),
283  },
284  {
285  torch::tensor({2.7098389356033787, 2.5833548721723747, 1.9841994925173085, 1.8723468731726323, 1.4228158926355312, 2.4421305315945085}),
286  torch::tensor({3.4096859099156673, 2.178143852041279, 2.6375854547611364}),
287  torch::tensor({-2.950970455420847, -3.2994581338995044, -2.3899651139415874}),
288  torch::tensor({-1.8922653655195538}),
289  },
290  };
291 }
292 
293 inline std::vector<std::vector<torch::Tensor>> Adagrad_with_weight_decay() {
294  return {
295  {
296  torch::tensor({0.7891011218979068, 0.5024439415126254, 0.8587078332470682, 0.6579710998575992, 0.7476364849956589, 1.6975570150849029}),
297  torch::tensor({0.8914687701583902, 0.7020514988715463, 1.6892015071335027}),
298  torch::tensor({-1.0508031297726799, -1.3941351871397083, -1.2843375972607243}),
299  torch::tensor({-1.0711381241615712}),
300  },
301  {
302  torch::tensor({0.18461166785222127, 0.24944077103107912, 0.18651745437755765, 0.25219093533041764, 0.18712037968446704, 0.25289206444055223}),
303  torch::tensor({0.6482869597891654, 0.6580215784646756, 0.6581256007663536}),
304  torch::tensor({-1.454709711443681, -1.4748063405174818, -1.481162594660476}),
305  torch::tensor({-1.9052928365443633}),
306  },
307  {
308  torch::tensor({0.18059895999281467, 0.2438515539257777, 0.18067177884778177, 0.24397186395008688, 0.18168388351830783, 0.24533853846052}),
309  torch::tensor({0.6325250261983025, 0.6331827793513023, 0.6366659383355596}),
310  torch::tensor({-1.420803333750877, -1.4215627240541653, -1.432026454453339}),
311  torch::tensor({-2.0301356418483225}),
312  },
313  {
314  torch::tensor({0.17981392697398363, 0.2427571544305695, 0.17981150414451733, 0.2427572599231052, 0.18014798619115763, 0.2432144956227816}),
315  torch::tensor({0.6294321320817985, 0.6294873737410742, 0.6306958589251878}),
316  torch::tensor({-1.4139253354785764, -1.4139026804709813, -1.4173628530293867}),
317  torch::tensor({-2.056210117690093}),
318  },
319  {
320  torch::tensor({0.17967006242163752, 0.2425558273455728, 0.17966873677301953, 0.24255462870545763, 0.17975882308988514, 0.24267729072765762}),
321  torch::tensor({0.6288576295241085, 0.6288643132826753, 0.6291921485342002}),
322  torch::tensor({-1.4126465879787569, -1.4126335126907266, -1.4135586793353687}),
323  torch::tensor({-2.0618018405404825}),
324  },
325  {
326  torch::tensor({0.17964321284685655, 0.24251808241139367, 0.17964291377171063, 0.24251779104198595, 0.17966515741781022, 0.2425481059085456}),
327  torch::tensor({0.6287486931367788, 0.6287498167193976, 0.6288312441271762}),
328  torch::tensor({-1.412405895385289, -1.4124029484481164, -1.412631305131538}),
329  torch::tensor({-2.0630223163099304}),
330  },
331  {
332  torch::tensor({0.17963799739278502, 0.24251071912462158, 0.17963793631343422, 0.24251065925361745, 0.17964321802205502, 0.24251786094585862}),
333  torch::tensor({0.6287272170354161, 0.6287274414587727, 0.6287468362309863}),
334  torch::tensor({-1.4123588626342636, -1.4123582636507843, -1.4124124805696718}),
335  torch::tensor({-2.0632918101480255}),
336  },
337  {
338  torch::tensor({0.17963694231402433, 0.24250922426893604, 0.17963692980716203, 0.24250921210074505, 0.17963815759528076, 0.24251088666939835}),
339  torch::tensor({0.6287228195255877, 0.6287228675172438, 0.628727383936762}),
340  torch::tensor({-1.4123493065102781, -1.4123491844624378, -1.41236178722436}),
341  torch::tensor({-2.063351765096138}),
342  },
343  {
344  torch::tensor({0.17963672159046334, 0.24250891070978675, 0.1796367189700305, 0.2425089081831808, 0.17963700091084858, 0.24250929278196062}),
345  torch::tensor({0.6287218911936109, 0.6287219017313679, 0.6287229399204575}),
346  torch::tensor({-1.4123473011084144, -1.412347275640343, -1.41235016959507}),
347  torch::tensor({-2.06336516740435}),
348  },
349  {
350  torch::tensor({0.17963667424978788, 0.2425088433312069, 0.1796366736882977, 0.24250884279379462, 0.17963673796131566, 0.24250893047794034}),
351  torch::tensor({0.6287216908150737, 0.6287216931558691, 0.6287219299749586}),
352  torch::tensor({-1.4123468700596722, -1.4123468646187733, -1.4123475243360135}),
353  torch::tensor({-2.0633681724342527}),
354  },
355  {
356  torch::tensor({0.17963666391853464, 0.24250882860835232, 0.1796366637961457, 0.24250882849182892, 0.17963667838367048, 0.24250884839397405}),
357  torch::tensor({0.6287216468984883, 0.6287216474215305, 0.6287217011907862}),
358  torch::tensor({-1.4123467758545658, -1.4123467746710385, -1.412346924400766}),
359  torch::tensor({-2.0633688474977463}),
360  },
361  };
362 }
363 
364 inline std::vector<std::vector<torch::Tensor>> Adagrad_with_weight_decay_and_lr_decay() {
365  return {
366  {
367  torch::tensor({0.7891011046018798, 0.5024439245163383, 0.8587078329086189, 0.6579710994225316, 0.747636483621666, 1.697557019500142}),
368  torch::tensor({0.8914687688943375, 0.7020514988069164, 1.6892015076050049}),
369  torch::tensor({-1.0508031297732776, -1.3941351871450511, -1.284337597261839}),
370  torch::tensor({-1.0711381241617108}),
371  },
372  {
373  torch::tensor({2.3462189441101033, 2.1919394395020024, 1.6833552017408127, 1.5405520021635608, 1.2137800230828062, 2.2052834637173024}),
374  torch::tensor({2.9090564593404, 1.7509657336815554, 2.3361664131869246}),
375  torch::tensor({-2.206159683368316, -2.5344318233445415, -1.9622783535807609}),
376  torch::tensor({-1.5796101463783623}),
377  },
378  {
379  torch::tensor({2.3889328781057237, 2.267822103800729, 1.7667624725138262, 1.6358015176639829, 1.2655767687152566, 2.2610880567112814}),
380  torch::tensor({3.045569451994985, 1.8770196253823253, 2.419270751956676}),
381  torch::tensor({-2.4079300017528613, -2.7399112002234305, -2.0780613510632375}),
382  torch::tensor({-1.664722108226537}),
383  },
384  {
385  torch::tensor({2.388613755780639, 2.2922158071009173, 1.8078384116424002, 1.6843524744409326, 1.290353948335789, 2.287071550970649}),
386  torch::tensor({3.111110355394278, 1.9438501730282314, 2.463024935587282}),
387  torch::tensor({-2.5226122034499263, -2.8573150939162923, -2.143964860243905}),
388  torch::tensor({-1.713068580990504}),
389  },
390  {
391  torch::tensor({2.3747033522031566, 2.2988044992574554, 1.8330249458212442, 1.7151661013307251, 1.3048586226945842, 2.301765059046427}),
392  torch::tensor({3.150318222034133, 1.9877926185369321, 2.4913999764016785}),
393  torch::tensor({-2.601415913361488, -2.9382038951139644, -2.1892988334550028}),
394  torch::tensor({-1.7462964261966802}),
395  },
396  {
397  torch::tensor({2.3553658567303817, 2.2971917580426875, 1.850115474907212, 1.7368360586881886, 1.3141313000193942, 2.310745259215385}),
398  torch::tensor({3.1762315339155434, 2.0197585204578647, 2.5117041377790192}),
399  torch::tensor({-2.6606644002288697, -2.999121607429386, -2.223413376189609}),
400  torch::tensor({-1.7712905233118803}),
401  },
402  {
403  torch::tensor({2.333805220169621, 2.2913023710914993, 1.8624163948044767, 1.753030073172546, 1.3203313209234842, 2.3163969478854742}),
404  torch::tensor({3.1943525925688934, 2.044447386769377, 2.5271097246073966}),
405  torch::tensor({-2.7076634717294894, -3.0475008084690365, -2.250495807208967}),
406  torch::tensor({-1.7911288238757481}),
407  },
408  {
409  torch::tensor({2.3114979154644897, 2.2830501835377808, 1.8716161429993556, 1.7656325976608416, 1.324565631636651, 2.3199392342052025}),
410  torch::tensor({3.2074779925809085, 2.0642940833670544, 2.539267130147123}),
411  torch::tensor({-2.7463093287485925, -3.0873155411347164, -2.272780318857348}),
412  torch::tensor({-1.8074516661537259}),
413  },
414  {
415  torch::tensor({2.289184162738735, 2.273471699579369, 1.878681899989582, 1.7757301317117609, 1.3274682997719436, 2.322067935399382}),
416  torch::tensor({3.2172019619454075, 2.0807140893178175, 2.549137481514187}),
417  torch::tensor({-2.7789204504423823, -3.120935140242918, -2.2915969523376867}),
418  torch::tensor({-1.8212347229484205}),
419  },
420  {
421  torch::tensor({2.267249823834307, 2.263167803792889, 1.8842131287032617, 1.784000770538389, 1.3294311820750493, 2.3232112430345424}),
422  torch::tensor({3.224507488068445, 2.094598223519413, 2.557325715579171}),
423  torch::tensor({-2.8069849199086647, -3.149882604502293, -2.3077996970997727}),
424  torch::tensor({-1.8331040438272383}),
425  },
426  {
427  torch::tensor({2.245896171868896, 2.252503172511477, 1.8886034384961108, 1.790893034126796, 1.3307102435291205, 2.3236474976004615}),
428  torch::tensor({3.230036385413078, 2.1065407459636134, 2.564234924960966}),
429  torch::tensor({-2.83151249424399, -3.1751926295316566, -2.3219682378974036}),
430  torch::tensor({-1.8434843744626481}),
431  },
432  };
433 }
434 
435 inline std::vector<std::vector<torch::Tensor>> RMSprop() {
436  return {
437  {
438  torch::tensor({0.7890625772821005, 0.502415108650816, 0.8587027713011453, 0.6579673123006431, 0.7476283936579036, 1.6975509766054537}),
439  torch::tensor({0.8914573371873159, 0.7020499947573374, 1.6891991194739453}),
440  torch::tensor({-1.0508027874171133, -1.3941348219724659, -1.2843374000099703}),
441  torch::tensor({-1.0711379842715099}),
442  },
443  {
444  torch::tensor({2.4485718582774427, 2.2809152044417678, 1.7346424449151967, 1.5940004770230671, 1.2507611318399818, 2.248993270255382}),
445  torch::tensor({2.994661478530102, 1.815048529086425, 2.382542610897819}),
446  torch::tensor({-2.3036981738757825, -2.6337299521275646, -2.018370122358821}),
447  torch::tensor({-1.6207875598008983}),
448  },
449  {
450  torch::tensor({2.583758247560778, 2.4365737242301537, 1.862288651935454, 1.7357065282848236, 1.3369695670141972, 2.3454934716983695}),
451  torch::tensor({3.2061266499381618, 1.9981112525417783, 2.5092495986614}),
452  torch::tensor({-2.6110809365525958, -2.9484807193016787, -2.194898560798439}),
453  torch::tensor({-1.750104348062583}),
454  },
455  {
456  torch::tensor({2.66996905113451, 2.536559412710799, 1.9456091681389673, 1.8289149480917672, 1.3952956766999585, 2.4110816686341923}),
457  torch::tensor({3.3436729755936576, 2.1204057198913, 2.5961524902119497}),
458  torch::tensor({-2.8372329851331006, -3.1817729538857207, -2.324997185399695}),
459  torch::tensor({-1.845042217390749}),
460  },
461  {
462  torch::tensor({2.7375365004059113, 2.615307154535863, 2.0117493624534317, 1.9033001982031037, 1.4427501882445095, 2.4646213743186127}),
463  torch::tensor({3.452912454199796, 2.219045152412753, 2.667552790123282}),
464  torch::tensor({-3.03294794567315, -3.384582488936652, -2.437729982499713}),
465  torch::tensor({-1.927101478411823}),
466  },
467  {
468  torch::tensor({2.7952372917068744, 2.6828202203757217, 2.0687223272687003, 1.9676545487787112, 1.4844410726622164, 2.5117888904510117}),
469  torch::tensor({3.5471904628565754, 2.3051135482621405, 2.7307948248967304}),
470  torch::tensor({-3.2141190290332533, -3.572944633614449, -2.5421970206546822}),
471  torch::tensor({-2.002997698521967}),
472  },
473  {
474  torch::tensor({2.8467333937519474, 2.743278517711039, 2.119898810135386, 2.025680541674126, 1.522525646428022, 2.554983108087885}),
475  torch::tensor({3.6320983238761952, 2.3832893041797774, 2.7889864719999222}),
476  torch::tensor({-3.3875799441679257, -3.7537658010839294, -2.6423123266260427}),
477  torch::tensor({-2.0756169514457254}),
478  },
479  {
480  torch::tensor({2.893860049806019, 2.7987769841836143, 2.166973814751561, 2.0792385384302032, 1.5580820887115119, 2.5954023023969692}),
481  torch::tensor({3.7104388143530413, 2.4559190993219526, 2.8436784441941008}),
482  torch::tensor({-3.5567368287146417, -3.930484868709691, -2.740026479434574}),
483  torch::tensor({-2.1463982568717586}),
484  },
485  {
486  torch::tensor({2.9376493943999376, 2.85049296531231, 2.2109027772324468, 2.1293746183147637, 1.5917084661873862, 2.6337100421079533}),
487  torch::tensor({3.783785332844353, 2.52431557011306, 2.8957265009949373}),
488  torch::tensor({-3.7234268485210475, -4.104949193318518, -2.836390693799751}),
489  torch::tensor({-2.2161187733606114}),
490  },
491  {
492  torch::tensor({2.9787316798887544, 2.89914510784732, 2.252272487010976, 2.176728872920277, 1.6237627746697587, 2.670302507579268}),
493  torch::tensor({3.8530980655045095, 2.5892755315530245, 2.9456388178450936}),
494  torch::tensor({-3.8886856459619636, -4.278191888396593, -2.9319964928350317}),
495  torch::tensor({-2.2852176995051243}),
496  },
497  {
498  torch::tensor({3.0175156205790477, 2.945200425145326, 2.291469925522963, 2.221721702578292, 1.6544730927621267, 2.705431441204422}),
499  torch::tensor({3.9190041004420175, 2.6513176244659378, 2.993736489599992}),
500  torch::tensor({-4.053111559341913, -4.450801801162238, -3.027184551913196}),
501  torch::tensor({-2.353949890597345}),
502  },
503  };
504 }
505 
506 inline std::vector<std::vector<torch::Tensor>> RMSprop_with_weight_decay() {
507  return {
508  {
509  torch::tensor({0.7890798754118442, 0.5024321083861885, 0.8587031097835685, 0.6579677474141494, 0.7476297677960806, 1.6975465611838714}),
510  torch::tensor({0.891458601354904, 0.7020500593937647, 1.6891986479348047}),
511  torch::tensor({-1.0508027868194278, -1.3941348166291232, -1.2843373988951865}),
512  torch::tensor({-1.0711379841318796}),
513  },
514  {
515  torch::tensor({0.21398926523995987, 0.2779011713348031, 0.1868480279465855, 0.2507569370784996, 0.19145335235116723, 0.2557687813139495}),
516  torch::tensor({0.6720959116683096, 0.64807348480635, 0.6542630070045603}),
517  torch::tensor({-1.435763364089916, -1.449355795007287, -1.4619011018356904}),
518  torch::tensor({-1.9673083558727738}),
519  },
520  {
521  torch::tensor({0.23961935744799467, 0.30354236866029477, 0.19567694278621234, 0.2544696440136598, 0.21982879020352686, 0.27495711472053963}),
522  torch::tensor({0.69278957246733, 0.6380155354793247, 0.6523245375965621}),
523  torch::tensor({-1.4137225835004055, -1.417029100163453, -1.4166977298481118}),
524  torch::tensor({-2.0626651115437737}),
525  },
526  {
527  torch::tensor({0.25066358652707205, 0.31463951142849156, 0.2511689291080567, 0.3043139957958546, 0.2521962504869568, 0.3160110008169302}),
528  torch::tensor({0.7051419232958933, 0.6699011906397252, 0.6990972846783207}),
529  torch::tensor({-1.4206083241624008, -1.425703744410001, -1.4171061826065008}),
530  torch::tensor({-2.075537874763715}),
531  },
532  {
533  torch::tensor({0.23285924743048006, 0.29652494304761695, 0.233532200273659, 0.2969991261378863, 0.2335827224521369, 0.2973997498164485}),
534  torch::tensor({0.6855589594923344, 0.6796983775694434, 0.6864174803981533}),
535  torch::tensor({-1.4311076279464525, -1.4334934742817511, -1.4227395521450898}),
536  torch::tensor({-2.0842642493045536}),
537  },
538  {
539  torch::tensor({0.23356397699391668, 0.29737142391987237, 0.23367622061824211, 0.29749447597162154, 0.2341848135739778, 0.2981812292515802}),
540  torch::tensor({0.6866530583001411, 0.6858933385102763, 0.6883944045412813}),
541  torch::tensor({-1.456495550960717, -1.45835481315008, -1.4418225445708688}),
542  torch::tensor({-2.1064103749186267}),
543  },
544  {
545  torch::tensor({0.23187173011747372, 0.29529041598728734, 0.23194024439476815, 0.2953701982498784, 0.23164213369046718, 0.2951041425983909}),
546  torch::tensor({0.6834813130194525, 0.6834401711464214, 0.6837275457100478}),
547  torch::tensor({-1.4647835805276774, -1.465345240817907, -1.457114211277772}),
548  torch::tensor({-2.1209598505912113}),
549  },
550  {
551  torch::tensor({0.23086833965041764, 0.29404746297504464, 0.23089067678260944, 0.29407615110959284, 0.23064069314214153, 0.2937904361139022}),
552  torch::tensor({0.6815062281792609, 0.6815233687209212, 0.6812759203026144}),
553  torch::tensor({-1.4643013018530677, -1.4644523635284243, -1.4617493939684876}),
554  torch::tensor({-2.1247293635678854}),
555  },
556  {
557  torch::tensor({0.23066464201462633, 0.2937605927373038, 0.23067069245857322, 0.2937690399784262, 0.23057551211606628, 0.2936477517373103}),
558  torch::tensor({0.6809028781780299, 0.6809134105028238, 0.6807404613096296}),
559  torch::tensor({-1.463792735217798, -1.4638374228010722, -1.4629928710264293}),
560  torch::tensor({-2.1258082720638107}),
561  },
562  {
563  torch::tensor({0.23062625079199184, 0.29369907874257073, 0.2306278924729116, 0.2937014834661652, 0.23059813368157012, 0.293660738904764}),
564  torch::tensor({0.6807251804689083, 0.6807295616357248, 0.6806523640328995}),
565  torch::tensor({-1.4635790398985618, -1.4635929269022863, -1.4633272688236565}),
566  torch::tensor({-2.12613963581418}),
567  },
568  {
569  torch::tensor({0.2306170112270019, 0.29368398381778194, 0.23061747865501664, 0.29368467998690806, 0.2306085559563821, 0.2936719507340021}),
570  torch::tensor({0.680671467383083, 0.6806730903175792, 0.6806434720800854}),
571  torch::tensor({-1.4635008778278134, -1.4635052859178375, -1.463420837506828}),
572  torch::tensor({-2.1262432969587723}),
573  },
574  };
575 }
576 
577 inline std::vector<std::vector<torch::Tensor>> RMSprop_with_weight_decay_and_centered() {
578  return {
579  {
580  torch::tensor({0.7941000061626792, 0.507452636734552, 0.8637405354185987, 0.663005089317529, 0.7526661272860107, 1.7025887305065852}),
581  torch::tensor({0.8964950370033696, 0.7070877948157552, 1.6942369105467197}),
582  torch::tensor({-1.055840599214661, -1.3991726335388424, -1.2893752132746332}),
583  torch::tensor({-1.0761757981162612}),
584  },
585  {
586  torch::tensor({2.3762999876885833, 2.239095829416783, 1.726175067071914, 1.5891569459230446, 1.2410074108588462, 2.2345431036725723}),
587  torch::tensor({2.990896455635836, 1.8152108764849464, 2.377985429759037}),
588  torch::tensor({-2.3071822180635286, -2.636859516619699, -2.0198181394256642}),
589  torch::tensor({-1.622583045791722}),
590  },
591  {
592  torch::tensor({2.372800588647971, 2.3022753207224254, 1.836028714221617, 1.7190937269287108, 1.3068955839895078, 2.3035835673200364}),
593  torch::tensor({3.1656599892042343, 1.9942937608209463, 2.4947143457182657}),
594  torch::tensor({-2.6139790332516775, -2.9507738987695404, -2.1954425128779516}),
595  torch::tensor({-1.7513053380188808}),
596  },
597  {
598  torch::tensor({2.2398453700818455, 2.2513384246965904, 1.8892176431436287, 1.7921873754661688, 1.3310951408713536, 2.3236392222350397}),
599  torch::tensor({3.240166119454613, 2.1097428136001883, 2.5651614461576973}),
600  torch::tensor({-2.8388734382997454, -3.182420077067612, -2.324831397600949}),
601  torch::tensor({-1.8460315737386979}),
602  },
603  {
604  torch::tensor({1.9829606312242465, 2.097356567850692, 1.9050263843525033, 1.8325835415812348, 1.3222762370713101, 2.3024963133870147}),
605  torch::tensor({3.2465360572089974, 2.196726604586991, 2.6091992649970672}),
606  torch::tensor({-3.0326878099587207, -3.3827004807595005, -2.4369891822504957}),
607  torch::tensor({-1.9282732162063443}),
608  },
609  {
610  torch::tensor({1.6051175329080525, 1.8332107491649117, 1.8794767349053179, 1.8403588051948858, 1.2738241113141069, 2.2296571379436823}),
611  torch::tensor({3.1814362940910437, 2.2630192140728465, 2.6273016977574013}),
612  torch::tensor({-3.210932646440219, -3.567153254014387, -2.5410169439239136}),
613  torch::tensor({-2.004915513461716}),
614  },
615  {
616  torch::tensor({1.1588059349082709, 1.4778613795232265, 1.7992410089026636, 1.8064600091986671, 1.1739931551629919, 2.08647960875392}),
617  torch::tensor({3.03843703712275, 2.3082030683758767, 2.6125393914734083}),
618  torch::tensor({-3.3798306786085885, -3.7419704144706256, -2.641008240084654}),
619  torch::tensor({-2.0792949959104874}),
620  },
621  {
622  torch::tensor({0.7701433312419088, 1.110502667742475, 1.6465075169366392, 1.7162526909817901, 1.013748545414221, 1.8532966501655352}),
623  torch::tensor({2.8271768758852454, 2.3274019481599275, 2.5535309398603405}),
624  torch::tensor({-3.541933298509861, -3.909665295212314, -2.7394088701924364}),
625  torch::tensor({-2.1537939241668997}),
626  },
627  {
628  torch::tensor({0.5598923129351211, 0.8460500042788703, 1.408417554916502, 1.5547314210944567, 0.8019580519338422, 1.5258384663629627}),
629  torch::tensor({2.5774950379490265, 2.3131013066991266, 2.438869575744175}),
630  torch::tensor({-3.6974974230160096, -4.070190514312715, -2.83789326757184}),
631  torch::tensor({-2.2307225014430423}),
632  },
633  {
634  torch::tensor({0.5016784472836651, 0.7258690889265433, 1.0976902935953958, 1.3199491879725134, 0.5853930356154848, 1.1446978015944624}),
635  torch::tensor({2.3235249877284954, 2.259284097042017, 2.268146169860938}),
636  torch::tensor({-3.8444921272569124, -4.220210513610989, -2.9373192115434263}),
637  torch::tensor({-2.312733063937045}),
638  },
639  {
640  torch::tensor({0.4875468895095058, 0.6878747871467127, 0.7787871237567608, 1.046259254610218, 0.4416468896022396, 0.8122992916762793}),
641  torch::tensor({2.1078734515587483, 2.17034337037527, 2.066632596856854}),
642  torch::tensor({-3.9782695475825225, -4.3520930551154136, -3.0377809502927033}),
643  torch::tensor({-2.403496388200805}),
644  },
645  };
646 }
647 
648 inline std::vector<std::vector<torch::Tensor>> RMSprop_with_weight_decay_and_centered_and_momentum() {
649  return {
650  {
651  torch::tensor({0.7941000061626794, 0.507452636734552, 0.8637405354185985, 0.663005089317529, 0.7526661272860107, 1.7025887305065852}),
652  torch::tensor({0.8964950370033699, 0.7070877948157552, 1.6942369105467197}),
653  torch::tensor({-1.055840599214661, -1.3991726335388424, -1.2893752132746332}),
654  torch::tensor({-1.0761757981162612}),
655  },
656  {
657  torch::tensor({11.587263945492355, 12.552112516667208, 10.773002960161074, 10.782117868337808, 9.675467654064093, 10.830689360054789}),
658  torch::tensor({15.298238342006444, 11.252244653209866, 11.423905295074075}),
659  torch::tensor({-11.287147147258441, -11.673871066494183, -11.143068139029769}),
660  torch::tensor({-10.744790465364126}),
661  },
662  {
663  torch::tensor({5.993130757784388, 7.778269455146454, 9.705741295559012, 9.974952848613889, 8.171307305871647, 9.551498426643077}),
664  torch::tensor({12.811268477045155, 10.912201832960703, 10.87477550647832}),
665  torch::tensor({-11.20842921856976, -11.58706973895515, -11.098172235374586}),
666  torch::tensor({-10.714110383698559}),
667  },
668  {
669  torch::tensor({1.917316794757853, 3.4420983730039167, 8.160846071267297, 8.766734268561208, 6.163892823252042, 7.748894752821816}),
670  torch::tensor({9.52929937981379, 10.371703621802425, 10.02242566317017}),
671  torch::tensor({-11.07914626767133, -11.444639737948599, -11.02397978065452}),
672  torch::tensor({-10.663204622623406}),
673  },
674  {
675  torch::tensor({0.24211162925745067, 0.8235150923738452, 6.109652191353378, 7.070860554523036, 3.8366635637770212, 5.46037058418296}),
676  torch::tensor({5.7908039507441, 9.534309069066389, 8.752252906881251}),
677  torch::tensor({-10.868651889371552, -11.212965695734527, -10.90242744782103}),
678  torch::tensor({-10.579596899816439}),
679  },
680  {
681  torch::tensor({0.0024206009020476234, 0.055217404976894764, 3.753606156332189, 4.9331546064599685, 1.7094621184709604, 3.022224882400484}),
682  torch::tensor({2.4729429920325234, 8.290211439306459, 6.983317870704776}),
683  torch::tensor({-10.529133489023623, -10.839885990130032, -10.704345435808353}),
684  torch::tensor({-10.44279235413811}),
685  },
686  {
687  torch::tensor({8.523664833406631e-06, -0.0001849801580961706, 1.6343074841140277, 2.683608480982545, 0.41425107807132744, 1.092111816609512}),
688  torch::tensor({0.553119873538318, 6.566845593450314, 4.783317472190566}),
689  torch::tensor({-9.990101114696575, -10.24914448933998, -10.38447825909146}),
690  torch::tensor({-10.220382375374728}),
691  },
692  {
693  torch::tensor({5.366918233939773e-08, -2.8997040293991813e-07, 0.37916783268568177, 0.9399553431452387, 0.02859528129337607, 0.17650614337704745}),
694  torch::tensor({0.03166973497545419, 4.442846994093523, 2.5203464928754724}),
695  torch::tensor({-9.15653357178671, -9.339631853060773, -9.875729313751442}),
696  torch::tensor({-9.862669711962374}),
697  },
698  {
699  torch::tensor({2.1133356499004343e-06, 2.4524630407768008e-06, 0.023655729923601883, 0.14273709578291383, -8.950192389690758e-05, 0.004237697008964042}),
700  torch::tensor({-0.00012364097582548376, 2.291191859107928, 0.8331414409602524}),
701  torch::tensor({-7.922566174765117, -8.003055545094796, -9.086673634672907}),
702  torch::tensor({-9.297519364373224}),
703  },
704  {
705  torch::tensor({0.002349743029499243, 0.002861131671472502, 0.000699873962729607, 0.0036571565360575295, 0.001654303471369622, 0.0018171459470053366}),
706  torch::tensor({0.004569191565477355, 0.7292466599711233, 0.11475431260766135}),
707  torch::tensor({-6.223834483308681, -6.185383631607397, -7.912955414853613}),
708  torch::tensor({-8.430731662958186}),
709  },
710  {
711  torch::tensor({0.10393820340367545, 0.1398207466618173, 0.0831407198272949, 0.10183584198629941, 0.13949594516972202, 0.17822672100147108}),
712  torch::tensor({0.34039464502063904, 0.2486088886235969, 0.31914045155310655}),
713  torch::tensor({-4.174294597914298, -4.037528929635062, -6.297198700024484}),
714  torch::tensor({-7.182093090194918}),
715  },
716  };
717 }
718 
719 inline std::vector<std::vector<torch::Tensor>> SGD() {
720  return {
721  {
722  torch::tensor({-0.21063957030131192, -0.4972093725858961, -0.13931849072410168, -0.33939101965581686, -0.25112865488453673, 0.6992101966874735}),
723  torch::tensor({-0.1076573444246077, -0.2913064413859577, 0.6933846874181748}),
724  torch::tensor({-0.07998325778863398, -0.42149210515421365, -0.33498349553944556}),
725  torch::tensor({-0.14255126505509488}),
726  },
727  {
728  torch::tensor({-0.15543131540224012, -0.42351103963720343, -0.04196796248622072, -0.20952231780684988, -0.16031407286541022, 0.8209742464453325}),
729  torch::tensor({0.0772434360716014, 0.03387529472490231, 1.0028793648054941}),
730  torch::tensor({-0.8213382425894498, -1.1570800333254736, -1.6154760331657425}),
731  torch::tensor({-1.873409073108485}),
732  },
733  {
734  torch::tensor({-0.13342791770744886, -0.3941509709488104, -0.011470356542661935, -0.16885142516066964, -0.13306680693528108, 0.8576491729785701}),
735  torch::tensor({0.15081014600761683, 0.13560816175111742, 1.0971559708365837}),
736  torch::tensor({-0.9780975407869251, -1.3215153697157922, -1.8760213876051515}),
737  torch::tensor({-2.202441305652889}),
738  },
739  {
740  torch::tensor({-0.11963097684681223, -0.37573675130134543, 0.00699871664138837, -0.14420855651125974, -0.11733423659038758, 0.8788673419128562}),
741  torch::tensor({0.19698293387590055, 0.19734611640471314, 1.1520119567305152}),
742  torch::tensor({-1.0677802792431819, -1.4166561260631116, -2.0220337532169905}),
743  torch::tensor({-2.3834524272927813}),
744  },
745  {
746  torch::tensor({-0.10950806441156272, -0.3622226699218595, 0.02028489243523426, -0.12647254228380073, -0.10635775660996463, 0.8936912722040982}),
747  torch::tensor({0.230894623318268, 0.2418445007408441, 1.1904864598387046}),
748  torch::tensor({-1.1306213044009719, -1.483718648357814, -2.1228846025142074}),
749  torch::tensor({-2.50713525051584}),
750  },
751  {
752  torch::tensor({-0.10149090356585248, -0.3515172115812867, 0.030662536099764083, -0.11261325211798622, -0.09797248308626623, 0.905027632401109}),
753  torch::tensor({0.2577775982668944, 0.2766609657536914, 1.2199973265718322}),
754  torch::tensor({-1.1789655573653979, -1.5355073692636771, -2.1996125838846075}),
755  torch::tensor({-2.6005295414716625}),
756  },
757  {
758  torch::tensor({-0.09484472748389533, -0.3426405023243085, 0.03917399284640637, -0.10124188994381234, -0.09121264836307835, 0.9141743475340721}),
759  torch::tensor({0.28008293001710327, 0.30526002002900676, 1.2438661306695873}),
760  torch::tensor({-1.2182324765944266, -1.577685139408549, -2.261370486631629}),
761  torch::tensor({-2.675274336197319}),
762  },
763  {
764  torch::tensor({-0.08916446117741175, -0.33505233521798666, 0.04638527943959316, -0.09160422984057523, -0.08556486270584644, 0.9218219103015535}),
765  torch::tensor({0.29916193801548524, 0.32952375512951, 1.2638639017720827}),
766  torch::tensor({-1.251282493526328, -1.613256463950431, -2.312952993721385}),
767  torch::tensor({-2.7374195723946606}),
768  },
769  {
770  torch::tensor({-0.08420245801272855, -0.3284224385121882, 0.05263847708646642, -0.08324438788845251, -0.08072424164719598, 0.9283806476306355}),
771  torch::tensor({0.3158408734266357, 0.3505901981820038, 1.2810450644764015}),
772  torch::tensor({-1.2798091496372141, -1.644007253821019, -2.3571804629611046}),
773  torch::tensor({-2.7905023459395877}),
774  },
775  {
776  torch::tensor({-0.07979600214534927, -0.3225337978155753, 0.0581562720006689, -0.07586555700667831, -0.07649523955108037, 0.9341138824526719}),
777  torch::tensor({0.3306627217189734, 0.3692005578577211, 1.2960873917356066}),
778  torch::tensor({-1.3048976883823566, -1.671085574250112, -2.3958498984546135}),
779  torch::tensor({-2.8367650855551236}),
780  },
781  {
782  torch::tensor({-0.07583232846497831, -0.3172360102461862, 0.06309179259248046, -0.06926361352067163, -0.07274510848082802, 0.9392004636935606}),
783  torch::tensor({0.34400386060915455, 0.38586478679967207, 1.3094518934419668}),
784  torch::tensor({-1.3272851146877218, -1.695273130850265, -2.4301754289421593}),
785  torch::tensor({-2.877716472882302}),
786  },
787  };
788 }
789 
790 inline std::vector<std::vector<torch::Tensor>> SGD_with_weight_decay() {
791  return {
792  {
793  torch::tensor({-0.21042867144447805, -0.49671181653925384, -0.13917719856207697, -0.3390489907590303, -0.2508762913762564, 0.6985126396619242}),
794  torch::tensor({-0.10754881320494518, -0.2910084928862701, 0.6926954859081793}),
795  torch::tensor({-0.079932454658518, -0.42109796996670307, -0.33469915794198624}),
796  torch::tensor({-0.14248012693079315}),
797  },
798  {
799  torch::tensor({-0.13579982290274883, -0.3765456284475787, -0.03166970700350034, -0.18102559254681197, -0.1373234786735746, 0.7522156177001302}),
800  torch::tensor({0.08550003826014416, 0.051563225553454196, 0.9321399061276381}),
801  torch::tensor({-0.796312238882584, -1.1010063686038731, -1.5363716774172782}),
802  torch::tensor({-1.8045854907382846}),
803  },
804  {
805  torch::tensor({-0.09659168723529124, -0.3056207693658826, 0.006712867145512922, -0.1166002367977548, -0.09012083166238948, 0.7264953102453368}),
806  torch::tensor({0.16531808496504802, 0.16488328577596398, 0.9610743966573319}),
807  torch::tensor({-0.9202466399245914, -1.2052829272891827, -1.7049756710541348}),
808  torch::tensor({-2.0415977924493043}),
809  },
810  {
811  torch::tensor({-0.06728100597713035, -0.24965896016541955, 0.03186158526394667, -0.07105441484407878, -0.056478595544178806, 0.6910758436366733}),
812  torch::tensor({0.21707768347081777, 0.23575238192099465, 0.9564382346520687}),
813  torch::tensor({-0.978819503903, -1.2447191597975942, -1.7620201560619633}),
814  torch::tensor({-2.131504419683077}),
815  },
816  {
817  torch::tensor({-0.04304955053155505, -0.20206572730420896, 0.05095951394632446, -0.03470009355744099, -0.029224652011670186, 0.6547611705604361}),
818  torch::tensor({0.2563898231537708, 0.2878867158887637, 0.9414221685252803}),
819  torch::tensor({-1.0143969472996655, -1.2623288365082086, -1.780047146006567}),
820  torch::tensor({-2.170255083720924}),
821  },
822  {
823  torch::tensor({-0.02215471703826274, -0.16036518660639856, 0.06644401410758825, -0.004183373274651911, -0.005965877978527785, 0.6200298215101535}),
824  torch::tensor({0.2886406829874717, 0.32924516791460257, 0.9230983700837223}),
825  torch::tensor({-1.0397895250773481, -1.271091416624018, -1.780775800960309}),
826  torch::tensor({-2.1862978976514738}),
827  },
828  {
829  torch::tensor({-0.00374391398483171, -0.12328293308251932, 0.0794469618680564, 0.022100305718442004, 0.014399113804332031, 0.587697912745227}),
830  torch::tensor({0.31628710746920075, 0.36346293565421134, 0.9042402154310413}),
831  torch::tensor({-1.060234961430088, -1.2762264965487673, -1.7731268727630665}),
832  torch::tensor({-2.191253945056341}),
833  },
834  {
835  torch::tensor({0.012675985938854724, -0.09003711893222127, 0.09059095692632843, 0.04506778924310348, 0.03247299240601, 0.5579755127260052}),
836  torch::tensor({0.34062269989331717, 0.3924947745885882, 0.8860121369119327}),
837  torch::tensor({-1.0781407849705034, -1.2800528898634016, -1.7613120374342217}),
838  torch::tensor({-2.190575043873577}),
839  },
840  {
841  torch::tensor({0.027425440985777993, -0.060088099586172165, 0.10026092920861807, 0.06531092947039242, 0.048628754907931976, 0.5308215072596255}),
842  torch::tensor({0.3623974452028054, 0.41751623876388866, 0.8688788105023479}),
843  torch::tensor({-1.0946579691370502, -1.2836103422269478, -1.7474706191775766}),
844  torch::tensor({-2.1870021744944763}),
845  },
846  {
847  torch::tensor({0.040732509801474096, -0.0330302410355501, 0.10871770475931389, 0.08324870459183517, 0.06312228688815541, 0.5060892094042873}),
848  torch::tensor({0.38208249693950164, 0.43930026549895956, 0.8529817924677643}),
849  torch::tensor({-1.1103326127955466, -1.2873324059163587, -1.7327386627485202}),
850  torch::tensor({-2.1819672316721337}),
851  },
852  {
853  torch::tensor({0.052771609187326034, -0.008539186625351419, 0.11615154444871968, 0.09919929206676086, 0.07614530177703589, 0.48359250162323586}),
854  torch::tensor({0.3999968617221314, 0.45839442009256354, 0.838313296680579}),
855  torch::tensor({-1.1254107858333455, -1.2913604197768886, -1.717739109221235}),
856  torch::tensor({-2.176236807160431}),
857  },
858  };
859 }
860 
861 inline std::vector<std::vector<torch::Tensor>> SGD_with_weight_decay_and_momentum() {
862  return {
863  {
864  torch::tensor({-0.21042867144447805, -0.49671181653925384, -0.13917719856207697, -0.3390489907590303, -0.2508762913762564, 0.6985126396619242}),
865  torch::tensor({-0.10754881320494518, -0.2910084928862701, 0.6926954859081793}),
866  torch::tensor({-0.079932454658518, -0.42109796996670307, -0.33469915794198624}),
867  torch::tensor({-0.14248012693079315}),
868  },
869  {
870  torch::tensor({0.005611848725195473, -0.0710915563059199, 0.07701400891926036, 0.047067327035013866, 0.0428654052972598, 0.4352977220593751}),
871  torch::tensor({0.23834837300214828, 0.3236638250370418, 0.712832101663469}),
872  torch::tensor({-1.041947788394885, -1.1730950187020548, -1.7648205873351157}),
873  torch::tensor({-2.3359277661920594}),
874  },
875  {
876  torch::tensor({0.11520007183759415, 0.12894537687632862, 0.1458684555595196, 0.1775341535876219, 0.15614155642578992, 0.3337912614746053}),
877  torch::tensor({0.465853656413685, 0.5201979178769089, 0.7274876508280723}),
878  torch::tensor({-1.2034746444882527, -1.2861269692338677, -1.604528340632377}),
879  torch::tensor({-2.2032159091966244}),
880  },
881  {
882  torch::tensor({0.15331258730374997, 0.197909036233604, 0.1666381464737419, 0.21833204987278948, 0.1803274550482287, 0.2836274579441783}),
883  torch::tensor({0.5532312776994918, 0.5834224152126114, 0.6903579410976888}),
884  torch::tensor({-1.3052171323471546, -1.3514190497186431, -1.5153574535010637}),
885  torch::tensor({-2.123181139806548}),
886  },
887  {
888  torch::tensor({0.16814113185552507, 0.22386572201448868, 0.17413795101952861, 0.23280515326261633, 0.1839142207976228, 0.2614499495870909}),
889  torch::tensor({0.5922828765767589, 0.6083877519652824, 0.6634387486999062}),
890  torch::tensor({-1.3591143274292896, -1.383673065830997, -1.4671578935172773}),
891  torch::tensor({-2.087859547998447}),
892  },
893  {
894  torch::tensor({0.1743742243877178, 0.23431261530597983, 0.1771694292764225, 0.2383866964333009, 0.18308461132092926, 0.25149544624452974}),
895  torch::tensor({0.6108281747800746, 0.6192657661217673, 0.6475519545045927}),
896  torch::tensor({-1.3860527054444405, -1.398816664238087, -1.4412527948055516}),
897  torch::tensor({-2.0731939075659627}),
898  },
899  {
900  torch::tensor({0.17714654787514617, 0.23875859951719522, 0.1784868271584857, 0.24067863725664962, 0.18181103291606765, 0.24687877342069475}),
901  torch::tensor({0.6198586021174767, 0.6242349464856269, 0.6387368453733712}),
902  torch::tensor({-1.3993307716862977, -1.4058965193851591, -1.42747775986796}),
903  torch::tensor({-2.0672675843404598}),
904  },
905  {
906  torch::tensor({0.1784309358535768, 0.24073954802700465, 0.1790869702744087, 0.2416675839909268, 0.18088350526559058, 0.24467193314356378}),
907  torch::tensor({0.6243071074374693, 0.6265628975677455, 0.6339840865876518}),
908  torch::tensor({-1.4058750036106915, -1.4092362337714568, -1.4202202926903085}),
909  torch::tensor({-2.0649062340635584}),
910  },
911  {
912  torch::tensor({0.17904350645021613, 0.24165496946247034, 0.17936920658487726, 0.24211164489776849, 0.18031858582735988, 0.2435923992630521}),
913  torch::tensor({0.626513445507806, 0.6276715667697311, 0.6314641991686346}),
914  torch::tensor({-1.409113940967948, -1.410830795235453, -1.4164247285253404}),
915  torch::tensor({-2.0639728292802046}),
916  },
917  {
918  torch::tensor({0.17934167113683835, 0.242089962404631, 0.17950490408309283, 0.24231745350706002, 0.17999989292556767, 0.24305577552575766}),
919  torch::tensor({0.6276131793232343, 0.6282062328090798, 0.6301427155170752}),
920  torch::tensor({-1.4107251789010826, -1.4116011824171857, -1.414451176796242}),
921  torch::tensor({-2.063605631667394}),
922  },
923  {
924  torch::tensor({0.17948886155124505, 0.24230096332204806, 0.17957117450689372, 0.242415213133214, 0.17982712042628354, 0.24278620392248687}),
925  torch::tensor({0.6281635672171683, 0.6284667582211863, 0.6294549191500092}),
926  torch::tensor({-1.4115305541843781, -1.4119772978756442, -1.413429652281864}),
927  torch::tensor({-2.0634616066978615}),
928  },
929  };
930 }
931 
932 inline std::vector<std::vector<torch::Tensor>> SGD_with_weight_decay_and_nesterov_momentum() {
933  return {
934  {
935  torch::tensor({-0.21040617235121148, -0.49689727139951717, -0.13754215970803657, -0.33701686525263036, -0.2500172388792182, 0.700697918175925}),
936  torch::tensor({-0.1068708360895515, -0.2853285323043249, 0.6971494161502307}),
937  torch::tensor({-0.10624536304143092, -0.4461132561477894, -0.3805647497874434}),
938  torch::tensor({-0.2068230782168696}),
939  },
940  {
941  torch::tensor({-0.1262387113548655, -0.3844658218758334, 0.03124406856508885, -0.1117053215242578, -0.09823268522398329, 0.9040698525178972}),
942  torch::tensor({0.17551336074135096, 0.27976614792027166, 1.2138399680985128}),
943  torch::tensor({-1.592840413595591, -1.8986806244521564, -2.966181914454827}),
944  torch::tensor({-3.7728444542017687}),
945  },
946  {
947  torch::tensor({-0.11614716303292183, -0.3709539909720773, 0.04307078045512774, -0.09588329367245822, -0.08795603365024901, 0.9178771227283019}),
948  torch::tensor({0.20944042006388683, 0.3195483889401668, 1.2500270348310718}),
949  torch::tensor({-1.635011052494502, -1.9463243375558272, -3.035708036973984}),
950  torch::tensor({-3.8570351018212796}),
951  },
952  {
953  torch::tensor({-0.10793942832760066, -0.35995697973682966, 0.05260329955808717, -0.08312010825923574, -0.07986326997915316, 0.9287409473303162}),
954  torch::tensor({0.2370574459090396, 0.35168415020524857, 1.278618438127574}),
955  torch::tensor({-1.669141810658011, -1.984894370767313, -3.091259532917102}),
956  torch::tensor({-3.923827025320545}),
957  },
958  {
959  torch::tensor({-0.1010142826857921, -0.35067247612415425, 0.06058642765135954, -0.07242353828264113, -0.07320722520220556, 0.9376663294528951}),
960  torch::tensor({0.26037531373638517, 0.3786476842903903, 1.3021925174954938}),
961  torch::tensor({-1.6978623668013235, -2.017346013780729, -3.137511248751908}),
962  torch::tensor({-3.9791368472670334}),
963  },
964  {
965  torch::tensor({-0.0950223925827384, -0.34263425874631004, 0.06744912149060933, -0.06322196689556117, -0.0675685037432093, 0.9452179348012486}),
966  torch::tensor({0.2805629021730173, 0.4018655921083789, 1.322201974233735}),
967  torch::tensor({-1.7226667375672964, -2.0453651314263936, -3.177094625235675}),
968  torch::tensor({-4.02626353351958}),
969  },
970  {
971  torch::tensor({-0.08974074058929345, -0.3355446553621404, 0.07346375443579245, -0.05515233627910104, -0.0626864887100167, 0.9517469338705254}),
972  torch::tensor({0.2983667593362755, 0.4222447182468949, 1.3395523443811077}),
973  torch::tensor({-1.7445022249557358, -2.070022023204061, -3.2116640112699977}),
974  torch::tensor({-4.0672678681014025}),
975  },
976  {
977  torch::tensor({-0.08501805567029426, -0.32920173535901165, 0.07881418855733363, -0.04796950604202485, -0.05838841106411265, 0.9574862878804136}),
978  torch::tensor({0.314293480998118, 0.4403978459123401, 1.3548455497581404}),
979  torch::tensor({-1.7640079833055848, -2.092039516337239, -3.2423272727017984}),
980  torch::tensor({-4.1035226231344275}),
981  },
982  {
983  torch::tensor({-0.08074691916762684, -0.32346209125355385, 0.08363041954091822, -0.041500113264404184, -0.054554001958245224, 0.9625982669377223}),
984  torch::tensor({0.3287029554083447, 0.4567586474454379, 1.3685016029692183}),
985  torch::tensor({-1.781635969965323, -2.1119291060463254, -3.2698625668054278}),
986  torch::tensor({-4.135987715622241}),
987  },
988  {
989  torch::tensor({-0.07684825926741545, -0.3182201487496606, 0.08800775942949754, -0.035617020506473744, -0.051096269408742706, 0.967200308063431}),
990  torch::tensor({0.3418602073865105, 0.4716453768126203, 1.3808249543962092}),
991  torch::tensor({-1.7977177360253929, -2.1300662313234127, -3.2948372910743537}),
992  torch::tensor({-4.165360368453936}),
993  },
994  {
995  torch::tensor({-0.07326215742204906, -0.31339589848358795, 0.09201816976416922, -0.03022421717885477, -0.04795031746059936, 0.9713800632469923}),
996  torch::tensor({0.35396614509118257, 0.48529852449498895, 1.3920431643924076}),
997  torch::tensor({-1.8125038126190638, -2.146734711618823, -3.3176778240157505}),
998  torch::tensor({-4.192162739857097}),
999  },
1000  };
1001 }
1002 
1003 } // namespace expected_parameters