#!/usr/bin/env python3

# Created in 2020 by Ryan A. Colyer.
# This work is released with CC0 into the public domain.
# https://creativecommons.org/publicdomain/zero/1.0/

import datetime
import matplotlib
import numpy as np
import os
import pandas as pd
import pylab as plt
import re
import scipy.stats
import sys
import generate_html
matplotlib.use("TkCairo")

refresh_minutes = 30
data_grab = 'git clone "https://github.com/CSSEGISandData/COVID-19.git"'
data_refresh = 'cd COVID-19; git pull'
datadir = os.path.join('COVID-19', 'csse_covid_19_data',
    'csse_covid_19_daily_reports')
plotdir = 'plots'
# Compiled from multiple sources, with most from worldbank.org
country_popfile = 'country_populations_2018.csv'
# From census.gov
state_popfile = 'state_populations_2019.csv'
fixed_axes = True
fit_active = False

states_abbrevs = [('Alabama', 'AL'), ('Alaska', 'AK'), ('Arizona', 'AZ'),
  ('Arkansas', 'AR'), ('California', 'CA'), ('Colorado', 'CO'),
  ('Connecticut', 'CT'), ('Delaware', 'DE'), ('Florida', 'FL'),
  ('Georgia', 'GA'), ('Hawaii', 'HI'), ('Idaho', 'ID'), ('Illinois', 'IL'),
  ('Indiana', 'IN'), ('Iowa', 'IA'), ('Kansas', 'KS'), ('Kentucky', 'KY'),
  ('Louisiana', 'LA'), ('Maine', 'ME'), ('Maryland', 'MD'),
  ('Massachusetts', 'MA'), ('Michigan', 'MI'), ('Minnesota', 'MN'),
  ('Mississippi', 'MS'), ('Missouri', 'MO'), ('Montana', 'MT'),
  ('Nebraska', 'NE'), ('Nevada', 'NV'), ('New Hampshire', 'NH'),
  ('New Jersey', 'NJ'), ('New Mexico', 'NM'), ('New York', 'NY'),
  ('North Carolina', 'NC'), ('North Dakota', 'ND'), ('Ohio', 'OH'),
  ('Oklahoma', 'OK'), ('Oregon', 'OR'), ('Pennsylvania', 'PA'),
  ('Rhode Island', 'RI'), ('South Carolina', 'SC'), ('South Dakota', 'SD'),
  ('Tennessee', 'TN'), ('Texas', 'TX'), ('Utah', 'UT'), ('Vermont', 'VT'),
  ('Virginia', 'VA'), ('Washington', 'WA'), ('West Virginia', 'WV'),
  ('Wisconsin', 'WI'), ('Wyoming', 'WY'),
  ('District of Columbia', 'Washington DC'), ('Puerto Rico', 'Puerto Rico')]
country_remap = {
    'Aruba':[('Aruba', 'Netherlands')],
    'Cape Verde':['Cabo Verde'],
    'China':['Mainland China', 'Macao', 'Macao SAR', 'Macau'],
    'Czech Republic':['Czechia'],
    'Denmark':['Faroe Islands', 'Greenland'],
    'France':['Saint Barthelemy', 'Saint Martin', 'St. Martin',
      'French Guiana', 'Mayotte', 'Guadeloupe', 'Reunion', 'Martinique'],
    'Hong Kong':['Hong Kong SAR', ('Hong Kong', 'China')],
    'Iran':['Iran (Islamic Republic of)'],
    'Ireland':['Republic of Ireland'],
    'Ivory Coast':['Cote d\'Ivoire'],
    'Moldova':['Republic of Moldova'],
    'Myanmar':['Burma'],
    'Netherlands':['Curacao'],
    'Palestine':['occupied Palestinian territory',
      'West Bank and Gaza'],
    'Russia':['Russian Federation'],
    'Congo (Brazzaville)':['Republic of the Congo'],
    'South Korea':['Republic of Korea', 'Korea, South'],
    'St. Kitts and Nevis':['Saint Kitts and Nevis'],
    'Taiwan':['Taiwan*','Taipei and environs'],
    'Gambia':['Gambia, The', 'The Gambia'],
    'Bahamas':['Bahamas, The', 'The Bahamas'],
    'Timor-Leste':['East Timor'],
    'United Kingdom':['UK', 'Cayman Islands', 'Channel Islands',
      'Gibraltar', 'North Ireland'],
    'United States':['US', 'Guam', 'Puerto Rico'],
    'Vatican City':['Holy See'],
    'Vietnam':['Viet Nam']}
country_remaplist = [c for v in country_remap.values() for c in v]
countries_remove = {'Cruise Ship', 'Others', 'Guernsey', 'Jersey',
    'Diamond Princess', 'MS Zaandam', 'Summer Olympics 2020'}

class Progress():
  def __init__(self, total=0):
    self.count = 0
    self.perc = -1
    self.Set(total)

  def Set(self, total):
    self.total = total

  def Tick(self):
    if self.total > 0:
      self.count += 1
      perc = round(self.count*100.0/self.total)
      if perc > self.perc:
        self.perc = perc
        print('\x08\x08\x08\x08'+str(perc)+'%', end='', flush=True)


def LoadData():
  csvfiles = []
  for f in os.listdir(datadir):
    if len(f)>4 and f[-4:] == '.csv':
      csvfiles.append(f)
  csvfiles.sort()

  df = pd.DataFrame()
  for f in csvfiles:
    data = pd.read_csv(os.path.join(datadir, f), skipinitialspace=True)
    data = data.rename(columns={
      'Country_Region':'Country/Region',
      'Province_State':'Province/State',
      'Active':'BrokenActiveData'})
    data = data[
        ['Province/State','Country/Region','Confirmed','Deaths','Recovered']]
    data['Date'] = f[6:10]+'-'+f[0:5]
    df = df.append(data)

  return df

def PlotYRaw(x, y, ylab, pop=None, color=None):
  if pop:
    values = y/pop
  else:
    values = y
  values = np.array([v if v>0 else np.nan for v in values], dtype=np.float64)

  plt.plot(x, values, '.-', label=ylab, color=color)
  return values


def PlotY(df, xlab, ylab, pop=None, x_axis_labs=None, color=None):
  df_plot = df[[xlab, ylab]]
  data = df_plot.groupby(xlab).sum()

  if x_axis_labs:
    x = [x_axis_labs.index(dx) for dx in data.index]
  else:
    x = data.index

  values = PlotYRaw(x, data.values, ylab, pop, color)
  valuesvalid = values[~np.isnan(values)]
  if len(valuesvalid) == 0:
    minv = 1
    maxv = 1
  else:
    minv = np.min(valuesvalid)
    maxv = np.max(valuesvalid)
  return data.index, values, minv, maxv

def CalculateActive(df):
  active = []
  for c,d,r in zip(
      np.array(df['Confirmed']),
      np.array(df['Deaths']),
      np.array(df['Recovered'])):
    if np.isnan(c):
      c = 0
    if np.isnan(d):
      d = 0
    if np.isnan(r):
      r = 0
    active.append(c - d - r)

  return active


def CalculateDeathsPerDay(df):
  dfcalc = df[['Date', 'Deaths']]
  data = dfcalc.groupby('Date').sum()
  deaths = np.array(data.values).ravel()
  dperday = np.concatenate([[0], np.diff(deaths)])
  return dperday


def GrowthPerDay(active, fitlen):
  if len(active) < fitlen or any(active[-fitlen:] <= 0):
    return None, None
  if len(set(active)) == 1:
    return active[-fitlen:], 0
  x = np.arange(len(active)-fitlen, len(active))
  p = scipy.stats.linregress(x, np.log(active[-fitlen:]))
  fit = np.exp(p[0]*x + p[1])
  growth_perc = (np.exp(p[0])-1)*100
  if any(np.isnan(fit)):
    print(x, p, active, fit, growth_perc)
  return fit, growth_perc


def CasesLogPlot(df, title, subdir='', pop=None, axes=None):
  xlab = 'Date'
  ylab = 'Confirmed'
  fig = plt.figure(figsize=(6.4, 5.2))
  plt.rcParams.update({'font.size': 12})

  x_axis_labs = None
  if axes:
    x_axis_labs = axes[0]
    x_axis_labs_show = axes[1]
    plt.xticks(np.arange(0,len(x_axis_labs)), x_axis_labs_show, rotation=90)

  x_data_labs, _, minv, maxv = PlotY(df, xlab, 'Confirmed', pop, x_axis_labs)
  _, deaths, _, _ = PlotY(df, xlab, 'Deaths', pop, x_axis_labs)
  _, recov, _, _ = PlotY(df, xlab, 'Recovered', pop, x_axis_labs)
  if (len(recov)>0) and (recov[-1]>0):
    df.insert(len(df.columns), 'Active', CalculateActive(df))
    df_act = df[[xlab, 'Active']]
    active = df_act.groupby(xlab).sum().values.ravel()
    PlotY(df, xlab, 'Active', pop, x_axis_labs)

    if fit_active:
      fitlen = 4
      if pop:
        actfit = active / pop
      else:
        actfit = active
      fit, growth_perc = GrowthPerDay(actfit, fitlen)
      if fit is not None:
        label = str(round(growth_perc,1))+'% / day'
        if growth_perc > 0:
          label = '+'+label
        if axes:
          last = x_axis_labs.index(x_data_labs[-1])
          plt.plot(np.arange(last+1-fitlen, last+1), 
              fit, '--', color='k', linewidth=2, label=label)
        else:
          plt.plot(x_data_labs[-fitlen:], fit, '--', color='k', linewidth=2,
              label=label)

  if (len(deaths)>0) and (deaths[-1]>0):
    deaths_per_day = CalculateDeathsPerDay(df)
    x = [x_axis_labs.index(dx) for dx in x_data_labs]
    PlotYRaw(x, deaths_per_day, 'Deaths/Day', pop, color='purple')

  if axes:
    if pop:
      minv = axes[2]
      maxv = axes[3]
    plt.xlim(0,len(x_axis_labs))

  if axes is None:
    if len(x_data_labs) > 24:
      step = int(len(x_data_labs)/12)
      x_data_labs_show = x_data_labs[::-1][::step][::-1]
    plt.xticks(x_data_labs_show, rotation=90)

  plt.yscale('log')
  if pop:
    plt.ylabel('Fraction of Population [Log scale]')
    plt.axhline(2e-3, ls='--', lw=1.5, color='orange', label='Projection')
  else:
    plt.ylabel('Cases [Log scale]')
  if pop:
    if (maxv > minv):
      #plt.yticks(np.linspace(np.log10(minv), np.log10(maxv), 10))
      plt.ylim(minv, maxv)
  else:
    plt.ylim(bottom = 1)
  plt.title(title)
  plt.legend(loc='upper left', labelspacing=0.2)
  fig.set_tight_layout(True)

  destdir = plotdir
  if subdir:
    destdir = os.path.join(plotdir, subdir)
  filename = os.path.join(destdir, re.sub(' ', '_', title)+'.png')
  plt.savefig(filename)
  plt.close()
  
  return filename


def UpdatePlots():
  print('Generating plots.')
  df = LoadData()
  country_pops = pd.read_csv(country_popfile)
  state_pops = pd.read_csv(state_popfile)

  try:
    os.mkdir(plotdir)
  except FileExistsError:
    pass
  try:
    os.mkdir(os.path.join(plotdir, 'US'))
  except FileExistsError:
    pass

  countries = set(df['Country/Region']) | set(country_remap.keys())
  countries = countries - countries_remove
  countries = [c for c in countries if not ((c in country_remaplist) or
      (isinstance(c, float)))] # NaN from blank entry.
  countries.sort()

  progress = Progress(2+len(countries)+len(states_abbrevs))

  pop = country_pops[country_pops['country']=='World']
  pop = pop['population'].values[0]

  axes = None
  if fixed_axes:
    data = df.groupby('Date').sum()
    x_dates = data.index
    if len(x_dates) > 24:
      step = int(len(x_dates)/12)
      x_dates_show = ['']*len(x_dates)
      for i in range(len(x_dates)-1, -1, -step):
        x_dates_show[i] = x_dates[i]
    y_min = 1.0/pop
    y_max = 0.999e-1
    axes = (list(x_dates), x_dates_show, y_min, y_max)

  CasesLogPlot(df.copy(), 'World', pop=pop, axes=axes)
  progress.Tick()

  pop = country_pops[country_pops['country']=='World']
  pop = pop['population'].values[0]
  pop_c = country_pops[country_pops['country']=='China']
  pop_c = pop_c['population'].values[0]
  pop -= pop_c

  exc_c = 'China'
  df_mask = df['Country/Region']!=exc_c
  if exc_c in country_remap:
    for exc_r in country_remap[exc_c]:
      df_mask &= df['Country/Region']!=exc_r
  df_c = df[df_mask]
  CasesLogPlot(df_c, 'World outside China', pop=pop, axes=axes)
  progress.Tick()

  for c in countries:
    if len(country_pops[country_pops['country']==c]) != 1:
      print('\nMissing population data for '+c)

  country_files = []
  for c in countries:
    df_mask = df['Country/Region']==c
    if c in country_remap:
      for r in country_remap[c]:
        if isinstance(r, tuple):
          df_tmp = df[df['Country/Region']==r[1]]
          df_mask |= df['Province/State']==r[0]
        else:
          df_mask |= df['Country/Region']==r
    df_c = df[df_mask]

    try:
      pop = country_pops[country_pops['country']==c]
      pop = pop['population'].values[0]
    except:
      print('\nFailed to get population data for '+c)
      pop = None

    fname = CasesLogPlot(df_c, c, pop=pop, axes=axes)
    country_files.append(fname)
    progress.Tick()

  df_US = df[df['Country/Region']=='US']
  state_files = []
  for s in states_abbrevs:
    df_s = df[df['Province/State'].str.contains(s[0], case=False) |
             df['Province/State'].str.contains(s[1])]

    try:
      pop = state_pops[state_pops['state']==s[0]]
      pop = pop['population'].values[0]
    except:
      print('\nFailed to get population data for '+s[0])
      pop = None
      raise

    fname = CasesLogPlot(df_s, s[0], 'US', pop=pop, axes=axes)
    state_files.append(fname)
    progress.Tick()
  print()
  
  return country_files, state_files


def RefreshData():
  print('Checking data state.')
  now = datetime.datetime.utcnow()
  epoch = datetime.datetime.utcfromtimestamp(0)
  epochtime = (now - epoch).total_seconds()
  
  stale = True
  refreshed = False
  try:
    with open('last_read', 'r') as fr:
      lastepochtime = float(fr.read())
      if (epochtime - lastepochtime) < refresh_minutes*60:
        stale = False
  except FileNotFoundError:
    pass
  except ValueError:
    pass

  if not os.path.exists(datadir):
    print('Checking out new data.')
    os.system(data_grab)
    stale = False
    refreshed = True
  else:
    if not stale:
      print('Data refreshed within '+str(refresh_minutes)+' minutes.')

  if stale:
    print('Refreshing data.')
    os.system(data_refresh)
    refreshed = True

  if refreshed:
    with open('last_read', 'w') as fw:
      fw.write(str(epochtime))

  return refreshed

if __name__ == '__main__':
  args = []
  if len(sys.argv) > 1:
    args = sys.argv[1:]

  if not args or '-u' in args:
    RefreshData()

  if not args or '-p' in args:
    country_files, state_files = UpdatePlots()
    html = generate_html.GenerateHTML(country_files, state_files)
    with open('index.html', 'w') as fw:
      fw.write(html)

