In [1]:
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import os
from scipy import interpolate
from scipy.interpolate import UnivariateSpline
import datetime

from covid.database import WorldCovidMatcher, CountryCollection
from sbir import ode, display_params, df_google

plt.close('all')
In [2]:
cm_c = WorldCovidMatcher('confirmed')                                                                                                                                           
cm_c.build_database()
cm_r = WorldCovidMatcher('recovered')                                                                                                                                           
cm_r.build_database()
cm_d = WorldCovidMatcher('deaths')                                                                                                                                           
cm_d.build_database()

country_list=["Italy", "France","Spain","Belgium", "Denmark","Norway",'Netherlands',
              "Switzerland","Greece",
              "United States","United Kingdom",'Sweden',
              "Luxembourg","Ireland",
              'Poland','Portugal',"Austria", "Germany"]

c = cm_c.get_cases(CountryCollection(country_list), normalized=True, cumulative=False)
r = cm_r.get_cases(CountryCollection(country_list), normalized=True, cumulative=False)
d = cm_d.get_cases(CountryCollection(country_list), normalized=True, cumulative=False)
for country in country_list:
    c[country].values[0]=0
    d[country].values[0]=0
    r[country].values[0]=0
In [3]:
fig_R0=plt.figure(figsize=(10/1.5,6/1.5))
for country in ['Germany','France','Spain','Italy','Belgium','Norway','Netherlands','Sweden']:#,'Portugal','Greece']:
    cases=c[country].copy()    
    offset_date = cases[cases.values >= 0.1].index[0]
    if country=='Luxembourg':
        offset_date+=datetime.timedelta(days=6)    
    if country=='Netherlands':
        offset_date+=datetime.timedelta(days=-2)    
    cases.index-=offset_date
    p1=np.polyfit(np.arange(0,25),
                  np.log(np.cumsum(1e-3+cases[(cases.index.days>0) & (cases.index.days<26)])),1)
    beta_minus_alpha=p1[0]
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=cases.index.days, y=cases, mode="markers", legendgroup="group2",name="confirmed"))    

    for mode in ['SBIR5','SBEIR']:
        t = np.linspace(0, 150, 200)
        result=ode.integrate(t,mode=mode,country=country,beta_minus_alpha=beta_minus_alpha)    
        fig.add_trace(go.Scatter(x=t, y=result['P']*1e5, mode="lines", legendgroup="group2",name=mode))    
    fig.update_yaxes(type="log")
    fig.update_layout(title='%s, t=0 corresponds to %s' % 
                      (country,offset_date.strftime('%Y %m %d')),
                     yaxis_title='Daily confirmed cases (per 100 000)',
                      xaxis_title='Time(days)')
    fig.show()


    
<Figure size 480x288 with 0 Axes>
In [4]:
%%javascript
//hack to fix export
require.config({
  paths: {
    d3: 'https://cdnjs.cloudflare.com/ajax/libs/d3/5.9.2/d3',
    jquery: 'https://code.jquery.com/jquery-3.4.1.min',
    plotly: 'https://cdn.plot.ly/plotly-latest.min'
  },

  shim: {
    plotly: {
      deps: ['d3', 'jquery'],
      exports: 'plotly'
    }
  }
});