diff options
-rw-r--r-- | COVID-19.png | bin | 42113 -> 40198 bytes | |||
-rw-r--r-- | SARS.png | bin | 0 -> 28850 bytes | |||
-rw-r--r-- | out/COVID-19-data.csv | 6 | ||||
-rw-r--r-- | out/COVID-19-prediction.csv | 154 | ||||
-rw-r--r-- | out/SARS-data.csv | 3 | ||||
-rw-r--r-- | out/SARS-prediction.csv | 10 | ||||
-rw-r--r-- | solver2.py | 54 |
7 files changed, 120 insertions, 107 deletions
diff --git a/COVID-19.png b/COVID-19.png Binary files differindex 4e7413e..b57bf1a 100644 --- a/COVID-19.png +++ b/COVID-19.png diff --git a/SARS.png b/SARS.png Binary files differnew file mode 100644 index 0000000..c166acf --- /dev/null +++ b/SARS.png diff --git a/out/COVID-19-data.csv b/out/COVID-19-data.csv index 7cfc7a0..7809e69 100644 --- a/out/COVID-19-data.csv +++ b/out/COVID-19-data.csv @@ -1,3 +1,3 @@ -Beta: 0.10176825591408271 -Gamma: 0.003925680724581611 -R0: 25.9237220380241
\ No newline at end of file +Beta: 0.12123601582761864 +Gamma: 0.004788532110372095 +R0: 25.317991616891955
\ No newline at end of file diff --git a/out/COVID-19-prediction.csv b/out/COVID-19-prediction.csv index ff5fb0b..adec705 100644 --- a/out/COVID-19-prediction.csv +++ b/out/COVID-19-prediction.csv @@ -1,78 +1,78 @@ ,Actual,S,I,R -1/22/20,0.004129051987767584,13499.000074068588,0.9999259314124879,0.0 -1/23/20,0.004129051987767584,13498.89317704777,1.1026991091086147,0.004123843121000688 -1/24/20,0.008258103975535168,13498.775294623783,1.2160338603561553,0.008671515861079792 -1/25/20,0.008258103975535168,13498.64529772263,1.341015689947722,0.013686587421359685 -1/26/20,0.02064525993883792,13498.501773727428,1.479002645211681,0.019223627359204125 -1/27/20,0.02064525993883792,13498.343097865096,1.6315566734055231,0.02534546149822523 -1/28/20,0.02064525993883792,13498.168236409509,1.7996717065348078,0.032091883956462224 -1/29/20,0.02064525993883792,13497.975885292719,1.9846015727694382,0.039513134511468394 -1/30/20,0.02064525993883792,13497.764413397454,2.1879144982525722,0.047672104294688565 -1/31/20,0.028903363914373086,13497.531862557107,2.41149310710062,0.05664433579145909 -2/1/20,0.03303241590214067,13497.275947555756,2.6575344214032457,0.06651802284100784 -2/2/20,0.03303241590214067,13496.99405612814,2.9285498612233654,0.07739401063645422 -2/3/20,0.045419571865443424,13496.683248959678,3.2273652445971495,0.08938579572480919 -2/4/20,0.045419571865443424,13496.340259686458,3.5571207875340214,0.10261952600697523 -2/5/20,0.045419571865443424,13495.961494895246,3.921271104016657,0.11723400073774633 -2/6/20,0.045419571865443424,13495.543034123473,4.3235852060009865,0.13338067052580804 -2/7/20,0.045419571865443424,13495.080629859249,4.768146503416194,0.1512236373337374 -2/8/20,0.045419571865443424,13494.569707541357,5.259352804164712,0.170939654478003 -2/9/20,0.045419571865443424,13494.00536555925,5.801916314122231,0.19271812662896506 -2/10/20,0.045419571865443424,13493.382375253052,6.400863637137696,0.2167611098108751 -2/11/20,0.049548623853211,13492.695180913564,7.0615357750333,0.24328331140187642 -2/12/20,0.049548623853211,13491.937899782262,7.789588127604491,0.27251209013400374 -2/13/20,0.053677675840978586,13491.104322051287,8.590990492619973,0.3046874560931832 -2/14/20,0.053677675840978586,13490.183011156712,9.476727935291171,0.3402609079974099 -2/15/20,0.053677675840978586,13489.163372539942,10.456979418076992,0.37964804198033986 -2/16/20,0.053677675840978586,13488.040725988436,11.536252355793748,0.4230216557711019 -2/17/20,0.053677675840978586,13486.808070134239,12.721283831018647,0.47064603474341726 -2/18/20,0.053677675840978586,13485.456057522571,14.021064513284362,0.5228779641451708 -2/19/20,0.053677675840978586,13483.972994611822,15.44683865907901,0.580166729098411 -2/20/20,0.053677675840978586,13482.344841773554,17.01210411184617,0.64305411459935 -2/21/20,0.061935779816513756,13480.555213292497,18.732612301984894,0.7121744055183638 -2/22/20,0.061935779816513756,13478.585377366551,20.62636824684967,0.7882543865999916 -2/23/20,0.061935779816513756,13476.414256106787,22.713630550750445,0.8721133424629367 -2/24/20,0.061935779816513756,13474.018425537448,25.016911404952634,0.964663057600066 -2/25/20,0.061935779816513756,13471.372115595945,27.560976587677093,1.0669078163784096 -2/26/20,0.061935779816513756,13468.447210132861,30.37284546410017,1.179944403039162 -2/27/20,0.06606483180428134,13465.213246911948,33.481790986353616,1.3049621016976805 -2/28/20,0.06606483180428134,13461.637417610133,36.91933969352468,1.4432426963434866 -2/29/20,0.099097247706422,13457.684567817505,40.719271711656056,1.596160470840266 -3/1/20,0.12387155963302751,13453.31719703733,44.91762075374588,1.7651822089258662 -3/2/20,0.21883975535168193,13448.49545868604,49.55267411974779,1.9518671942123016 -3/3/20,0.3014207951070336,13443.177160093244,54.66497269657083,2.157867210185746 -3/4/20,0.4294214067278287,13437.317762501714,60.297310958079514,2.3849265402065396 -3/5/20,0.7101969418960244,13430.870381067398,66.49473696509384,2.6348819675091875 -3/6/20,0.8960042813455656,13423.785784859409,73.30455236538921,2.9096627752023543 -3/7/20,1.3873614678899082,13415.993035589532,80.79446982966446,3.2124945808042553 -3/8/20,1.8580733944954126,13407.413026717277,89.04019184550404,3.5467814372193227 -3/9/20,2.122332721712538,13397.982321633168,98.10306252218177,3.914615844651356 -3/10/20,2.9233688073394495,13387.62389885365,108.05737845476976,4.318722691580334 -3/11/20,4.56260244648318,13376.247015299878,118.99051659186165,4.762468108260993 -3/12/20,6.428933944954128,13363.747206297705,131.0029342355727,5.249859466722829 -3/13/20,8.865074617737003,13350.00628557769,144.2081690415396,5.785545380770096 -3/14/20,11.796701529051987,13334.892345275097,158.73283901892086,6.374815705981806 -3/15/20,12.048573700305809,13318.259755929892,174.71664253039623,7.023601539711728 -3/16/20,17.783826911314982,13299.949166486744,192.31235829216723,7.738475221088386 -3/17/20,25.17070091743119,13279.787504295027,211.68584537395694,8.526650331015071 -3/18/20,36.63707828746177,13257.58797510882,233.01604319900989,9.395981692169823 -3/19/20,58.194858715596325,13233.150063086901,256.4949715440922,10.354965369005443 -3/20/20,80.11599571865443,13206.25953079276,282.32773053949165,11.412738667749492 -3/21/20,106.2198623853211,13176.688419194577,310.73250066901755,12.579080136404288 -3/22/20,138.8765345565749,13144.195047665253,341.94054277000055,13.864409564746897 -3/23/20,180.286796941896,13108.524013982378,376.19619803329323,15.279787984329166 -3/24/20,221.87873761467887,13069.406194328254,413.7568880032695,16.83691766847768 -3/25/20,271.6007816513761,13026.558743289883,454.89311457782486,18.54814213229378 -3/26/20,346.1632024464831,12979.685093858972,499.8884600083764,20.42644613265358 -3/27/20,419.7470379204893,12928.47456158081,549.0398554613624,22.485582957829788 -3/28/20,501.58897737003053,12872.561237871914,602.6941421259124,24.744620002174376 -3/29/20,581.8205865443425,12811.496145572517,661.2801353150604,27.223719112424398 -3/30/20,668.2333865443425,12744.817714929719,725.2391216290611,29.94316344122026 -3/31/20,776.9719706422018,12672.059066800972,795.0159397367895,32.92499346223956 -4/1/20,880.9827902140672,12592.748012654063,871.0589803757406,36.1930069701971 -4/2/20,1005.9031290519877,12506.407054567126,953.8201863520294,39.77275908084485 -4/3/20,1137.8924048929662,12412.553385228637,1043.7550525403915,43.69156223097204 -4/4/20,1275.2370611620795,12310.698887937413,1141.3226258841821,47.97848617840502 -4/5/20,1391.7589082568807,12200.350136602616,1246.9855053953772,52.664358002007376 -4/6/20,1513.9582018348622,12081.008395743751,1361.2098421545716,57.78176210167991 -4/7/20,1636.0171076452598,11952.169620490658,1484.465339310981,63.36504019836059 +1/22/20,0.01529051987767584,49999.0,1.0,0.0 +1/23/20,0.01529051987767584,49998.87142424206,1.1234972154720178,0.005078542461735381 +1/24/20,0.03058103975535168,49998.726971267446,1.2622445067142218,0.010784225844925973 +1/25/20,0.03058103975535168,49998.564682877775,1.418122720321673,0.017194401905434777 +1/26/20,0.0764525993883792,49998.38234997699,1.5932536885634114,0.024396334439772325 +1/27/20,0.0764525993883792,49998.1761616234,1.7912976312291053,0.032540745365268506 +1/28/20,0.0764525993883792,49997.94267086507,2.0155651686230516,0.04176396630732312 +1/29/20,0.0764525993883792,49997.68103501622,2.2668658997885043,0.052099083983799914 +1/30/20,0.0764525993883792,49997.389420767366,2.5469608746381214,0.06361835799887736 +1/31/20,0.10703363914373089,49997.064956030794,2.858608840222778,0.07643512898078467 +2/1/20,0.12232415902140673,49996.70372994068,3.2055662407315637,0.09070381858180188 +2/2/20,0.12232415902140673,49996.30079285303,3.592587217491785,0.10661992947825968 +2/3/20,0.16819571865443425,49995.85015634566,4.025423608968964,0.12442004537053958 +2/4/20,0.16819571865443425,49995.344793218246,4.510824950766836,0.14438183098307383 +2/5/20,0.16819571865443425,49994.77663749231,5.05653847562736,0.16682403206434543 +2/6/20,0.16819571865443425,49994.13658441118,5.671309113430702,0.1921064753868881 +2/7/20,0.16819571865443425,49993.41449044005,6.36487949119525,0.2206300687472865 +2/8/20,0.16819571865443425,49992.59917326595,7.1479899330776036,0.2528368009661757 +2/9/20,0.16819571865443425,49991.67841179774,8.032378460372584,0.2892097418882419 +2/10/20,0.16819571865443425,49990.6389461661,9.03078079151322,0.3302730423822217 +2/11/20,0.1834862385321101,49989.46647772359,10.156930342070767,0.3765919343409029 +2/12/20,0.1834862385321101,49988.145669044556,11.425558224754685,0.4287727306811233 +2/13/20,0.19877675840978593,49986.66014392524,12.85239324941266,0.4874628253437725 +2/14/20,0.19877675840978593,49984.99248738367,14.454161923030584,0.5533506932937899 +2/15/20,0.19877675840978593,49983.12424565975,16.248588449732576,0.6271658905201664 +2/16/20,0.19877675840978593,49981.02889523054,18.261139779188554,0.70996499027104 +2/17/20,0.19877675840978593,49978.651306317464,20.54472852896415,0.8039651535683541 +2/18/20,0.19877675840978593,49975.97895646455,23.11139937053732,0.909644164911142 +2/19/20,0.19877675840978593,49972.9967703992,25.97565285443824,1.027576746355153 +2/20/20,0.19877675840978593,49969.678659363395,29.162561937280955,1.1587786993212803 +2/21/20,0.2293577981651376,49965.98752111364,32.70777198176339,1.3047069045955584 +2/22/20,0.2293577981651376,49961.875239921,36.65750075666734,1.4672593223291648 +2/23/20,0.2293577981651376,49957.2826865711,41.068538436858454,1.6487749920384194 +2/24/20,0.2293577981651376,49952.13971836411,46.0082476032863,1.8520340326047855 +2/25/20,0.2293577981651376,49946.365179114735,51.55456324298426,2.0802576422748684 +2/26/20,0.2293577981651376,49939.86689915227,57.795992749069626,2.337108098660416 +2/27/20,0.24464831804281345,49932.54169532051,64.83161592074356,2.626688758738319 +2/28/20,0.24464831804281345,49924.275370977855,72.7710849632911,2.95354405885061 +2/29/20,0.3669724770642202,49914.94271599721,81.73462448808112,3.322659514704465 +3/1/20,0.4587155963302752,49904.40750676606,91.85303151256642,3.739461721372202 +3/2/20,0.8103975535168195,49892.522506186426,103.26767546028361,4.209818353291284 +3/3/20,1.1162079510703364,49879.12946367488,116.13049816085325,4.740038164264311 +3/4/20,1.5902140672782874,49864.05911516256,130.6040138499797,5.336870987459033 +3/5/20,2.6299694189602447,49847.13118309514,146.86130916945123,6.007507735408336 +3/6/20,3.318042813455657,49828.154376432845,165.08604316714002,6.759580400010251 +3/7/20,5.137614678899083,49806.92639065047,185.472447297002,7.601162052527955 +3/8/20,6.8807339449541285,49783.15387362768,208.30131583295304,8.544810539364274 +3/9/20,7.859327217125382,49756.46216012896,233.93144509123292,9.606394779802377 +3/10/20,10.825688073394495,49726.5796666417,262.62464902953786,10.79568432875744 +3/11/20,16.896024464831804,49693.16121368366,294.7130775011809,12.125708815156308 +3/12/20,23.807339449541285,49655.78361959187,330.60339369236084,13.612986715763734 +3/13/20,32.82874617737003,49613.94570052266,370.7767741221621,15.277525355182393 +3/14/20,43.68501529051988,49567.068270451586,415.7889086425552,17.14282090585288 +3/15/20,44.6177370030581,49514.49414117355,466.2700004383959,19.2358583880537 +3/16/20,65.85626911314985,49455.48812230267,522.9247660274261,21.587111669901272 +3/17/20,93.21100917431193,49389.237021272376,586.5324352602737,24.23054346734994 +3/18/20,135.67278287461772,49314.84964333535,657.9467513204519,27.20360534419196 +3/19/20,215.5045871559633,49231.35679156358,738.0959707243596,30.5472377120575 +3/20/20,296.6819571865443,49137.7112668483,827.9828633212821,34.30586983041465 +3/21/20,393.348623853211,49032.787867900035,928.68471229339,38.52741980656942 +3/22/20,514.2813455657492,48915.38339124859,1041.3533141557398,43.26329459566571 +3/23/20,667.6299694189603,48784.216631243034,1167.2149787562737,48.5683900006854 +3/24/20,821.651376146789,48637.92838005173,1307.5705292758196,54.5010906724482 +3/25/20,1005.7798165137615,48475.23815719861,1463.632752792264,61.12909000912159 +3/26/20,1281.8960244648317,48294.55190123027,1636.9083208388115,68.53977793091411 +3/27/20,1554.388379204893,48093.39034634198,1829.7869013768861,76.82275228112964 +3/28/20,1857.4617737003057,47869.19051439562,2044.730256305486,86.07922929888963 +3/29/20,2154.571865443425,47619.33475639337,2284.2426197884333,96.4226238181884 +3/30/20,2474.571865443425,47341.150752477726,2550.8706982543795,107.97854926789326 +3/31/20,2877.2477064220184,47031.91151193145,2847.2036703968,120.88481767174422 +4/1/20,3262.4159021406726,46688.835373177644,3175.8731871739956,135.29143964835436 +4/2/20,3725.0152905198775,46309.086003779696,3539.553371809094,151.36062441120953 +4/3/20,4213.792048929664,45889.772400441274,3940.960819790049,169.26677976866858 +4/4/20,4722.400611620795,45427.94888900639,4382.85459886964,189.1965121239631 +4/5/20,5153.899082568807,44920.61512445932,4868.036249065472,211.3486264751976 +4/6/20,5606.422018348624,44364.71609092467,5399.349782659975,235.93412641534957 +4/7/20,6058.4250764526,43757.14210166731,5979.681684200406,263.1762141322693 diff --git a/out/SARS-data.csv b/out/SARS-data.csv new file mode 100644 index 0000000..f98e7f5 --- /dev/null +++ b/out/SARS-data.csv @@ -0,0 +1,3 @@ +Beta: 1e-08 +Gamma: 0.01700912239686379 +R0: 5.879198095396057e-07
\ No newline at end of file diff --git a/out/SARS-prediction.csv b/out/SARS-prediction.csv new file mode 100644 index 0000000..2849d47 --- /dev/null +++ b/out/SARS-prediction.csv @@ -0,0 +1,10 @@ +,Actual,S,I,R +4/10/03,0.21148953068592058,13499.000074068588,0.9999259314124879,0.0 +4/11/03,0.2244162454873646,13499.000074058673,0.98306190629524,0.016864035031213732 +4/12/03,0.23479999999999998,13499.000074048927,0.9664822974727793,0.033443653600438765 +4/14/03,0.2521768953068592,13499.000074039344,0.9501823087970815,0.049743651858518506 +4/15/03,0.2610772563176895,13499.000074029924,0.9341572239618826,0.06576874611448988 +4/16/03,0.2687061371841155,13499.000074020661,0.9184024060302004,0.08152357330806109 +4/17/03,0.27485162454873646,13499.000074011556,0.9029132970915866,0.09701269135235983 +4/18/03,0.2877783393501805,13499.000074002603,0.8876854169726223,0.11224058042343865 +4/19/03,0.2877783393501805,13499.000073993802,0.872714359634142,0.12721164656305334 @@ -8,7 +8,6 @@ from scipy.optimize import minimize import argparse import os -# TODO - Parse arguments for different model options parser = argparse.ArgumentParser() parser.add_argument('--mode', '-m', dest = 'mode', help = 'change the mode of the model (SIR, Linear, ESIR, SEIR); default: SIR', default = 'SIR', choices = ['SIR', 'Linear', 'ESIR', 'SEIR']) @@ -20,16 +19,17 @@ parser.add_argument('--start', '-s', dest = 'start', default = '1/22/20', help = parser.add_argument('--end', '-e', dest = 'end', default = None, help = 'the date where the data stops (defaults to whereever the input data ends)') parser.add_argument('--incubation', '-i', dest = 'incubation_period', default = None, help = 'the incubation period of the disease (only applicable if using SIRE model; ignored otherwise); none by default') parser.add_argument('--predict', '-p', dest = 'prediction_range', default = None, help = 'the number of days to predict the course of the disease (defaults to None, meaning the model will not predict beyond the given data)') +parser.add_argument('--country', '-c', dest = 'country', default = 'US', help = 'the country that is being modeled (defaults to US)') +parser.add_argument('--population', '-P', dest = 'population', default = '10000', help = 'the population of the model (defaults to 10000)') args = parser.parse_args() -S_0 = 13500/13501 -I_0 = 1/13501 -R_0 = 0 # Both are equal to 0/13501 -E_0 = 0 # Both are equal to 0/13501 +S_0 = (int(args.population) - 1) / int(args.population) +I_0 = 1 / int(args.population) +R_0 = 0 +E_0 = 0 -# Running a model for 3.27 million population is quite hard, so here we've reduced the population to 13.5 thousand people, and modified -# the actual stats to match -correction_factor = 13502/3270000 if args.disease == 'COVID-19' else 1 +# Running a model for a million population is quite hard, so here we've reduced the population and modified the actual stats to match +correction_factor = int(args.population) / 3270000 if args.country == 'US' else int(args.population) / 63710000 if args.country == 'Hong_Kong' else 1 class Learner(object): def __init__(self, country): @@ -45,7 +45,7 @@ class Learner(object): if args.end != None: confirmed_sums = np.sum([reg.loc[args.start:args.end].values for reg in country_df.iloc], axis = 0) else: - confirmed_sums = np.sum([reg.loc[args.start:].values for reg in country_df.iloc], axis = 0) + confirmed_sums = np.sum([reg.loc[[args.start:]].values for reg in country_df.iloc], axis = 0) if args.end != None: new_data = pd.DataFrame(confirmed_sums, country_df.iloc[0].loc[args.start:args.end].index.tolist()) @@ -155,7 +155,7 @@ class Learner(object): beta, gamma = optimal.x print(f'Beta: {beta}, Gamma: {gamma}, R0: {beta/gamma}') new_index, extended_actual, prediction = self.predict(confirmed_data, beta = beta, gamma = gamma) - print(f'Predicted I: {prediction.y[1][-1] * 13500}, Actual I: {extended_actual[-1] * correction_factor}') + print(f'Predicted I: {prediction.y[1][-1] * int(args.population)}, Actual I: {extended_actual[-1] * correction_factor}') df = compose_df(prediction, extended_actual, correction_factor, new_index) with open(f'out/{args.disease}-data.csv', 'w+') as file: file.write(f'Beta: {beta}\nGamma: {gamma}\nR0: {beta/gamma}') @@ -170,7 +170,7 @@ class Learner(object): beta, gamma = optimal.x print(f'Beta: {beta}, Gamma: {gamma}, R0: {beta/gamma}') new_index, extended_actual, prediction = self.predict(confirmed_data, beta = beta, gamma = gamma) - print(f'Predicted I: {prediction.y[1][-1] * 13500}, Actual I: {extended_actual[-1] * correction_factor}') + print(f'Predicted I: {prediction.y[1][-1] * int(args.population)}, Actual I: {extended_actual[-1] * correction_factor}') df = compose_df(prediction, extended_actual, correction_factor, new_index) with open(f'out/{args.disease}-data.csv', 'w+') as file: file.write(f'Beta: {beta}\nGamma: {gamma}\nR0: {beta/gamma}') @@ -185,7 +185,7 @@ class Learner(object): beta, gamma, mu = optimal.x print(f'Beta: {beta}, Gamma: {gamma}, Mu: {mu} R0: {beta/(gamma + mu)}') new_index, extended_actual, prediction = self.predict(confirmed_data, beta = beta, gamma = gamma, mu = mu) - print(f'Predicted I: {prediction.y[1][-1] * 13500}, Actual I: {extended_actual[-1] * correction_factor}') + print(f'Predicted I: {prediction.y[1][-1] * int(args.population)}, Actual I: {extended_actual[-1] * correction_factor}') df = compose_df(prediction, extended_actual, correction_factor, new_index) with open(f'out/{args.disease}-data.csv', 'w+') as file: file.write(f'Beta: {beta}\nGamma: {gamma}\nMu: {mu}\nR0: {beta/(gamma + mu)}') @@ -202,7 +202,7 @@ class Learner(object): beta, gamma, mu, sigma = optimal.x print(f'Beta: {beta}, Gamma: {gamma}, Mu: {mu}, Sigma: {sigma} R0: {(beta * sigma)/((mu + gamma) * (mu + sigma))}') new_index, extended_actual, prediction = self.predict(confirmed_data, beta = beta, gamma = gamma, mu = mu) - print(f'Predicted I: {prediction.y[1][-1] * 13500}, Actual I: {extended_actual[-1] * correction_factor}') + print(f'Predicted I: {prediction.y[1][-1] * int(args.population)}, Actual I: {extended_actual[-1] * correction_factor}') df = compose_df(prediction, extended_actual, correction_factor, new_index) with open(f'out/{args.disease}-data.csv', 'w+') as file: file.write(f'Beta: {beta}\nGamma: {gamma}\nMu: {mu}\nSigma: {sigma}\nR0: {(beta * sigma)/((mu + gamma) * (mu + sigma))}') @@ -226,13 +226,13 @@ def compose_df(prediction, actual, correction_factor, index): if data == 'Actual': df_dict['Actual'] = filter_zeroes(actual * correction_factor) elif data == 'S': - df_dict['S'] = prediction.y[0] * 13500 + df_dict['S'] = prediction.y[0] * int(args.population) elif data == 'I': - df_dict['I'] = prediction.y[1] * 13500 + df_dict['I'] = prediction.y[1] * int(args.population) elif data == 'R': - df_dict['R'] = prediction.y[2] * 13500 + df_dict['R'] = prediction.y[2] * int(args.population) elif data == 'E': - df_dict['E'] = prediction.y[3] * 13500 + df_dict['E'] = prediction.y[3] * int(args.population) return pd.DataFrame(df_dict, index=index) @@ -247,8 +247,8 @@ def loss_linear(point, confirmed, recovered): R = y[2] return [-beta * S, beta * S - gamma * I, gamma * I] solution = solve_ivp(model, [0, size], [S_0,I_0,R_0], t_eval=np.arange(0, size, 1), vectorized=True) - sol_inf = np.sqrt(np.mean((solution.y[1] - (confirmed.values.flatten() * correction_factor/13500))**2)) - sol_rec = np.sqrt(np.mean((solution.y[2] - (recovered.values * correction_factor/13500))**2)) + sol_inf = np.sqrt(np.mean((solution.y[1] - (confirmed.values.flatten() * correction_factor/int(args.population)))**2)) + sol_rec = np.sqrt(np.mean((solution.y[2] - (recovered.values * correction_factor/int(args.population)))**2)) return sol_inf * 0.5 + sol_rec * 0.5 def loss_sir(point, confirmed, recovered): @@ -260,8 +260,8 @@ def loss_sir(point, confirmed, recovered): R = y[2] return [-beta * S * I, beta * S * I - gamma * I, gamma * I] solution = solve_ivp(model, [0, size], [S_0,I_0,R_0], t_eval=np.arange(0, size, 1), vectorized=True) - sol_inf = np.sqrt(np.mean((solution.y[1] - (confirmed.values.flatten() * correction_factor/13500))**2)) - sol_rec = np.sqrt(np.mean((solution.y[2] - (recovered.values * correction_factor/13500))**2)) + sol_inf = np.sqrt(np.mean((solution.y[1] - (confirmed.values.flatten() * correction_factor/int(args.population)))**2)) + sol_rec = np.sqrt(np.mean((solution.y[2] - (recovered.values * correction_factor/int(args.population)))**2)) return sol_inf * 0.5 + sol_rec * 0.5 def loss_esir(point, confirmed, recovered): @@ -273,8 +273,8 @@ def loss_esir(point, confirmed, recovered): R = y[2] return [mu - beta * S * I - mu * S, beta * S * I - gamma * I - mu * I, gamma * I - mu * R] solution = solve_ivp(model, [0, size], [S_0,I_0,R_0], t_eval=np.arange(0, size, 1), vectorized=True) - sol_inf = np.sqrt(np.mean((solution.y[1] - (confirmed.values.flatten() * correction_factor/13500))**2)) - sol_rec = np.sqrt(np.mean((solution.y[2] - (recovered.values * correction_factor/13500))**2)) + sol_inf = np.sqrt(np.mean((solution.y[1] - (confirmed.values.flatten() * correction_factor/int(args.population)))**2)) + sol_rec = np.sqrt(np.mean((solution.y[2] - (recovered.values * correction_factor/int(args.population)))**2)) return sol_inf * 0.5 + sol_rec * 0.5 def loss_seir(point, confirmed, recovered, exposed): @@ -287,10 +287,10 @@ def loss_seir(point, confirmed, recovered, exposed): E = y[3] return [mu - beta * S * I - mu * S, beta * S * I - sigma * E - mu * E, sigma * E * I - gamma * I - mu * I, gamma * I - mu * R] solution = solve_ivp(model, [0, size], [S_0,E_0,I_0,R_0], t_eval=np.arange(0, size, 1), vectorized=True) - sol_inf = np.sqrt(np.mean((solution.y[1] - (confirmed.values.flatten() * correction_factor/13500))**2)) - sol_rec = np.sqrt(np.mean((solution.y[2] - (recovered.values * correction_factor/13500))**2)) - sol_exp = np.sqrt(np.mean((solution.y[3] - (exposed.values * correction_factor/13500))**2)) + sol_inf = np.sqrt(np.mean((solution.y[1] - (confirmed.values.flatten() * correction_factor/int(args.population)))**2)) + sol_rec = np.sqrt(np.mean((solution.y[2] - (recovered.values * correction_factor/int(args.population)))**2)) + sol_exp = np.sqrt(np.mean((solution.y[3] - (exposed.values * correction_factor/int(args.population)))**2)) return sol_inf/3 + sol_rec/3 + sol_exp/3 -my_learner = Learner('Hong_Kong') +my_learner = Learner(args.country) my_learner.train() |