blob: 9d634609cc9b7ea79c6b53a328225543a8ebf66f [file] [log] [blame]
// @generated from test/cpp/api/optim_baseline.py
#include <torch/tensor.h>
#include <vector>
namespace expected_parameters {
static std::vector<std::vector<torch::Tensor>> Adam = {
{
torch::tensor({0.7889791973017408, 0.5023527440741749, 0.8586918159203789, 0.6579591153929213, 0.747610883848348, 1.697537897359327}),
torch::tensor({0.8914325948147117, 0.7020467393446147, 1.6891939505415117}),
torch::tensor({-1.0508020464078212, -1.3941340315784612, -1.2843369730699878}),
torch::tensor({-1.0711376814874036}),
},
{
torch::tensor({8.232343369651838, 7.970643300186945, 6.643546447481872, 6.470927350729255, 6.1699929180461135, 7.150644529115176}),
torch::tensor({8.417513698774671, 6.597182008001362, 7.231731333798338}),
torch::tensor({-6.7296200590850805, -7.097441464483235, -6.7533081426144665}),
torch::tensor({-6.435644769127909}),
},
{
torch::tensor({8.232728629793431, 7.971029896507965, 6.643845645407439, 6.471228017080045, 6.170273299845014, 7.150926509034132}),
torch::tensor({8.41790341440641, 6.597486966540781, 7.232017965615688}),
torch::tensor({-6.729913952070152, -7.097736635913622, -6.7535910561856145}),
torch::tensor({-6.435922230155599}),
},
{
torch::tensor({8.232728644291049, 7.971029911061701, 6.6438456566822675, 6.471228028420597, 6.170273310405472, 7.150926519662466}),
torch::tensor({8.417903429093856, 6.597486978073832, 7.232017976441867}),
torch::tensor({-6.729913994977667, -7.0977366809700015, -6.753591085251188}),
torch::tensor({-6.435922253675698}),
},
{
torch::tensor({8.232728644308507, 7.971029911086271, 6.643845656714944, 6.471228028466046, 6.170273310429644, 7.150926519696103}),
torch::tensor({8.417903429138356, 6.597486978157095, 7.232017976503389}),
torch::tensor({-6.729914033672082, -7.097736722215963, -6.753591107635156}),
torch::tensor({-6.435922269573368}),
},
{
torch::tensor({8.232728644328265, 7.971029911114318, 6.643845656752554, 6.471228028518527, 6.170273310457381, 7.150926519734861}),
torch::tensor({8.417903429189604, 6.597486978253601, 7.232017976574605}),
torch::tensor({-6.729914078724216, -7.09773677023889, -6.753591133696778}),
torch::tensor({-6.435922288082886}),
},
{
torch::tensor({8.232728644350706, 7.971029911146173, 6.643845656795268, 6.471228028578131, 6.1702733104888825, 7.150926519778878}),
torch::tensor({8.417903429247799, 6.597486978363203, 7.232017976655485}),
torch::tensor({-6.729914129890292, -7.0977368247789245, -6.75359116329518}),
torch::tensor({-6.435922309104302}),
},
{
torch::tensor({8.232728644375776, 7.9710299111817635, 6.64384565684299, 6.471228028644725, 6.170273310524078, 7.150926519828057}),
torch::tensor({8.417903429312823, 6.597486978485657, 7.232017976745851}),
torch::tensor({-6.729914187056874, -7.097736885715142, -6.753591196364736}),
torch::tensor({-6.435922332591009}),
},
{
torch::tensor({8.232728644403466, 7.971029911221072, 6.643845656895699, 6.471228028718272, 6.17027331056295, 7.150926519882374}),
torch::tensor({8.417903429384637, 6.597486978620901, 7.232017976845656}),
torch::tensor({-6.729914250194772, -7.09773695301644, -6.753591232888567}),
torch::tensor({-6.435922358531008}),
},
{
torch::tensor({8.232728644433786, 7.971029911264116, 6.643845656953418, 6.471228028798811, 6.170273310605518, 7.150926519941853}),
torch::tensor({8.41790342946328, 6.597486978769002, 7.23201797695495}),
torch::tensor({-6.729914319334562, -7.097737026715396, -6.753591272884355}),
torch::tensor({-6.43592238693687}),
},
{
torch::tensor({8.232728644466773, 7.971029911310945, 6.64384565701621, 6.471228028886431, 6.170273310651826, 7.150926520006559}),
torch::tensor({8.417903429548836, 6.5974869789301245, 7.23201797707385}),
torch::tensor({-6.729914394552445, -7.097737106893246, -6.753591316396183}),
torch::tensor({-6.435922417839901}),
},
};
static std::vector<std::vector<torch::Tensor>> Adam_with_weight_decay = {
{
torch::tensor({0.7890338917145869, 0.5024064972281554, 0.8586928862731152, 0.6579604913213795, 0.7476152291155436, 1.697523935068844}),
torch::tensor({0.8914365922767382, 0.7020469437416427, 1.689192459420757}),
torch::tensor({-1.0508020445177773, -1.3941340146813552, -1.2843369695447353}),
torch::tensor({-1.071137681045855}),
},
{
torch::tensor({0.17835897288288402, 0.2542141735779514, 0.19682011079203052, 0.23522758723292447, 0.17806013441679716, 0.2294364029080342}),
torch::tensor({0.6227661366261539, 0.6058596073202995, 0.6077177004897373}),
torch::tensor({-1.4259754714918282, -1.4333347967076562, -1.4085456423279246}),
torch::tensor({-2.0710783910024624}),
},
{
torch::tensor({0.17965695285279146, 0.24254347996350406, 0.1796466384372902, 0.24250836158041728, 0.17962895987963448, 0.24249920721192142}),
torch::tensor({0.628714524574237, 0.6286955878301643, 0.6286563325801343}),
torch::tensor({-1.4123887231473444, -1.4124007117933108, -1.4122701547931236}),
torch::tensor({-2.0633570397920584}),
},
{
torch::tensor({0.17963666509735735, 0.24250861931018913, 0.17963731763071286, 0.24250861136640806, 0.17963720028927982, 0.2425089024869051}),
torch::tensor({0.6287221269545457, 0.6287225821212604, 0.6287220275523637}),
torch::tensor({-1.4123466102957032, -1.4123465669345847, -1.41234626147394}),
torch::tensor({-2.063368365141719}),
},
{
torch::tensor({0.17963666103062387, 0.24250882317040834, 0.17963665831242884, 0.24250882481070404, 0.17963666029239178, 0.24250882426284776}),
torch::tensor({0.6287216329916357, 0.6287216340516892, 0.6287216326966386}),
torch::tensor({-1.412346754262532, -1.4123467542352315, -1.4123467478192075}),
torch::tensor({-2.063369043244171}),
},
{
torch::tensor({0.17963666098500486, 0.24250882442379565, 0.17963666099348433, 0.242508824412024, 0.17963666097250178, 0.2425088244105809}),
torch::tensor({0.628721634379825, 0.6287216343800635, 0.6287216343742585}),
torch::tensor({-1.4123467490742738, -1.4123467490725563, -1.4123467490678534}),
torch::tensor({-2.0633690434425387}),
},
{
torch::tensor({0.1796366609840712, 0.24250882442250232, 0.17963666098407377, 0.24250882442242358, 0.17963666098409362, 0.24250882442251606}),
torch::tensor({0.6287216343836621, 0.628721634383615, 0.628721634383626}),
torch::tensor({-1.412346749067226, -1.412346749067243, -1.4123467490671688}),
torch::tensor({-2.063369043434909}),
},
{
torch::tensor({0.17963666098407, 0.24250882442244054, 0.1796366609840706, 0.24250882442244054, 0.17963666098407016, 0.2425088244224409}),
torch::tensor({0.6287216343837067, 0.6287216343837064, 0.6287216343837068}),
torch::tensor({-1.4123467490671704, -1.4123467490671708, -1.4123467490671715}),
torch::tensor({-2.0633690434349052}),
},
{
torch::tensor({0.17963666098407016, 0.24250882442244107, 0.17963666098407033, 0.24250882442244107, 0.17963666098407036, 0.24250882442244096}),
torch::tensor({0.6287216343837067, 0.6287216343837069, 0.6287216343837068}),
torch::tensor({-1.4123467490671706, -1.4123467490671708, -1.4123467490671706}),
torch::tensor({-2.0633690434349052}),
},
{
torch::tensor({0.17963666098407077, 0.24250882442244168, 0.1796366609840704, 0.242508824422441, 0.17963666098407038, 0.24250882442244104}),
torch::tensor({0.6287216343837073, 0.6287216343837069, 0.6287216343837069}),
torch::tensor({-1.4123467490671708, -1.4123467490671706, -1.4123467490671708}),
torch::tensor({-2.0633690434349052}),
},
{
torch::tensor({0.1796366609840697, 0.24250882442244098, 0.1796366609840703, 0.2425088244224411, 0.1796366609840702, 0.24250882442244087}),
torch::tensor({0.6287216343837071, 0.6287216343837067, 0.6287216343837069}),
torch::tensor({-1.4123467490671706, -1.4123467490671708, -1.4123467490671706}),
torch::tensor({-2.0633690434349052}),
},
};
static std::vector<std::vector<torch::Tensor>> Adam_with_weight_decay_and_amsgrad = {
{
torch::tensor({0.7889792072185753, 0.502352757161707, 0.8586918160350755, 0.6579591155483523, 0.7476108843716649, 1.6975378965521928}),
torch::tensor({0.8914325952640971, 0.7020467393659713, 1.6891939504169646}),
torch::tensor({-1.0508020464076324, -1.3941340315767958, -1.2843369730696377}),
torch::tensor({-1.0711376814873597}),
},
{
torch::tensor({6.790172150533646, 6.914645717041209, 6.415265715837617, 6.29759694822821, 5.845043191735576, 6.8621426309929}),
torch::tensor({7.958560080109429, 6.511332850580849, 7.100944983002762}),
torch::tensor({-6.690685400916639, -7.0565911313808, -6.7211550236002875}),
torch::tensor({-6.4066139474811115}),
},
{
torch::tensor({4.707385912311419, 5.291370521290747, 6.045088989172656, 6.024360881197099, 5.309433796539775, 6.388035972014135}),
torch::tensor({7.200400889176835, 6.398381825147094, 6.904102965924078}),
torch::tensor({-6.664146183636052, -7.026722729554791, -6.705827953119048}),
torch::tensor({-6.3963166164396865}),
},
{
torch::tensor({2.950915886969458, 3.7657694681407405, 5.607364810245922, 5.695752074431366, 4.701371245420197, 5.835001213891114}),
torch::tensor({6.3434824675494825, 6.258238067497203, 6.6630019537301886}),
torch::tensor({-6.630457443730077, -6.988861422125357, -6.686200564723017}),
torch::tensor({-6.3830161296508745}),
},
{
torch::tensor({1.713005764813263, 2.536434364252745, 5.140371990006822, 5.337998982747555, 4.083858717692442, 5.254544687928855}),
torch::tensor({5.4779207832576855, 6.100064103985914, 6.3947278642170975}),
torch::tensor({-6.591738209089138, -6.94536203674196, -6.663591073998776}),
torch::tensor({-6.3676809797869085}),
},
{
torch::tensor({0.9342570168001864, 1.6340561447464659, 4.667953486696388, 4.967659606858919, 3.4932659094295624, 4.678153658892501}),
torch::tensor({4.65515648111164, 5.929480127155326, 6.110044531515085}),
torch::tensor({-6.549112173520781, -6.897492853519762, -6.63863605131254}),
torch::tensor({-6.35073759645243}),
},
{
torch::tensor({0.483598008466956, 1.0143518793820192, 4.205896945782518, 4.5961697937201675, 2.950176323373372, 4.125793850415685}),
torch::tensor({3.9036580383003816, 5.750502254470126, 5.816601628842057}),
torch::tensor({-6.5033672421993005, -6.846143861683003, -6.611779357893088}),
torch::tensor({-6.332482574254225}),
},
{
torch::tensor({0.23940946262165735, 0.6100979055073588, 3.764576093338919, 4.23158112532352, 2.464744464995124, 3.6096701455897247}),
torch::tensor({3.236802951205823, 5.566165699240412, 5.520071244528464}),
torch::tensor({-6.4550968672620765, -6.791985760021875, -6.58335397163011}),
torch::tensor({-6.313137987164135}),
},
{
torch::tensor({0.1140432483273187, 0.35710698257092627, 3.3504972618755904, 3.879514641744938, 2.040187813168003, 3.136530496314723}),
torch::tensor({2.658047229800407, 5.378832755537033, 5.224730288801605}),
torch::tensor({-6.40476778326666, -6.735546192994843, -6.553621096227888}),
torch::tensor({-6.292877798552594}),
},
{
torch::tensor({0.05253197784154339, 0.2041401230295221, 2.9673514620600994, 3.5437785414463425, 1.6752543175367425, 2.70927259889588}),
torch::tensor({2.1645721260950497, 5.190371109223889, 4.933812708502093}),
torch::tensor({-6.352757654031313, -6.677252148986569, -6.52279182009777}),
torch::tensor({-6.271842467566972}),
},
{
torch::tensor({0.0234986931355529, 0.11430841965659456, 2.616783974306594, 3.226808143825143, 1.3659901252130222, 2.3281244462759787}),
torch::tensor({1.7498774217269997, 5.002268663873182, 4.649746694230509}),
torch::tensor({-6.299378178284937, -6.617456004434928, -6.491040351925264}),
torch::tensor({-6.250147860242402}),
},
};
static std::vector<std::vector<torch::Tensor>> Adagrad = {
{
torch::tensor({0.7891011045987429, 0.502443924512199, 0.8587078329085825, 0.6579710994224826, 0.7476364836215006, 1.697557019500397}),
torch::tensor({0.8914687688941954, 0.7020514988069096, 1.6892015076050444}),
torch::tensor({-1.0508031297732776, -1.3941351871450518, -1.284337597261839}),
torch::tensor({-1.071138124161711}),
},
{
torch::tensor({2.4079229696892583, 2.2346803754764286, 1.6967885588547365, 1.5522796958276492, 1.2259044248443602, 2.221279696180243}),
torch::tensor({2.9334079162217193, 1.7619824934767887, 2.3464577179091473}),
torch::tensor({-2.221396083069719, -2.549950976011168, -1.9709315957317095}),
torch::tensor({-1.5858816837541876}),
},
{
torch::tensor({2.510404433941812, 2.3522584510262887, 1.7921695110761213, 1.6577558258368463, 1.2891186618593045, 2.291878516133922}),
torch::tensor({3.092171180776419, 1.8971624370952997, 2.438734251283465}),
torch::tensor({-2.437641633486504, -2.7704264590526573, -2.0949471699460225}),
torch::tensor({-1.6769121890401757}),
},
{
torch::tensor({2.5652648968109415, 2.4155313947260972, 1.844241233613541, 1.71565133512464, 1.3245206506797171, 2.3315409972138825}),
torch::tensor({3.178399916514377, 1.9721945764936502, 2.4909037706250428}),
torch::tensor({-2.5658710403147933, -2.901921821645266, -2.168560672193225}),
torch::tensor({-1.7307903926154131}),
},
{
torch::tensor({2.6021584494332592, 2.4582101324909065, 1.8796060082750778, 1.755096520741472, 1.3489253597999988, 2.3589345190118247}),
torch::tensor({3.2368674310041516, 2.0236468833666894, 2.52707132741292}),
torch::tensor({-2.6573969292994164, -2.9960731060650505, -2.2211375717304076}),
torch::tensor({-1.7692090167089707}),
},
{
torch::tensor({2.629700772579208, 2.4901377017698683, 1.906173477530586, 1.7847957161833834, 1.3674517119505822, 2.3797578857769905}),
torch::tensor({3.2807643102638546, 2.062561811940094, 2.5546379424362775}),
torch::tensor({-2.7286379977755035, -3.0695109399636236, -2.262081199960513}),
torch::tensor({-1.7990936323432214}),
},
{
torch::tensor({2.6515471766995247, 2.51550257362603, 1.927341363452414, 1.8084994719811578, 1.3823309942932445, 2.3964995243914373}),
torch::tensor({3.3157334001309473, 2.093728023484945, 2.5768468697402924}),
torch::tensor({-2.786981763434855, -3.129746439571402, -2.29562487034177}),
torch::tensor({-1.8235564908139104}),
},
{
torch::tensor({2.6695780544837886, 2.53646401614724, 1.9448721033433505, 1.8281575823539011, 1.3947329882074622, 2.4104657178934947}),
torch::tensor({3.344694775590452, 2.1196465761628516, 2.5954050923596252}),
torch::tensor({-2.8363936812536537, -3.1808219609745194, -2.32404190866147}),
torch::tensor({-1.8442667636913117}),
},
{
torch::tensor({2.684883801533072, 2.5542762192735515, 1.9597939532350015, 1.8449096080124192, 1.4053459079217485, 2.4224257790968386}),
torch::tensor({3.369349515259956, 2.1417845308976795, 2.611319989214332}),
torch::tensor({-2.879251075341889, -3.225165734647855, -2.3486956737228057}),
torch::tensor({-1.86222449978646}),
},
{
torch::tensor({2.698151012423769, 2.56972998600169, 1.972757472697587, 1.8594775691681185, 1.4146081751022495, 2.43287021079559}),
torch::tensor({3.390772758897601, 2.1610741754331757, 2.6252349489549824}),
torch::tensor({-2.917092322961074, -3.264351563375218, -2.370468664387175}),
torch::tensor({-1.8780765115117757}),
},
{
torch::tensor({2.7098389356033783, 2.5833548721723747, 1.9841994925173085, 1.8723468731726325, 1.4228158926355312, 2.4421305315945085}),
torch::tensor({3.4096859099156673, 2.178143852041279, 2.6375854547611364}),
torch::tensor({-2.9509704554208467, -3.2994581338995044, -2.3899651139415874}),
torch::tensor({-1.8922653655195538}),
},
};
static std::vector<std::vector<torch::Tensor>> Adagrad_with_weight_decay = {
{
torch::tensor({0.7891011218979068, 0.5024439415126254, 0.8587078332470682, 0.6579710998575992, 0.7476364849956589, 1.6975570150849029}),
torch::tensor({0.8914687701583902, 0.7020514988715463, 1.6892015071335027}),
torch::tensor({-1.0508031297726799, -1.3941351871397083, -1.2843375972607243}),
torch::tensor({-1.0711381241615712}),
},
{
torch::tensor({0.18461166785222133, 0.2494407710310792, 0.1865174543775577, 0.2521909353304177, 0.18712037968446715, 0.2528920644405524}),
torch::tensor({0.6482869597891655, 0.6580215784646756, 0.6581256007663538}),
torch::tensor({-1.4547097114436809, -1.4748063405174818, -1.4811625946604763}),
torch::tensor({-1.9052928365443633}),
},
{
torch::tensor({0.1805989599928146, 0.24385155392577762, 0.18067177884778177, 0.24397186395008685, 0.18168388351830786, 0.2453385384605201}),
torch::tensor({0.6325250261983025, 0.6331827793513023, 0.6366659383355597}),
torch::tensor({-1.4208033337508772, -1.421562724054165, -1.4320264544533394}),
torch::tensor({-2.030135641848322}),
},
{
torch::tensor({0.17981392697398357, 0.2427571544305694, 0.1798115041445173, 0.2427572599231052, 0.18014798619115757, 0.24321449562278158}),
torch::tensor({0.6294321320817985, 0.6294873737410742, 0.6306958589251878}),
torch::tensor({-1.4139253354785764, -1.4139026804709813, -1.4173628530293867}),
torch::tensor({-2.056210117690093}),
},
{
torch::tensor({0.17967006242163752, 0.24255582734557285, 0.1796687367730195, 0.24255462870545763, 0.1797588230898851, 0.2426772907276576}),
torch::tensor({0.6288576295241087, 0.6288643132826753, 0.6291921485342001}),
torch::tensor({-1.412646587978757, -1.4126335126907266, -1.4135586793353685}),
torch::tensor({-2.0618018405404825}),
},
{
torch::tensor({0.17964321284685653, 0.24251808241139367, 0.17964291377171068, 0.242517791041986, 0.17966515741781022, 0.24254810590854564}),
torch::tensor({0.6287486931367787, 0.6287498167193977, 0.6288312441271763}),
torch::tensor({-1.412405895385289, -1.4124029484481166, -1.412631305131538}),
torch::tensor({-2.06302231630993}),
},
{
torch::tensor({0.17963799739278477, 0.2425107191246213, 0.1796379363134341, 0.24251065925361734, 0.1796432180220549, 0.24251786094585853}),
torch::tensor({0.6287272170354158, 0.6287274414587726, 0.6287468362309862}),
torch::tensor({-1.4123588626342636, -1.412358263650784, -1.4124124805696718}),
torch::tensor({-2.0632918101480255}),
},
{
torch::tensor({0.17963694231402447, 0.24250922426893617, 0.17963692980716212, 0.24250921210074514, 0.17963815759528076, 0.24251088666939838}),
torch::tensor({0.628722819525588, 0.628722867517244, 0.628727383936762}),
torch::tensor({-1.4123493065102781, -1.412349184462438, -1.4123617872243597}),
torch::tensor({-2.063351765096138}),
},
{
torch::tensor({0.17963672159046323, 0.2425089107097866, 0.17963671897003045, 0.24250890818318074, 0.17963700091084855, 0.24250929278196057}),
torch::tensor({0.6287218911936108, 0.6287219017313679, 0.6287229399204574}),
torch::tensor({-1.4123473011084142, -1.412347275640343, -1.41235016959507}),
torch::tensor({-2.0633651674043505}),
},
{
torch::tensor({0.17963667424978777, 0.24250884333120676, 0.17963667368829767, 0.24250884279379462, 0.17963673796131557, 0.24250893047794023}),
torch::tensor({0.6287216908150736, 0.6287216931558691, 0.6287219299749585}),
torch::tensor({-1.4123468700596722, -1.4123468646187736, -1.4123475243360133}),
torch::tensor({-2.0633681724342527}),
},
{
torch::tensor({0.17963666391853464, 0.24250882860835238, 0.1796366637961457, 0.24250882849182895, 0.17963667838367048, 0.24250884839397405}),
torch::tensor({0.6287216468984885, 0.6287216474215306, 0.6287217011907862}),
torch::tensor({-1.4123467758545658, -1.4123467746710385, -1.412346924400766}),
torch::tensor({-2.0633688474977463}),
},
};
static std::vector<std::vector<torch::Tensor>> Adagrad_with_weight_decay_and_lr_decay = {
{
torch::tensor({0.7891011046018798, 0.5024439245163383, 0.8587078329086189, 0.6579710994225316, 0.747636483621666, 1.697557019500142}),
torch::tensor({0.8914687688943375, 0.7020514988069164, 1.6892015076050049}),
torch::tensor({-1.0508031297732776, -1.3941351871450511, -1.284337597261839}),
torch::tensor({-1.0711381241617108}),
},
{
torch::tensor({2.346218944110103, 2.191939439502003, 1.683355201740813, 1.5405520021635604, 1.2137800230828062, 2.2052834637173024}),
torch::tensor({2.9090564593404, 1.7509657336815554, 2.336166413186925}),
torch::tensor({-2.206159683368316, -2.5344318233445415, -1.9622783535807609}),
torch::tensor({-1.5796101463783623}),
},
{
torch::tensor({2.3889328781057233, 2.2678221038007296, 1.7667624725138267, 1.6358015176639824, 1.2655767687152566, 2.2610880567112814}),
torch::tensor({3.045569451994985, 1.8770196253823253, 2.4192707519566765}),
torch::tensor({-2.4079300017528613, -2.7399112002234305, -2.0780613510632375}),
torch::tensor({-1.664722108226537}),
},
{
torch::tensor({2.3886137557806384, 2.2922158071009178, 1.8078384116424007, 1.6843524744409322, 1.290353948335789, 2.287071550970649}),
torch::tensor({3.111110355394278, 1.9438501730282314, 2.4630249355872826}),
torch::tensor({-2.5226122034499263, -2.857315093916292, -2.143964860243905}),
torch::tensor({-1.7130685809905042}),
},
{
torch::tensor({2.374703352203156, 2.298804499257456, 1.8330249458212446, 1.7151661013307247, 1.3048586226945842, 2.301765059046427}),
torch::tensor({3.150318222034133, 1.9877926185369321, 2.491399976401679}),
torch::tensor({-2.601415913361488, -2.938203895113964, -2.1892988334550028}),
torch::tensor({-1.7462964261966805}),
},
{
torch::tensor({2.3553658567303812, 2.297191758042688, 1.8501154749072124, 1.7368360586881881, 1.3141313000193942, 2.310745259215385}),
torch::tensor({3.1762315339155434, 2.0197585204578647, 2.5117041377790197}),
torch::tensor({-2.6606644002288697, -2.9991216074293856, -2.223413376189609}),
torch::tensor({-1.7712905233118807}),
},
{
torch::tensor({2.3338052201696207, 2.2913023710914993, 1.8624163948044772, 1.7530300731725457, 1.3203313209234842, 2.3163969478854742}),
torch::tensor({3.1943525925688934, 2.044447386769377, 2.527109724607397}),
torch::tensor({-2.7076634717294894, -3.047500808469036, -2.250495807208967}),
torch::tensor({-1.7911288238757486}),
},
{
torch::tensor({2.3114979154644892, 2.2830501835377808, 1.871616142999356, 1.7656325976608411, 1.324565631636651, 2.3199392342052025}),
torch::tensor({3.2074779925809085, 2.0642940833670544, 2.5392671301471235}),
torch::tensor({-2.7463093287485925, -3.087315541134716, -2.272780318857348}),
torch::tensor({-1.8074516661537263}),
},
{
torch::tensor({2.2891841627387346, 2.2734716995793693, 1.8786818999895825, 1.7757301317117604, 1.3274682997719436, 2.322067935399382}),
torch::tensor({3.2172019619454075, 2.0807140893178175, 2.5491374815141876}),
torch::tensor({-2.7789204504423823, -3.1209351402429175, -2.2915969523376867}),
torch::tensor({-1.821234722948421}),
},
{
torch::tensor({2.2672498238343066, 2.2631678037928893, 1.8842131287032622, 1.7840007705383885, 1.3294311820750493, 2.3232112430345424}),
torch::tensor({3.224507488068445, 2.094598223519413, 2.5573257155791715}),
torch::tensor({-2.8069849199086647, -3.1498826045022925, -2.3077996970997727}),
torch::tensor({-1.8331040438272388}),
},
{
torch::tensor({2.2458961718688957, 2.2525031725114775, 1.8886034384961112, 1.7908930341267955, 1.3307102435291205, 2.3236474976004615}),
torch::tensor({3.230036385413078, 2.1065407459636134, 2.5642349249609664}),
torch::tensor({-2.83151249424399, -3.1751926295316566, -2.3219682378974036}),
torch::tensor({-1.8434843744626483}),
},
};
static std::vector<std::vector<torch::Tensor>> RMSprop = {
{
torch::tensor({0.7890625772821005, 0.502415108650816, 0.8587027713011453, 0.6579673123006431, 0.7476283936579036, 1.6975509766054537}),
torch::tensor({0.8914573371873159, 0.7020499947573374, 1.6891991194739453}),
torch::tensor({-1.0508027874171133, -1.3941348219724659, -1.2843374000099703}),
torch::tensor({-1.0711379842715099}),
},
{
torch::tensor({2.4485718582774427, 2.2809152044417678, 1.734642444915197, 1.5940004770230667, 1.250761131839982, 2.248993270255382}),
torch::tensor({2.994661478530102, 1.815048529086425, 2.382542610897819}),
torch::tensor({-2.3036981738757825, -2.6337299521275646, -2.0183701223588213}),
torch::tensor({-1.620787559800898}),
},
{
torch::tensor({2.583758247560778, 2.4365737242301537, 1.8622886519354542, 1.7357065282848232, 1.3369695670141974, 2.3454934716983695}),
torch::tensor({3.2061266499381618, 1.9981112525417783, 2.5092495986614}),
torch::tensor({-2.6110809365525953, -2.9484807193016787, -2.1948985607984395}),
torch::tensor({-1.7501043480625826}),
},
{
torch::tensor({2.66996905113451, 2.536559412710799, 1.9456091681389676, 1.828914948091767, 1.3952956766999587, 2.4110816686341923}),
torch::tensor({3.3436729755936576, 2.1204057198913, 2.5961524902119497}),
torch::tensor({-2.8372329851331, -3.1817729538857207, -2.324997185399696}),
torch::tensor({-1.8450422173907484}),
},
{
torch::tensor({2.7375365004059113, 2.615307154535863, 2.0117493624534317, 1.9033001982031035, 1.4427501882445097, 2.4646213743186127}),
torch::tensor({3.452912454199796, 2.2190451524127526, 2.667552790123282}),
torch::tensor({-3.03294794567315, -3.384582488936652, -2.437729982499714}),
torch::tensor({-1.9271014784118223}),
},
{
torch::tensor({2.7952372917068744, 2.6828202203757217, 2.0687223272687003, 1.9676545487787112, 1.4844410726622166, 2.5117888904510117}),
torch::tensor({3.5471904628565745, 2.30511354826214, 2.7307948248967304}),
torch::tensor({-3.2141190290332533, -3.5729446336144495, -2.542197020654683}),
torch::tensor({-2.002997698521966}),
},
{
torch::tensor({2.8467333937519474, 2.743278517711039, 2.119898810135386, 2.025680541674126, 1.5225256464280221, 2.554983108087885}),
torch::tensor({3.632098323876194, 2.383289304179777, 2.7889864719999222}),
torch::tensor({-3.3875799441679257, -3.75376580108393, -2.642312326626043}),
torch::tensor({-2.0756169514457246}),
},
{
torch::tensor({2.89386004980602, 2.7987769841836148, 2.166973814751561, 2.0792385384302032, 1.5580820887115123, 2.5954023023969692}),
torch::tensor({3.71043881435304, 2.455919099321952, 2.8436784441941008}),
torch::tensor({-3.5567368287146413, -3.9304848687096916, -2.7400264794345746}),
torch::tensor({-2.1463982568717577}),
},
{
torch::tensor({2.9376493943999384, 2.8504929653123106, 2.2109027772324468, 2.1293746183147637, 1.5917084661873866, 2.6337100421079533}),
torch::tensor({3.7837853328443516, 2.5243155701130595, 2.8957265009949373}),
torch::tensor({-3.723426848521047, -4.104949193318518, -2.8363906937997516}),
torch::tensor({-2.2161187733606105}),
},
{
torch::tensor({2.9787316798887553, 2.8991451078473203, 2.252272487010976, 2.176728872920277, 1.6237627746697592, 2.670302507579268}),
torch::tensor({3.8530980655045086, 2.589275531553024, 2.9456388178450936}),
torch::tensor({-3.888685645961963, -4.278191888396593, -2.931996492835032}),
torch::tensor({-2.2852176995051234}),
},
{
torch::tensor({3.0175156205790485, 2.9452004251453263, 2.2914699255229634, 2.221721702578292, 1.6544730927621272, 2.705431441204422}),
torch::tensor({3.9190041004420166, 2.6513176244659373, 2.993736489599992}),
torch::tensor({-4.053111559341913, -4.450801801162238, -3.0271845519131966}),
torch::tensor({-2.353949890597344}),
},
};
static std::vector<std::vector<torch::Tensor>> RMSprop_with_weight_decay = {
{
torch::tensor({0.7890798754118442, 0.5024321083861885, 0.8587031097835685, 0.6579677474141494, 0.7476297677960806, 1.6975465611838714}),
torch::tensor({0.891458601354904, 0.7020500593937647, 1.6891986479348047}),
torch::tensor({-1.0508027868194278, -1.3941348166291232, -1.2843373988951865}),
torch::tensor({-1.0711379841318796}),
},
{
torch::tensor({0.21398926524040954, 0.27790117133525366, 0.1868480279466365, 0.2507569370785431, 0.19145335235126934, 0.2557687813140428}),
torch::tensor({0.6720959116687695, 0.6480734848063959, 0.6542630070046453}),
torch::tensor({-1.4357633640899108, -1.4493557950073153, -1.4619011018357038}),
torch::tensor({-1.9673083558727882}),
},
{
torch::tensor({0.2396193574469268, 0.3035423686592136, 0.19567694278592207, 0.2544696440134417, 0.21982879020282778, 0.27495711471996653}),
torch::tensor({0.6927895724662014, 0.6380155354791558, 0.6523245375961474}),
torch::tensor({-1.4137225835003873, -1.4170291001633764, -1.4166977298480798}),
torch::tensor({-2.062665111543728}),
},
{
torch::tensor({0.2506635865271791, 0.3146395114286015, 0.2511689291081516, 0.3043139957959059, 0.2521962504870668, 0.31601100081704064}),
torch::tensor({0.7051419232960152, 0.6699011906397474, 0.6990972846784108}),
torch::tensor({-1.4206083241624177, -1.4257037444100082, -1.4171061826065106}),
torch::tensor({-2.0755378747636994}),
},
{
torch::tensor({0.23285924743060057, 0.29652494304773935, 0.23353220027378083, 0.2969991261380096, 0.23358272245225947, 0.2973997498165736}),
torch::tensor({0.6855589594924651, 0.6796983775695622, 0.686417480398288}),
torch::tensor({-1.431107627946497, -1.4334934742818015, -1.422739552145117}),
torch::tensor({-2.0842642493046037}),
},
{
torch::tensor({0.23356397699390252, 0.2973714239198578, 0.23367622061822785, 0.297494475971607, 0.23418481357396348, 0.2981812292515654}),
torch::tensor({0.6866530583001251, 0.6858933385102606, 0.6883944045412651}),
torch::tensor({-1.4564955509607054, -1.458354813150068, -1.4418225445708615}),
torch::tensor({-2.1064103749186205}),
},
{
torch::tensor({0.2318717301174728, 0.2952904159872864, 0.23194024439476718, 0.2953701982498774, 0.2316421336904662, 0.2951041425983899}),
torch::tensor({0.6834813130194516, 0.6834401711464204, 0.6837275457100467}),
torch::tensor({-1.464783580527677, -1.465345240817906, -1.457114211277771}),
torch::tensor({-2.1209598505912095}),
},
{
torch::tensor({0.2308683396504175, 0.2940474629750446, 0.23089067678260933, 0.29407615110959273, 0.2306406931421414, 0.29379043611390215}),
torch::tensor({0.6815062281792609, 0.6815233687209212, 0.6812759203026143}),
torch::tensor({-1.4643013018530677, -1.4644523635284243, -1.4617493939684876}),
torch::tensor({-2.124729363567885}),
},
{
torch::tensor({0.23066464201462666, 0.29376059273730415, 0.2306706924585737, 0.29376903997842674, 0.2305755121160667, 0.2936477517373107}),
torch::tensor({0.6809028781780303, 0.6809134105028245, 0.6807404613096301}),
torch::tensor({-1.4637927352177984, -1.4638374228010729, -1.46299287102643}),
torch::tensor({-2.125808272063811}),
},
{
torch::tensor({0.23062625079199164, 0.29369907874257056, 0.23062789247291138, 0.293701483466165, 0.23059813368156995, 0.29366073890476385}),
torch::tensor({0.680725180468908, 0.6807295616357244, 0.6806523640328993}),
torch::tensor({-1.4635790398985613, -1.4635929269022856, -1.4633272688236563}),
torch::tensor({-2.1261396358141798}),
},
{
torch::tensor({0.23061701122700168, 0.29368398381778166, 0.23061747865501642, 0.29368467998690795, 0.2306085559563819, 0.2936719507340019}),
torch::tensor({0.6806714673830828, 0.6806730903175792, 0.6806434720800852}),
torch::tensor({-1.4635008778278131, -1.4635052859178375, -1.463420837506828}),
torch::tensor({-2.1262432969587723}),
},
};
static std::vector<std::vector<torch::Tensor>> RMSprop_with_weight_decay_and_centered = {
{
torch::tensor({0.7941000061626792, 0.507452636734552, 0.8637405354185987, 0.663005089317529, 0.7526661272860107, 1.7025887305065852}),
torch::tensor({0.8964950370033696, 0.7070877948157552, 1.6942369105467197}),
torch::tensor({-1.055840599214661, -1.3991726335388424, -1.2893752132746332}),
torch::tensor({-1.0761757981162612}),
},
{
torch::tensor({2.3762999876885833, 2.239095829416783, 1.726175067071914, 1.5891569459230446, 1.2410074108588462, 2.2345431036725723}),
torch::tensor({2.990896455635836, 1.8152108764849464, 2.377985429759037}),
torch::tensor({-2.3071822180635286, -2.636859516619699, -2.0198181394256642}),
torch::tensor({-1.622583045791722}),
},
{
torch::tensor({2.372800588647971, 2.3022753207224254, 1.836028714221617, 1.7190937269287108, 1.3068955839895078, 2.3035835673200364}),
torch::tensor({3.1656599892042343, 1.9942937608209463, 2.4947143457182657}),
torch::tensor({-2.6139790332516775, -2.9507738987695404, -2.1954425128779516}),
torch::tensor({-1.7513053380188806}),
},
{
torch::tensor({2.2398453700818455, 2.2513384246965904, 1.8892176431436287, 1.7921873754661688, 1.3310951408713538, 2.3236392222350397}),
torch::tensor({3.240166119454613, 2.1097428136001883, 2.5651614461576973}),
torch::tensor({-2.8388734382997454, -3.1824200770676123, -2.324831397600949}),
torch::tensor({-1.8460315737386976}),
},
{
torch::tensor({1.9829606312242465, 2.097356567850692, 1.9050263843525033, 1.8325835415812348, 1.3222762370713104, 2.3024963133870147}),
torch::tensor({3.2465360572089974, 2.196726604586991, 2.6091992649970672}),
torch::tensor({-3.0326878099587207, -3.3827004807595005, -2.436989182250496}),
torch::tensor({-1.928273216206344}),
},
{
torch::tensor({1.6051175329080525, 1.8332107491649117, 1.8794767349053179, 1.8403588051948858, 1.273824111314107, 2.2296571379436823}),
torch::tensor({3.1814362940910437, 2.2630192140728465, 2.6273016977574013}),
torch::tensor({-3.210932646440219, -3.567153254014387, -2.541016943923914}),
torch::tensor({-2.0049155134617154}),
},
{
torch::tensor({1.1588059349082709, 1.4778613795232265, 1.7992410089026636, 1.8064600091986671, 1.1739931551629919, 2.08647960875392}),
torch::tensor({3.03843703712275, 2.3082030683758767, 2.6125393914734083}),
torch::tensor({-3.379830678608588, -3.741970414470626, -2.6410082400846546}),
torch::tensor({-2.079294995910487}),
},
{
torch::tensor({0.7701433312419087, 1.110502667742475, 1.6465075169366392, 1.7162526909817901, 1.013748545414221, 1.8532966501655352}),
torch::tensor({2.827176875885245, 2.3274019481599275, 2.5535309398603405}),
torch::tensor({-3.54193329850986, -3.9096652952123145, -2.739408870192437}),
torch::tensor({-2.1537939241668997}),
},
{
torch::tensor({0.559892312935121, 0.8460500042788702, 1.408417554916502, 1.5547314210944567, 0.8019580519338424, 1.5258384663629627}),
torch::tensor({2.5774950379490265, 2.3131013066991266, 2.4388695757441745}),
torch::tensor({-3.6974974230160087, -4.070190514312716, -2.8378932675718405}),
torch::tensor({-2.2307225014430423}),
},
{
torch::tensor({0.5016784472836648, 0.7258690889265434, 1.0976902935953958, 1.3199491879725134, 0.5853930356154851, 1.1446978015944624}),
torch::tensor({2.3235249877284945, 2.259284097042017, 2.2681461698609375}),
torch::tensor({-3.8444921272569115, -4.22021051361099, -2.9373192115434263}),
torch::tensor({-2.312733063937045}),
},
{
torch::tensor({0.4875468895095056, 0.6878747871467128, 0.7787871237567607, 1.046259254610218, 0.4416468896022397, 0.8122992916762792}),
torch::tensor({2.107873451558748, 2.17034337037527, 2.0666325968568535}),
torch::tensor({-3.9782695475825216, -4.352093055115415, -3.0377809502927033}),
torch::tensor({-2.403496388200805}),
},
};
static std::vector<std::vector<torch::Tensor>> RMSprop_with_weight_decay_and_centered_and_momentum = {
{
torch::tensor({0.7941000061626794, 0.507452636734552, 0.8637405354185986, 0.6630050893175291, 0.7526661272860108, 1.7025887305065852}),
torch::tensor({0.8964950370033697, 0.7070877948157552, 1.6942369105467197}),
torch::tensor({-1.055840599214661, -1.3991726335388426, -1.289375213274633}),
torch::tensor({-1.0761757981162612}),
},
{
torch::tensor({11.587263945492356, 12.552112516667208, 10.773002960161074, 10.782117868337808, 9.675467654064095, 10.83068936005479}),
torch::tensor({15.298238342006444, 11.252244653209868, 11.423905295074075}),
torch::tensor({-11.287147147258441, -11.673871066494183, -11.143068139029769}),
torch::tensor({-10.744790465364128}),
},
{
torch::tensor({5.99313075778439, 7.778269455146454, 9.705741295559012, 9.974952848613889, 8.171307305871647, 9.551498426643079}),
torch::tensor({12.811268477045154, 10.912201832960704, 10.87477550647832}),
torch::tensor({-11.20842921856976, -11.58706973895515, -11.098172235374586}),
torch::tensor({-10.71411038369856}),
},
{
torch::tensor({1.917316794757855, 3.442098373003916, 8.160846071267297, 8.76673426856121, 6.163892823252042, 7.748894752821818}),
torch::tensor({9.529299379813787, 10.371703621802427, 10.02242566317017}),
torch::tensor({-11.07914626767133, -11.444639737948599, -11.02397978065452}),
torch::tensor({-10.663204622623407}),
},
{
torch::tensor({0.242111629257451, 0.8235150923738453, 6.109652191353378, 7.070860554523036, 3.8366635637770212, 5.460370584182961}),
torch::tensor({5.7908039507441, 9.53430906906639, 8.752252906881251}),
torch::tensor({-10.868651889371552, -11.212965695734527, -10.90242744782103}),
torch::tensor({-10.57959689981644}),
},
{
torch::tensor({0.002420600902047647, 0.055217404976894736, 3.753606156332189, 4.9331546064599685, 1.7094621184709604, 3.022224882400484}),
torch::tensor({2.4729429920325234, 8.29021143930646, 6.983317870704776}),
torch::tensor({-10.529133489023623, -10.839885990130032, -10.704345435808353}),
torch::tensor({-10.442792354138112}),
},
{
torch::tensor({8.523664833406609e-06, -0.000184980158096171, 1.6343074841140277, 2.683608480982545, 0.4142510780713272, 1.0921118166095118}),
torch::tensor({0.5531198735383179, 6.566845593450315, 4.783317472190565}),
torch::tensor({-9.990101114696575, -10.24914448933998, -10.38447825909146}),
torch::tensor({-10.22038237537473}),
},
{
torch::tensor({5.366918233939821e-08, -2.899704029399208e-07, 0.37916783268568177, 0.9399553431452387, 0.028595281293376063, 0.1765061433770474}),
torch::tensor({0.03166973497545415, 4.4428469940935225, 2.520346492875472}),
torch::tensor({-9.15653357178671, -9.339631853060773, -9.875729313751442}),
torch::tensor({-9.862669711962376}),
},
{
torch::tensor({2.1133356499004305e-06, 2.4524630407767974e-06, 0.023655729923601897, 0.1427370957829138, -8.950192389690754e-05, 0.00423769700896404}),
torch::tensor({-0.00012364097582548354, 2.2911918591079274, 0.833141440960252}),
torch::tensor({-7.922566174765117, -8.003055545094796, -9.086673634672907}),
torch::tensor({-9.297519364373226}),
},
{
torch::tensor({0.0023497430294992443, 0.002861131671472504, 0.0006998739627296077, 0.0036571565360575277, 0.0016543034713696228, 0.0018171459470053383}),
torch::tensor({0.004569191565477358, 0.729246659971123, 0.11475431260766132}),
torch::tensor({-6.223834483308681, -6.185383631607397, -7.912955414853613}),
torch::tensor({-8.430731662958186}),
},
{
torch::tensor({0.1039382034036754, 0.13982074666181726, 0.08314071982729489, 0.10183584198629941, 0.139495945169722, 0.17822672100147102}),
torch::tensor({0.34039464502063893, 0.24860888862359676, 0.31914045155310644}),
torch::tensor({-4.174294597914298, -4.037528929635062, -6.297198700024484}),
torch::tensor({-7.182093090194919}),
},
};
static std::vector<std::vector<torch::Tensor>> SGD = {
{
torch::tensor({-0.21063957030131192, -0.4972093725858961, -0.13931849072410168, -0.33939101965581686, -0.25112865488453673, 0.6992101966874735}),
torch::tensor({-0.1076573444246077, -0.2913064413859577, 0.6933846874181748}),
torch::tensor({-0.07998325778863398, -0.42149210515421365, -0.33498349553944556}),
torch::tensor({-0.14255126505509488}),
},
{
torch::tensor({-0.15543131540224012, -0.4235110396372034, -0.04196796248622072, -0.20952231780684996, -0.16031407286541025, 0.8209742464453325}),
torch::tensor({0.07724343607160139, 0.0338752947249023, 1.0028793648054941}),
torch::tensor({-0.8213382425894498, -1.1570800333254736, -1.6154760331657425}),
torch::tensor({-1.873409073108485}),
},
{
torch::tensor({-0.13342791770744886, -0.39415097094881035, -0.011470356542661935, -0.16885142516066973, -0.1330668069352811, 0.8576491729785701}),
torch::tensor({0.1508101460076168, 0.13560816175111742, 1.0971559708365837}),
torch::tensor({-0.9780975407869251, -1.3215153697157922, -1.8760213876051515}),
torch::tensor({-2.202441305652889}),
},
{
torch::tensor({-0.11963097684681223, -0.3757367513013453, 0.00699871664138837, -0.14420855651125983, -0.11733423659038761, 0.8788673419128562}),
torch::tensor({0.19698293387590052, 0.1973461164047132, 1.1520119567305152}),
torch::tensor({-1.0677802792431819, -1.4166561260631116, -2.0220337532169905}),
torch::tensor({-2.3834524272927813}),
},
{
torch::tensor({-0.10950806441156272, -0.36222266992185936, 0.02028489243523426, -0.12647254228380078, -0.10635775660996466, 0.8936912722040982}),
torch::tensor({0.23089462331826796, 0.24184450074084415, 1.1904864598387046}),
torch::tensor({-1.1306213044009719, -1.483718648357814, -2.1228846025142074}),
torch::tensor({-2.50713525051584}),
},
{
torch::tensor({-0.10149090356585248, -0.35151721158128657, 0.030662536099764087, -0.11261325211798627, -0.09797248308626626, 0.905027632401109}),
torch::tensor({0.25777759826689434, 0.2766609657536915, 1.2199973265718322}),
torch::tensor({-1.1789655573653979, -1.5355073692636771, -2.1996125838846075}),
torch::tensor({-2.6005295414716625}),
},
{
torch::tensor({-0.09484472748389533, -0.34264050232430837, 0.03917399284640637, -0.10124188994381239, -0.09121264836307838, 0.9141743475340721}),
torch::tensor({0.2800829300171032, 0.3052600200290068, 1.2438661306695873}),
torch::tensor({-1.2182324765944266, -1.577685139408549, -2.261370486631629}),
torch::tensor({-2.675274336197319}),
},
{
torch::tensor({-0.08916446117741175, -0.33505233521798655, 0.04638527943959316, -0.09160422984057529, -0.08556486270584647, 0.9218219103015535}),
torch::tensor({0.2991619380154852, 0.32952375512951004, 1.2638639017720827}),
torch::tensor({-1.251282493526328, -1.613256463950431, -2.312952993721385}),
torch::tensor({-2.7374195723946606}),
},
{
torch::tensor({-0.08420245801272856, -0.3284224385121881, 0.05263847708646642, -0.08324438788845256, -0.08072424164719601, 0.9283806476306355}),
torch::tensor({0.31584087342663564, 0.3505901981820039, 1.2810450644764015}),
torch::tensor({-1.2798091496372141, -1.644007253821019, -2.3571804629611046}),
torch::tensor({-2.7905023459395877}),
},
{
torch::tensor({-0.07979600214534928, -0.3225337978155752, 0.0581562720006689, -0.07586555700667837, -0.0764952395510804, 0.9341138824526719}),
torch::tensor({0.3306627217189733, 0.36920055785772116, 1.2960873917356066}),
torch::tensor({-1.3048976883823566, -1.671085574250112, -2.3958498984546135}),
torch::tensor({-2.8367650855551236}),
},
{
torch::tensor({-0.07583232846497832, -0.3172360102461861, 0.06309179259248046, -0.06926361352067169, -0.07274510848082805, 0.9392004636935606}),
torch::tensor({0.3440038606091545, 0.3858647867996721, 1.3094518934419668}),
torch::tensor({-1.3272851146877218, -1.695273130850265, -2.4301754289421593}),
torch::tensor({-2.877716472882302}),
},
};
static std::vector<std::vector<torch::Tensor>> SGD_with_weight_decay = {
{
torch::tensor({-0.21042867144447805, -0.49671181653925384, -0.13917719856207697, -0.3390489907590303, -0.2508762913762564, 0.6985126396619242}),
torch::tensor({-0.10754881320494518, -0.2910084928862701, 0.6926954859081793}),
torch::tensor({-0.079932454658518, -0.42109796996670307, -0.33469915794198624}),
torch::tensor({-0.14248012693079315}),
},
{
torch::tensor({-0.13579982290274883, -0.3765456284475787, -0.03166970700350032, -0.18102559254681197, -0.1373234786735746, 0.7522156177001302}),
torch::tensor({0.0855000382601442, 0.051563225553454176, 0.9321399061276381}),
torch::tensor({-0.7963122388825842, -1.1010063686038731, -1.5363716774172782}),
torch::tensor({-1.8045854907382846}),
},
{
torch::tensor({-0.09659168723529124, -0.3056207693658826, 0.006712867145512926, -0.1166002367977548, -0.09012083166238948, 0.7264953102453366}),
torch::tensor({0.16531808496504807, 0.16488328577596395, 0.9610743966573316}),
torch::tensor({-0.9202466399245917, -1.205282927289183, -1.7049756710541348}),
torch::tensor({-2.0415977924493043}),
},
{
torch::tensor({-0.06728100597713034, -0.24965896016541955, 0.03186158526394668, -0.07105441484407879, -0.056478595544178806, 0.6910758436366732}),
torch::tensor({0.21707768347081782, 0.23575238192099465, 0.9564382346520685}),
torch::tensor({-0.9788195039030002, -1.2447191597975944, -1.7620201560619633}),
torch::tensor({-2.131504419683077}),
},
{
torch::tensor({-0.043049550531555035, -0.20206572730420896, 0.050959513946324454, -0.03470009355744102, -0.02922465201167018, 0.654761170560436}),
torch::tensor({0.25638982315377085, 0.2878867158887637, 0.94142216852528}),
torch::tensor({-1.0143969472996657, -1.2623288365082086, -1.780047146006567}),
torch::tensor({-2.170255083720924}),
},
{
torch::tensor({-0.02215471703826272, -0.16036518660639854, 0.06644401410758824, -0.004183373274651932, -0.0059658779785277715, 0.6200298215101534}),
torch::tensor({0.2886406829874718, 0.32924516791460257, 0.9230983700837222}),
torch::tensor({-1.0397895250773483, -1.271091416624018, -1.780775800960309}),
torch::tensor({-2.1862978976514738}),
},
{
torch::tensor({-0.0037439139848316947, -0.1232829330825193, 0.07944696186805639, 0.022100305718441984, 0.014399113804332047, 0.587697912745227}),
torch::tensor({0.31628710746920086, 0.36346293565421134, 0.9042402154310412}),
torch::tensor({-1.0602349614300883, -1.2762264965487673, -1.7731268727630665}),
torch::tensor({-2.191253945056341}),
},
{
torch::tensor({0.01267598593885474, -0.09003711893222124, 0.09059095692632842, 0.04506778924310346, 0.03247299240601001, 0.5579755127260052}),
torch::tensor({0.3406226998933173, 0.3924947745885882, 0.8860121369119326}),
torch::tensor({-1.0781407849705036, -1.2800528898634016, -1.7613120374342217}),
torch::tensor({-2.190575043873577}),
},
{
torch::tensor({0.027425440985778007, -0.06008809958617214, 0.10026092920861805, 0.0653109294703924, 0.04862875490793198, 0.5308215072596255}),
torch::tensor({0.36239744520280553, 0.4175162387638887, 0.8688788105023479}),
torch::tensor({-1.0946579691370502, -1.2836103422269476, -1.7474706191775766}),
torch::tensor({-2.1870021744944763}),
},
{
torch::tensor({0.040732509801474116, -0.033030241035550065, 0.10871770475931387, 0.08324870459183514, 0.06312228688815541, 0.5060892094042873}),
torch::tensor({0.38208249693950175, 0.4393002654989596, 0.8529817924677643}),
torch::tensor({-1.1103326127955466, -1.2873324059163584, -1.7327386627485202}),
torch::tensor({-2.1819672316721337}),
},
{
torch::tensor({0.052771609187326055, -0.008539186625351379, 0.11615154444871967, 0.09919929206676083, 0.07614530177703589, 0.48359250162323586}),
torch::tensor({0.3999968617221315, 0.4583944200925636, 0.838313296680579}),
torch::tensor({-1.1254107858333455, -1.2913604197768884, -1.717739109221235}),
torch::tensor({-2.176236807160431}),
},
};
static std::vector<std::vector<torch::Tensor>> SGD_with_weight_decay_and_momentum = {
{
torch::tensor({-0.21042867144447805, -0.49671181653925384, -0.13917719856207697, -0.3390489907590303, -0.2508762913762564, 0.6985126396619242}),
torch::tensor({-0.10754881320494518, -0.2910084928862701, 0.6926954859081793}),
torch::tensor({-0.079932454658518, -0.42109796996670307, -0.33469915794198624}),
torch::tensor({-0.14248012693079315}),
},
{
torch::tensor({0.005611848725195478, -0.0710915563059199, 0.07701400891926034, 0.047067327035013866, 0.042865405297259754, 0.4352977220593751}),
torch::tensor({0.23834837300214828, 0.3236638250370417, 0.712832101663469}),
torch::tensor({-1.0419477883948856, -1.1730950187020546, -1.7648205873351155}),
torch::tensor({-2.3359277661920594}),
},
{
torch::tensor({0.11520007183759419, 0.12894537687632862, 0.1458684555595196, 0.17753415358762184, 0.15614155642578992, 0.33379126147460536}),
torch::tensor({0.4658536564136851, 0.5201979178769089, 0.7274876508280723}),
torch::tensor({-1.203474644488253, -1.2861269692338677, -1.604528340632377}),
torch::tensor({-2.203215909196624}),
},
{
torch::tensor({0.15331258730375, 0.19790903623360404, 0.16663814647374192, 0.2183320498727895, 0.18032745504822867, 0.2836274579441783}),
torch::tensor({0.5532312776994918, 0.5834224152126114, 0.6903579410976887}),
torch::tensor({-1.3052171323471546, -1.3514190497186431, -1.5153574535010634}),
torch::tensor({-2.123181139806548}),
},
{
torch::tensor({0.16814113185552507, 0.2238657220144887, 0.17413795101952864, 0.2328051532626163, 0.18391422079762276, 0.26144994958709084}),
torch::tensor({0.592282876576759, 0.6083877519652824, 0.6634387486999062}),
torch::tensor({-1.3591143274292896, -1.3836730658309968, -1.467157893517277}),
torch::tensor({-2.087859547998447}),
},
{
torch::tensor({0.1743742243877178, 0.23431261530597985, 0.1771694292764225, 0.23838669643330085, 0.18308461132092924, 0.25149544624452974}),
torch::tensor({0.6108281747800746, 0.6192657661217672, 0.6475519545045926}),
torch::tensor({-1.386052705444441, -1.3988166642380868, -1.4412527948055516}),
torch::tensor({-2.0731939075659627}),
},
{
torch::tensor({0.1771465478751462, 0.23875859951719525, 0.17848682715848566, 0.24067863725664956, 0.18181103291606765, 0.24687877342069478}),
torch::tensor({0.6198586021174767, 0.6242349464856269, 0.638736845373371}),
torch::tensor({-1.3993307716862977, -1.405896519385159, -1.42747775986796}),
torch::tensor({-2.06726758434046}),
},
{
torch::tensor({0.17843093585357678, 0.24073954802700467, 0.17908697027440862, 0.24166758399092678, 0.18088350526559055, 0.24467193314356372}),
torch::tensor({0.6243071074374693, 0.6265628975677454, 0.6339840865876518}),
torch::tensor({-1.4058750036106915, -1.4092362337714566, -1.4202202926903085}),
torch::tensor({-2.0649062340635584}),
},
{
torch::tensor({0.17904350645021616, 0.2416549694624704, 0.1793692065848773, 0.2421116448977685, 0.18031858582735985, 0.24359239926305207}),
torch::tensor({0.6265134455078061, 0.6276715667697312, 0.6314641991686346}),
torch::tensor({-1.409113940967948, -1.4108307952354533, -1.41642472852534}),
torch::tensor({-2.0639728292802046}),
},
{
torch::tensor({0.17934167113683835, 0.24208996240463102, 0.17950490408309286, 0.24231745350706002, 0.17999989292556764, 0.2430557755257576}),
torch::tensor({0.6276131793232345, 0.62820623280908, 0.6301427155170752}),
torch::tensor({-1.4107251789010826, -1.411601182417186, -1.4144511767962418}),
torch::tensor({-2.063605631667394}),
},
{
torch::tensor({0.17948886155124505, 0.2423009633220481, 0.17957117450689372, 0.242415213133214, 0.17982712042628354, 0.24278620392248684}),
torch::tensor({0.6281635672171683, 0.6284667582211864, 0.6294549191500092}),
torch::tensor({-1.4115305541843781, -1.4119772978756442, -1.4134296522818637}),
torch::tensor({-2.0634616066978615}),
},
};
static std::vector<std::vector<torch::Tensor>> SGD_with_weight_decay_and_nesterov_momentum = {
{
torch::tensor({-0.21040617235121148, -0.49689727139951717, -0.13754215970803657, -0.33701686525263036, -0.2500172388792182, 0.700697918175925}),
torch::tensor({-0.1068708360895515, -0.2853285323043249, 0.6971494161502307}),
torch::tensor({-0.10624536304143092, -0.44611325614778935, -0.3805647497874434}),
torch::tensor({-0.2068230782168696}),
},
{
torch::tensor({-0.12623871135486553, -0.3844658218758334, 0.03124406856508882, -0.11170532152425779, -0.09823268522398335, 0.9040698525178972}),
torch::tensor({0.17551336074135093, 0.2797661479202715, 1.213839968098513}),
torch::tensor({-1.5928404135955905, -1.8986806244521566, -2.9661819144548276}),
torch::tensor({-3.7728444542017687}),
},
{
torch::tensor({-0.11614716303292186, -0.3709539909720773, 0.0430707804551277, -0.09588329367245822, -0.08795603365024907, 0.9178771227283019}),
torch::tensor({0.20944042006388677, 0.31954838894016663, 1.250027034831072}),
torch::tensor({-1.6350110524945016, -1.9463243375558275, -3.0357080369739844}),
torch::tensor({-3.8570351018212796}),
},
{
torch::tensor({-0.10793942832760069, -0.3599569797368297, 0.05260329955808714, -0.08312010825923574, -0.07986326997915322, 0.9287409473303162}),
torch::tensor({0.23705744590903954, 0.3516841502052484, 1.2786184381275743}),
torch::tensor({-1.6691418106580105, -1.9848943707673132, -3.0912595329171024}),
torch::tensor({-3.923827025320545}),
},
{
torch::tensor({-0.10101428268579213, -0.3506724761241543, 0.060586427651359506, -0.07242353828264113, -0.07320722520220561, 0.9376663294528951}),
torch::tensor({0.2603753137363851, 0.37864768429039014, 1.302192517495494}),
torch::tensor({-1.697862366801323, -2.017346013780729, -3.1375112487519083}),
torch::tensor({-3.9791368472670334}),
},
{
torch::tensor({-0.09502239258273842, -0.3426342587463101, 0.0674491214906093, -0.06322196689556117, -0.06756850374320936, 0.9452179348012486}),
torch::tensor({0.28056290217301727, 0.40186559210837874, 1.322201974233735}),
torch::tensor({-1.722666737567296, -2.0453651314263936, -3.1770946252356755}),
torch::tensor({-4.02626353351958}),
},
{
torch::tensor({-0.08974074058929346, -0.33554465536214045, 0.07346375443579242, -0.05515233627910104, -0.06268648871001675, 0.9517469338705254}),
torch::tensor({0.29836675933627543, 0.42224471824689475, 1.3395523443811077}),
torch::tensor({-1.7445022249557354, -2.070022023204061, -3.211664011269998}),
torch::tensor({-4.0672678681014025}),
},
{
torch::tensor({-0.08501805567029427, -0.3292017353590117, 0.0788141885573336, -0.04796950604202485, -0.0583884110641127, 0.9574862878804136}),
torch::tensor({0.31429348099811794, 0.44039784591233994, 1.3548455497581404}),
torch::tensor({-1.7640079833055844, -2.092039516337239, -3.242327272701799}),
torch::tensor({-4.1035226231344275}),
},
{
torch::tensor({-0.08074691916762686, -0.3234620912535539, 0.0836304195409182, -0.041500113264404184, -0.05455400195824528, 0.9625982669377223}),
torch::tensor({0.32870295540834465, 0.45675864744543776, 1.3685016029692183}),
torch::tensor({-1.7816359699653226, -2.1119291060463254, -3.269862566805428}),
torch::tensor({-4.135987715622241}),
},
{
torch::tensor({-0.07684825926741547, -0.31822014874966065, 0.08800775942949751, -0.035617020506473744, -0.05109626940874276, 0.967200308063431}),
torch::tensor({0.34186020738651046, 0.47164537681262014, 1.3808249543962092}),
torch::tensor({-1.7977177360253924, -2.1300662313234127, -3.294837291074354}),
torch::tensor({-4.165360368453936}),
},
{
torch::tensor({-0.07326215742204907, -0.313395898483588, 0.0920181697641692, -0.03022421717885477, -0.04795031746059942, 0.9713800632469923}),
torch::tensor({0.3539661450911825, 0.4852985244949888, 1.3920431643924076}),
torch::tensor({-1.8125038126190633, -2.146734711618823, -3.317677824015751}),
torch::tensor({-4.192162739857097}),
},
};
} // namespace expected_parameters