aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--solver.py248
1 files changed, 248 insertions, 0 deletions
diff --git a/solver.py b/solver.py
new file mode 100644
index 0000000..02f6b9c
--- /dev/null
+++ b/solver.py
@@ -0,0 +1,248 @@
+#!/usr/bin/python
+import numpy as np
+import pandas as pd
+from csv import reader
+from csv import writer
+from scipy.integrate import solve_ivp
+from scipy.optimize import minimize
+import matplotlib.pyplot as plt
+from datetime import timedelta, datetime
+import argparse
+import sys
+import json
+import ssl
+import urllib.request
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ '--countries',
+ action='store',
+ dest='countries',
+ help='Countries on CSV format. ' +
+ 'It must exact match the data names or you will get out of bonds error.',
+ metavar='COUNTRY_CSV',
+ type=str,
+ default="")
+
+ parser.add_argument(
+ '--download-data',
+ action='store_true',
+ dest='download_data',
+ help='Download fresh data and then run',
+ default=False
+ )
+
+ parser.add_argument(
+ '--start-date',
+ required=False,
+ action='store',
+ dest='start_date',
+ help='Start date on MM/DD/YY format ... I know ...' +
+ 'It defaults to first data available 1/22/20',
+ metavar='START_DATE',
+ type=str,
+ default="1/22/20")
+
+ parser.add_argument(
+ '--prediction-days',
+ required=False,
+ dest='predict_range',
+ help='Days to predict with the model. Defaults to 150',
+ metavar='PREDICT_RANGE',
+ type=int,
+ default=150)
+
+ parser.add_argument(
+ '--S_0',
+ required=False,
+ dest='s_0',
+ help='S_0. Defaults to 100000',
+ metavar='S_0',
+ type=int,
+ default=100000)
+
+ parser.add_argument(
+ '--I_0',
+ required=False,
+ dest='i_0',
+ help='I_0. Defaults to 2',
+ metavar='I_0',
+ type=int,
+ default=2)
+
+ parser.add_argument(
+ '--R_0',
+ required=False,
+ dest='r_0',
+ help='R_0. Defaults to 0',
+ metavar='R_0',
+ type=int,
+ default=10)
+
+ args = parser.parse_args()
+
+ country_list = []
+ if args.countries != "":
+ try:
+ countries_raw = args.countries
+ country_list = countries_raw.split(",")
+ except Exception:
+ sys.exit("QUIT: countries parameter is not on CSV format")
+ else:
+ sys.exit("QUIT: You must pass a country list on CSV format.")
+
+ return (country_list, args.download_data, args.start_date, args.predict_range, args.s_0, args.i_0, args.r_0)
+
+
+def sumCases_province(input_file, output_file):
+ with open(input_file, "r") as read_obj, open(output_file,'w',newline='') as write_obj:
+ csv_reader = reader(read_obj)
+ csv_writer = writer(write_obj)
+
+ lines=[]
+ for line in csv_reader:
+ lines.append(line)
+
+ i=0
+ ix=0
+ for i in range(0,len(lines[:])-1):
+ if lines[i][1]==lines[i+1][1]:
+ if ix==0:
+ ix=i
+ lines[ix][4:] = np.asfarray(lines[ix][4:],float)+np.asfarray(lines[i+1][4:] ,float)
+ else:
+ if not ix==0:
+ lines[ix][0]=""
+ csv_writer.writerow(lines[ix])
+ ix=0
+ else:
+ csv_writer.writerow(lines[i])
+ i+=1
+
+
+def download_data(url_dictionary):
+ #Lets download the files
+ for url_title in url_dictionary.keys():
+ urllib.request.urlretrieve(url_dictionary[url_title], "./data/" + url_title)
+
+
+def load_json(json_file_str):
+ # Loads JSON into a dictionary or quits the program if it cannot.
+ try:
+ with open(json_file_str, "r") as json_file:
+ json_variable = json.load(json_file)
+ return json_variable
+ except Exception:
+ sys.exit("Cannot open JSON file: " + json_file_str)
+
+
+class Learner(object):
+ def __init__(self, country, loss, start_date, predict_range,s_0, i_0, r_0):
+ self.country = country
+ self.loss = loss
+ self.start_date = start_date
+ self.predict_range = predict_range
+ self.s_0 = s_0
+ self.i_0 = i_0
+ self.r_0 = r_0
+
+
+ def load_confirmed(self, country):
+ df = pd.read_csv('data/time_series_19-covid-Confirmed-country.csv')
+ country_df = df[df['Country/Region'] == country]
+ return country_df.iloc[0].loc[self.start_date:]
+
+
+ def load_recovered(self, country):
+ df = pd.read_csv('data/time_series_19-covid-Recovered-country.csv')
+ country_df = df[df['Country/Region'] == country]
+ return country_df.iloc[0].loc[self.start_date:]
+
+
+ def load_dead(self, country):
+ df = pd.read_csv('data/time_series_19-covid-Deaths-country.csv')
+ country_df = df[df['Country/Region'] == country]
+ return country_df.iloc[0].loc[self.start_date:]
+
+
+ def extend_index(self, index, new_size):
+ values = index.values
+ current = datetime.strptime(index[-1], '%m/%d/%y')
+ while len(values) < new_size:
+ current = current + timedelta(days=1)
+ values = np.append(values, datetime.strftime(current, '%m/%d/%y'))
+ return values
+
+ def predict(self, beta, gamma, data, recovered, death, country, s_0, i_0, r_0):
+ new_index = self.extend_index(data.index, self.predict_range)
+ size = len(new_index)
+ def SIR(t, y):
+ S = y[0]
+ I = y[1]
+ R = y[2]
+ return [-beta*S*I, beta*S*I-gamma*I, gamma*I]
+ extended_actual = np.concatenate((data.values, [None] * (size - len(data.values))))
+ extended_recovered = np.concatenate((recovered.values, [None] * (size - len(recovered.values))))
+ extended_death = np.concatenate((death.values, [None] * (size - len(death.values))))
+ return new_index, extended_actual, extended_recovered, extended_death, solve_ivp(SIR, [0, size], [s_0,i_0,r_0], t_eval=np.arange(0, size, 1))
+
+
+ def train(self):
+ recovered = self.load_recovered(self.country)
+ death = self.load_dead(self.country)
+ data = (self.load_confirmed(self.country) - recovered - death)
+
+ optimal = minimize(loss, [0.001, 0.001], args=(data, recovered, self.s_0, self.i_0, self.r_0), method='L-BFGS-B', bounds=[(0.00000001, 0.4), (0.00000001, 0.4)])
+ print(optimal)
+ beta, gamma = optimal.x
+ new_index, extended_actual, extended_recovered, extended_death, prediction = self.predict(beta, gamma, data, recovered, death, self.country, self.s_0, self.i_0, self.r_0)
+ df = pd.DataFrame({'Infected data': extended_actual, 'Recovered data': extended_recovered, 'Death data': extended_death, 'Susceptible': prediction.y[0], 'Infected': prediction.y[1], 'Recovered': prediction.y[2]}, index=new_index)
+ fig, ax = plt.subplots(figsize=(15, 10))
+ ax.set_title(self.country)
+ df.plot(ax=ax)
+ print(f"country={self.country}, beta={beta:.8f}, gamma={gamma:.8f}, r_0:{(beta/gamma):.8f}")
+ fig.savefig(f"{self.country}.png")
+
+
+def loss(point, data, recovered, s_0, i_0, r_0):
+ size = len(data)
+ beta, gamma = point
+ def SIR(t, y):
+ S = y[0]
+ I = y[1]
+ R = y[2]
+ return [-beta*S*I, beta*S*I-gamma*I, gamma*I]
+ solution = solve_ivp(SIR, [0, size], [s_0,i_0,r_0], t_eval=np.arange(0, size, 1), vectorized=True)
+ l1 = np.sqrt(np.mean((solution.y[1] - data)**2))
+ l2 = np.sqrt(np.mean((solution.y[2] - recovered)**2))
+ alpha = 0.1
+ return alpha * l1 + (1 - alpha) * l2
+
+
+def main():
+
+ countries, download, startdate, predict_range , s_0, i_0, r_0 = parse_arguments()
+
+ if download:
+ data_d = load_json("./data_url.json")
+ download_data(data_d)
+
+ sumCases_province('data/time_series_19-covid-Confirmed.csv', 'data/time_series_19-covid-Confirmed-country.csv')
+ sumCases_province('data/time_series_19-covid-Recovered.csv', 'data/time_series_19-covid-Recovered-country.csv')
+ sumCases_province('data/time_series_19-covid-Deaths.csv', 'data/time_series_19-covid-Deaths-country.csv')
+
+ for country in countries:
+ learner = Learner(country, loss, startdate, predict_range, s_0, i_0, r_0)
+ #try:
+ learner.train()
+ #except BaseException:
+ # print('WARNING: Problem processing ' + str(country) +
+ # '. Be sure it exists in the data exactly as you entry it.' +
+ # ' Also check date format if you passed it as parameter.')
+
+
+if __name__ == '__main__':
+ main()