diff options
Diffstat (limited to 'solver.py')
-rw-r--r-- | solver.py | 248 |
1 files changed, 248 insertions, 0 deletions
diff --git a/solver.py b/solver.py new file mode 100644 index 0000000..02f6b9c --- /dev/null +++ b/solver.py @@ -0,0 +1,248 @@ +#!/usr/bin/python +import numpy as np +import pandas as pd +from csv import reader +from csv import writer +from scipy.integrate import solve_ivp +from scipy.optimize import minimize +import matplotlib.pyplot as plt +from datetime import timedelta, datetime +import argparse +import sys +import json +import ssl +import urllib.request + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument( + '--countries', + action='store', + dest='countries', + help='Countries on CSV format. ' + + 'It must exact match the data names or you will get out of bonds error.', + metavar='COUNTRY_CSV', + type=str, + default="") + + parser.add_argument( + '--download-data', + action='store_true', + dest='download_data', + help='Download fresh data and then run', + default=False + ) + + parser.add_argument( + '--start-date', + required=False, + action='store', + dest='start_date', + help='Start date on MM/DD/YY format ... I know ...' + + 'It defaults to first data available 1/22/20', + metavar='START_DATE', + type=str, + default="1/22/20") + + parser.add_argument( + '--prediction-days', + required=False, + dest='predict_range', + help='Days to predict with the model. Defaults to 150', + metavar='PREDICT_RANGE', + type=int, + default=150) + + parser.add_argument( + '--S_0', + required=False, + dest='s_0', + help='S_0. Defaults to 100000', + metavar='S_0', + type=int, + default=100000) + + parser.add_argument( + '--I_0', + required=False, + dest='i_0', + help='I_0. Defaults to 2', + metavar='I_0', + type=int, + default=2) + + parser.add_argument( + '--R_0', + required=False, + dest='r_0', + help='R_0. Defaults to 0', + metavar='R_0', + type=int, + default=10) + + args = parser.parse_args() + + country_list = [] + if args.countries != "": + try: + countries_raw = args.countries + country_list = countries_raw.split(",") + except Exception: + sys.exit("QUIT: countries parameter is not on CSV format") + else: + sys.exit("QUIT: You must pass a country list on CSV format.") + + return (country_list, args.download_data, args.start_date, args.predict_range, args.s_0, args.i_0, args.r_0) + + +def sumCases_province(input_file, output_file): + with open(input_file, "r") as read_obj, open(output_file,'w',newline='') as write_obj: + csv_reader = reader(read_obj) + csv_writer = writer(write_obj) + + lines=[] + for line in csv_reader: + lines.append(line) + + i=0 + ix=0 + for i in range(0,len(lines[:])-1): + if lines[i][1]==lines[i+1][1]: + if ix==0: + ix=i + lines[ix][4:] = np.asfarray(lines[ix][4:],float)+np.asfarray(lines[i+1][4:] ,float) + else: + if not ix==0: + lines[ix][0]="" + csv_writer.writerow(lines[ix]) + ix=0 + else: + csv_writer.writerow(lines[i]) + i+=1 + + +def download_data(url_dictionary): + #Lets download the files + for url_title in url_dictionary.keys(): + urllib.request.urlretrieve(url_dictionary[url_title], "./data/" + url_title) + + +def load_json(json_file_str): + # Loads JSON into a dictionary or quits the program if it cannot. + try: + with open(json_file_str, "r") as json_file: + json_variable = json.load(json_file) + return json_variable + except Exception: + sys.exit("Cannot open JSON file: " + json_file_str) + + +class Learner(object): + def __init__(self, country, loss, start_date, predict_range,s_0, i_0, r_0): + self.country = country + self.loss = loss + self.start_date = start_date + self.predict_range = predict_range + self.s_0 = s_0 + self.i_0 = i_0 + self.r_0 = r_0 + + + def load_confirmed(self, country): + df = pd.read_csv('data/time_series_19-covid-Confirmed-country.csv') + country_df = df[df['Country/Region'] == country] + return country_df.iloc[0].loc[self.start_date:] + + + def load_recovered(self, country): + df = pd.read_csv('data/time_series_19-covid-Recovered-country.csv') + country_df = df[df['Country/Region'] == country] + return country_df.iloc[0].loc[self.start_date:] + + + def load_dead(self, country): + df = pd.read_csv('data/time_series_19-covid-Deaths-country.csv') + country_df = df[df['Country/Region'] == country] + return country_df.iloc[0].loc[self.start_date:] + + + def extend_index(self, index, new_size): + values = index.values + current = datetime.strptime(index[-1], '%m/%d/%y') + while len(values) < new_size: + current = current + timedelta(days=1) + values = np.append(values, datetime.strftime(current, '%m/%d/%y')) + return values + + def predict(self, beta, gamma, data, recovered, death, country, s_0, i_0, r_0): + new_index = self.extend_index(data.index, self.predict_range) + size = len(new_index) + def SIR(t, y): + S = y[0] + I = y[1] + R = y[2] + return [-beta*S*I, beta*S*I-gamma*I, gamma*I] + extended_actual = np.concatenate((data.values, [None] * (size - len(data.values)))) + extended_recovered = np.concatenate((recovered.values, [None] * (size - len(recovered.values)))) + extended_death = np.concatenate((death.values, [None] * (size - len(death.values)))) + return new_index, extended_actual, extended_recovered, extended_death, solve_ivp(SIR, [0, size], [s_0,i_0,r_0], t_eval=np.arange(0, size, 1)) + + + def train(self): + recovered = self.load_recovered(self.country) + death = self.load_dead(self.country) + data = (self.load_confirmed(self.country) - recovered - death) + + optimal = minimize(loss, [0.001, 0.001], args=(data, recovered, self.s_0, self.i_0, self.r_0), method='L-BFGS-B', bounds=[(0.00000001, 0.4), (0.00000001, 0.4)]) + print(optimal) + beta, gamma = optimal.x + new_index, extended_actual, extended_recovered, extended_death, prediction = self.predict(beta, gamma, data, recovered, death, self.country, self.s_0, self.i_0, self.r_0) + df = pd.DataFrame({'Infected data': extended_actual, 'Recovered data': extended_recovered, 'Death data': extended_death, 'Susceptible': prediction.y[0], 'Infected': prediction.y[1], 'Recovered': prediction.y[2]}, index=new_index) + fig, ax = plt.subplots(figsize=(15, 10)) + ax.set_title(self.country) + df.plot(ax=ax) + print(f"country={self.country}, beta={beta:.8f}, gamma={gamma:.8f}, r_0:{(beta/gamma):.8f}") + fig.savefig(f"{self.country}.png") + + +def loss(point, data, recovered, s_0, i_0, r_0): + size = len(data) + beta, gamma = point + def SIR(t, y): + S = y[0] + I = y[1] + R = y[2] + return [-beta*S*I, beta*S*I-gamma*I, gamma*I] + solution = solve_ivp(SIR, [0, size], [s_0,i_0,r_0], t_eval=np.arange(0, size, 1), vectorized=True) + l1 = np.sqrt(np.mean((solution.y[1] - data)**2)) + l2 = np.sqrt(np.mean((solution.y[2] - recovered)**2)) + alpha = 0.1 + return alpha * l1 + (1 - alpha) * l2 + + +def main(): + + countries, download, startdate, predict_range , s_0, i_0, r_0 = parse_arguments() + + if download: + data_d = load_json("./data_url.json") + download_data(data_d) + + sumCases_province('data/time_series_19-covid-Confirmed.csv', 'data/time_series_19-covid-Confirmed-country.csv') + sumCases_province('data/time_series_19-covid-Recovered.csv', 'data/time_series_19-covid-Recovered-country.csv') + sumCases_province('data/time_series_19-covid-Deaths.csv', 'data/time_series_19-covid-Deaths-country.csv') + + for country in countries: + learner = Learner(country, loss, startdate, predict_range, s_0, i_0, r_0) + #try: + learner.train() + #except BaseException: + # print('WARNING: Problem processing ' + str(country) + + # '. Be sure it exists in the data exactly as you entry it.' + + # ' Also check date format if you passed it as parameter.') + + +if __name__ == '__main__': + main() |