blob: 85702f538ed73713c3dca845eeaebbbc704f92e6 [file] [log] [blame]
// @generated from test/cpp/api/optim_baseline.py
#include <torch/tensor.h>
#include <vector>
namespace expected_parameters {
static std::vector<std::vector<torch::Tensor>> Adam = {
{
torch::tensor({0.7889792072185753, 0.502352757161707, 0.8586918160350755, 0.6579591155483523, 0.7476108843716649, 1.6975378965521928}),
torch::tensor({0.8914325952640971, 0.7020467393659713, 1.6891939504169646}),
torch::tensor({-1.0508020464076324, -1.3941340315767958, -1.2843369730696377}),
torch::tensor({-1.0711376814873597}),
},
{
torch::tensor({6.7760957077832815, 6.9124599235220785, 6.433510595961532, 6.317833927015116, 5.858358965431237, 6.877000885215112}),
torch::tensor({7.976190958038579, 6.534978464582878, 7.121533207755487}),
torch::tensor({-6.714974314073826, -7.080922682620933, -6.744778112213496}),
torch::tensor({-6.4298819001431085}),
},
{
torch::tensor({4.558988462696042, 5.176192443970086, 6.0336571634341505, 6.0224389139758, 5.280816350261895, 6.365433817392172}),
torch::tensor({7.158919001867603, 6.412706290235066, 6.908680508139067}),
torch::tensor({-6.686226984996727, -7.0485763202090785, -6.728173785196665}),
torch::tensor({-6.418722639668962}),
},
{
torch::tensor({2.676497911906038, 3.514230904704933, 5.540289462213301, 5.651091500603315, 4.599555032332985, 5.7436411842005715}),
torch::tensor({6.199573530483314, 6.253637773745561, 6.635649515135911}),
torch::tensor({-6.6479125255000415, -7.005525929642858, -6.705848142237787}),
torch::tensor({-6.403590364569412}),
},
{
torch::tensor({1.390551436159619, 2.186808822332009, 4.994097914148984, 5.230206052925842, 3.8873275598775168, 5.068205713003804}),
torch::tensor({5.2028669487041, 6.065647517713214, 6.318260904353201}),
torch::tensor({-6.601668647421558, -6.9535874186665785, -6.67883101301033}),
torch::tensor({-6.385259022353005}),
},
{
torch::tensor({0.6428874779821907, 1.2530294670027888, 4.42474396483426, 4.778966032173873, 3.194104800709639, 4.379924438267535}),
torch::tensor({4.239990461988167, 5.853776036724999, 5.967533875546956}),
torch::tensor({-6.548225472035984, -6.893592571399344, -6.64750966472527}),
torch::tensor({-6.363980692898975}),
},
{
torch::tensor({0.2649056275082859, 0.6623036233906192, 3.855580769864683, 4.312992553271765, 2.5539685072813745, 3.7096880079750543}),
torch::tensor({3.358522498538248, 5.6220080436157085, 5.592491715335737}),
torch::tensor({-6.488047848171285, -6.826075945430906, -6.612113718185246}),
torch::tensor({-6.339899671178354}),
},
{
torch::tensor({0.09720440879172912, 0.32295960311161326, 3.305014661197478, 3.8453888770306373, 1.987959590244318, 3.0803155463808802}),
torch::tensor({2.5867474425386776, 5.373786289684033, 5.200999931806661}),
torch::tensor({-6.421459980696732, -6.751415029544431, -6.572787274104027}),
torch::tensor({-6.313101192728845}),
},
{
torch::tensor({0.03168219948352204, 0.14514092874328724, 2.787089436799649, 3.3871307837850804, 1.506328896591618, 2.5075443167981337}),
torch::tensor({1.9371981268942096, 5.1122381676870345, 4.800104891435606}),
torch::tensor({-6.34870330695251, -6.669895358637439, -6.529622284228346}),
torch::tensor({-6.283633921566764}),
},
{
torch::tensor({0.009143320123746463, 0.06000692560705293, 2.3118678318360013, 2.947231977224673, 1.110782434234418, 2.0008734743554664}),
torch::tensor({1.4101939064831694, 4.840286974170306, 4.3961763185205}),
torch::tensor({-6.269967843778274, -6.581745817714579, -6.482676531197509}),
torch::tensor({-6.251522016769645}),
},
{
torch::tensor({0.00234842105352119, 0.02278637440138154, 1.8857973138568447, 2.532835881548222, 0.7967451883439284, 1.5644546013645035}),
torch::tensor({0.9973538456553341, 4.560708133428994, 3.9949611760478665}),
torch::tensor({-6.18541110750813, -6.487160016579177, -6.4319844222067895}),
torch::tensor({-6.216772332487236}),
},
};
static std::vector<std::vector<torch::Tensor>> Adagrad = {
{
torch::tensor({0.7891011046018798, 0.5024439245163383, 0.8587078329086189, 0.6579710994225316, 0.747636483621666, 1.697557019500142}),
torch::tensor({0.8914687688943375, 0.7020514988069164, 1.6892015076050049}),
torch::tensor({-1.0508031297732776, -1.3941351871450511, -1.284337597261839}),
torch::tensor({-1.0711381241617108}),
},
{
torch::tensor({2.346218944110103, 2.191939439502003, 1.683355201740813, 1.5405520021635604, 1.2137800230828062, 2.2052834637173024}),
torch::tensor({2.9090564593404, 1.7509657336815554, 2.336166413186925}),
torch::tensor({-2.206159683368316, -2.5344318233445415, -1.9622783535807609}),
torch::tensor({-1.5796101463783623}),
},
{
torch::tensor({2.3889328781057233, 2.2678221038007296, 1.7667624725138267, 1.6358015176639824, 1.2655767687152566, 2.2610880567112814}),
torch::tensor({3.045569451994985, 1.8770196253823253, 2.4192707519566765}),
torch::tensor({-2.4079300017528613, -2.7399112002234305, -2.0780613510632375}),
torch::tensor({-1.664722108226537}),
},
{
torch::tensor({2.3886137557806384, 2.2922158071009178, 1.8078384116424007, 1.6843524744409322, 1.290353948335789, 2.287071550970649}),
torch::tensor({3.111110355394278, 1.9438501730282314, 2.4630249355872826}),
torch::tensor({-2.5226122034499263, -2.857315093916292, -2.143964860243905}),
torch::tensor({-1.7130685809905042}),
},
{
torch::tensor({2.374703352203156, 2.298804499257456, 1.8330249458212446, 1.7151661013307247, 1.3048586226945842, 2.301765059046427}),
torch::tensor({3.150318222034133, 1.9877926185369321, 2.491399976401679}),
torch::tensor({-2.601415913361488, -2.938203895113964, -2.1892988334550028}),
torch::tensor({-1.7462964261966805}),
},
{
torch::tensor({2.3553658567303812, 2.297191758042688, 1.8501154749072124, 1.7368360586881881, 1.3141313000193942, 2.310745259215385}),
torch::tensor({3.1762315339155434, 2.0197585204578647, 2.5117041377790197}),
torch::tensor({-2.6606644002288697, -2.9991216074293856, -2.223413376189609}),
torch::tensor({-1.7712905233118807}),
},
{
torch::tensor({2.3338052201696207, 2.2913023710914993, 1.8624163948044772, 1.7530300731725457, 1.3203313209234842, 2.3163969478854742}),
torch::tensor({3.1943525925688934, 2.044447386769377, 2.527109724607397}),
torch::tensor({-2.7076634717294894, -3.047500808469036, -2.250495807208967}),
torch::tensor({-1.7911288238757486}),
},
{
torch::tensor({2.3114979154644892, 2.2830501835377808, 1.871616142999356, 1.7656325976608411, 1.324565631636651, 2.3199392342052025}),
torch::tensor({3.2074779925809085, 2.0642940833670544, 2.5392671301471235}),
torch::tensor({-2.7463093287485925, -3.087315541134716, -2.272780318857348}),
torch::tensor({-1.8074516661537263}),
},
{
torch::tensor({2.2891841627387346, 2.2734716995793693, 1.8786818999895825, 1.7757301317117604, 1.3274682997719436, 2.322067935399382}),
torch::tensor({3.2172019619454075, 2.0807140893178175, 2.5491374815141876}),
torch::tensor({-2.7789204504423823, -3.1209351402429175, -2.2915969523376867}),
torch::tensor({-1.821234722948421}),
},
{
torch::tensor({2.2672498238343066, 2.2631678037928893, 1.8842131287032622, 1.7840007705383885, 1.3294311820750493, 2.3232112430345424}),
torch::tensor({3.224507488068445, 2.094598223519413, 2.5573257155791715}),
torch::tensor({-2.8069849199086647, -3.1498826045022925, -2.3077996970997727}),
torch::tensor({-1.8331040438272388}),
},
{
torch::tensor({2.2458961718688957, 2.2525031725114775, 1.8886034384961112, 1.7908930341267955, 1.3307102435291205, 2.3236474976004615}),
torch::tensor({3.230036385413078, 2.1065407459636134, 2.5642349249609664}),
torch::tensor({-2.83151249424399, -3.1751926295316566, -2.3219682378974036}),
torch::tensor({-1.8434843744626483}),
},
};
static std::vector<std::vector<torch::Tensor>> RMSprop = {
{
torch::tensor({0.789062580418602, 0.5024151127899734, 0.8587027713374162, 0.6579673123497956, 0.7476283938233967, 1.6975509763502088}),
torch::tensor({0.891457337329431, 0.7020499947640909, 1.6891991194345595}),
torch::tensor({-1.0508027874170536, -1.3941348219719392, -1.2843374000098593}),
torch::tensor({-1.0711379842714959}),
},
{
torch::tensor({11.617971383650625, 12.569324191541254, 10.750362589690871, 10.758744280441727, 9.645515422139535, 10.800246335822324}),
torch::tensor({15.281460788745067, 11.227695130869545, 11.392623137702232}),
torch::tensor({-11.253200612957224, -11.640304407367251, -11.10515970703974}),
torch::tensor({-10.70450790002659}),
},
{
torch::tensor({6.046736887816883, 7.825431817224273, 9.687011716351043, 9.954711288408674, 8.14740348271963, 9.525985838440336}),
torch::tensor({12.822053316618073, 10.889090672152843, 10.845780422808176}),
torch::tensor({-11.174857199859773, -11.553906933533701, -11.060485249649355}),
torch::tensor({-10.673987015109141}),
},
{
torch::tensor({1.9571373332473823, 3.490982039449725, 8.146064947642957, 8.749830187230062, 6.146253042228082, 7.728648273553586}),
torch::tensor({9.565022547037527, 10.350286153772108, 9.996138388056213}),
torch::tensor({-11.046063491906642, -11.411999818458279, -10.986591391047115}),
torch::tensor({-10.623300659238438}),
},
{
torch::tensor({0.2533394450092949, 0.8487216725008625, 6.099516863836628, 7.058052576203578, 3.826101977602061, 5.4465243550891405}),
torch::tensor({5.842262376064937, 9.515175340901033, 8.729600468763985}),
torch::tensor({-10.836290562755634, -11.181092820342055, -10.865488142142983}),
torch::tensor({-10.540026730263222}),
},
{
torch::tensor({0.0028003666727239255, 0.05902391869285125, 3.748716429598524, 4.925383772212378, 1.7054258503085213, 3.0154295058730067}),
torch::tensor({2.5189723383469143, 8.274294567516511, 6.965678470134445}),
torch::tensor({-10.497890796162736, -10.809198403294875, -10.668113905016504}),
torch::tensor({-10.40375240539236}),
},
{
torch::tensor({8.271676181935537e-06, -0.00018355896681218517, 1.6335648369488047, 2.6808923568105154, 0.41363824898472834, 1.090479198637994}),
torch::tensor({0.5739272720885893, 6.5553203371069895, 4.77208963856775}),
torch::tensor({-9.96060858122121, -10.220305543889674, -10.349375660214779}),
torch::tensor({-10.182195023732161}),
},
{
torch::tensor({6.185398198559693e-08, -1.541936772053279e-07, 0.3796273256352156, 0.9401265830080575, 0.028596845546205944, 0.17648335822427214}),
torch::tensor({0.03428668381950183, 4.436633619476827, 2.515631003256126}),
torch::tensor({-9.129738869797919, -9.313631525351905, -9.842416162331277}),
torch::tensor({-9.82584888365321}),
},
{
torch::tensor({2.301163769642748e-06, 2.6686276058313296e-06, 0.023759872757083522, 0.14307496311038903, -8.923289496023129e-05, 0.004248712398606914}),
torch::tensor({-0.0001300076077318404, 2.2897575859027177, 0.8324154760613494}),
torch::tensor({-7.899784778698039, -7.981253754035267, -9.056141702142126}),
torch::tensor({-9.262858557574456}),
},
{
torch::tensor({0.0024924016682296903, 0.0030339372114999833, 0.0007509002769717036, 0.0037341431497881668, 0.0017606472677062757, 0.0019348763452492956}),
torch::tensor({0.0048385780006693566, 0.7299482676059379, 0.11498900317323348}),
torch::tensor({-6.2066461415350735, -6.169374349409609, -7.886599467110102}),
torch::tensor({-8.399404241287046}),
},
{
torch::tensor({0.105613217805617, 0.14212963123097416, 0.08487768339425271, 0.10413270403927237, 0.14189898312301077, 0.1815884431559531}),
torch::tensor({0.34655586728243415, 0.2531201095942877, 0.32677017754456106}),
torch::tensor({-4.164684884039514, -4.029224257118343, -6.277175857463741}),
torch::tensor({-7.156227835327511}),
},
};
static std::vector<std::vector<torch::Tensor>> SGD = {
{
torch::tensor({-0.21063954921142625, -0.49720932283029146, -0.1393184765948855, -0.3393909854529272, -0.25112862964818594, 0.699210126931771}),
torch::tensor({-0.10765733357148573, -0.2913064115911077, 0.6933846184980238}),
torch::tensor({-0.07998325270832098, -0.4214920657406949, -0.33498346710568583}),
torch::tensor({-0.14255125794128246}),
},
{
torch::tensor({-0.1330316364215997, -0.3935294274545123, 0.031868002506474064, -0.11084376970207319, -0.10044018049327709, 0.9010178292324523}),
torch::tensor({0.1528090373898639, 0.2821438649147616, 1.205386760430569}),
torch::tensor({-1.713468945736039, -2.005156118333251, -3.179277317098681}),
torch::tensor({-4.088880710164363}),
},
{
torch::tensor({-0.12680536431075715, -0.3851833683510085, 0.03904239282451148, -0.10124023247972329, -0.0940339175875239, 0.9095833537085721}),
torch::tensor({0.17373114038388335, 0.3062644150893206, 1.227860257896027}),
torch::tensor({-1.7373016253033575, -2.03249513013615, -3.219138770050986}),
torch::tensor({-4.137292114732651}),
},
{
torch::tensor({-0.12132497250032251, -0.3778331039904768, 0.04532003261274458, -0.09283201036735936, -0.08851196423135906, 0.9169578599973672}),
torch::tensor({0.19215412593446154, 0.327399336185489, 1.2472665827949112}),
torch::tensor({-1.7582130574461698, -2.056484638774683, -3.2538206997778527}),
torch::tensor({-4.1792051428429815}),
},
{
torch::tensor({-0.11642629644462195, -0.371259551231067, 0.05090337508675891, -0.08534947911495054, -0.0836630803858705, 0.9234252017697436}),
torch::tensor({0.20862718059402438, 0.34622033891380727, 1.2643318695604322}),
torch::tensor({-1.7768873985691658, -2.0779002006617127, -3.2845633491961754}),
torch::tensor({-4.216211613875442}),
},
{
torch::tensor({-0.11199692998473046, -0.3653126926175739, 0.055928814442105936, -0.07861097082436985, -0.07934681806193093, 0.9291742365403652}),
torch::tensor({0.22352641732263448, 0.3631801602983304, 1.2795402134567078}),
torch::tensor({-1.7937606819069465, -2.0972428048607665, -3.312155416912881}),
torch::tensor({-4.249311719641607}),
},
{
torch::tensor({-0.10795430968582471, -0.35988233095293476, 0.06049618926552477, -0.07248348167814321, -0.0754620039452837, 0.9343410885098413}),
torch::tensor({0.2371281737399247, 0.3786104526446995, 1.2932411233560754}),
torch::tensor({-1.8091521591375157, -2.114879507429728, -3.33717091609132}),
torch::tensor({-4.279229817804015}),
},
{
torch::tensor({-0.10423597676568656, -0.35488509675405316, 0.06468078527531465, -0.06686669857198466, -0.07193342158854438, 0.9390270110811988}),
torch::tensor({0.2496415568519247, 0.39276148981107395, 1.305694968823765}),
torch::tensor({-1.8233027090805771, -2.131087403685365, -3.3600405655741863}),
torch::tensor({-4.306507101504117}),
},
{
torch::tensor({-0.10079352100780103, -0.35025636793356374, 0.06854077976598513, -0.06168309068553341, -0.06870368260858187, 0.9433092952856199}),
torch::tensor({0.2612286853643801, 0.40582680705425617, 1.3171007763490907}),
torch::tensor({-1.8363986223633926, -2.146080884869125, -3.381095779579462}),
torch::tensor({-4.331558400559893}),
},
{
torch::tensor({-0.09758865034515093, -0.3459450254605736, 0.07212206086787502, -0.05687149432043213, -0.06572803950285522, 0.9472482340146743}),
torch::tensor({0.2720178289132921, 0.4179591600527258, 1.327613999831149}),
torch::tensor({-1.8485869819731393, -2.160029278482207, -3.400596982161749}),
torch::tensor({-4.354708572999517}),
},
{
torch::tensor({-0.094590547958586, -0.3419099268142475, 0.07546146037543552, -0.0523828089910123, -0.06297095100755593, 0.9508917353335773}),
torch::tensor({0.2821122445127561, 0.4292812357169464, 1.337358307292812}),
torch::tensor({-1.8599859731758501, -2.173068681359289, -3.418752504544578}),
torch::tensor({-4.376216692846435}),
},
};
} // namespace expected_parameters