import pandas as pd
import geopandas
import seaborn as sns
import matplotlib.pyplot as plt
from IPython.display import HTML, display


sns.set()
plt.xkcd()

geo_df = geopandas.read_file("https://raw.githubusercontent.com/plotly/datasets/master/geojson-counties-fips.json")

df = pd.read_csv('https://data.ca.gov/dataset/590188d5-8545-4c93-a9a0-e230f0db7290/resource/926fd08f-cc91-4828-af38-bd45de97f8c3/download/statewide_cases.csv')
df.head()
county totalcountconfirmed totalcountdeaths newcountconfirmed newcountdeaths date
0 Santa Clara 151.0 6.0 151 6 2020-03-18
1 Santa Clara 183.0 8.0 32 2 2020-03-19
2 Santa Clara 246.0 8.0 63 0 2020-03-20
3 Santa Clara 269.0 10.0 23 2 2020-03-21
4 Santa Clara 284.0 13.0 15 3 2020-03-22
geo_df.head()
id GEO_ID STATE COUNTY NAME LSAD CENSUSAREA geometry
0 01001 0500000US01001 01 001 Autauga County 594.436 POLYGON ((-86.49677 32.34444, -86.71790 32.402...
1 01009 0500000US01009 01 009 Blount County 644.776 POLYGON ((-86.57780 33.76532, -86.75914 33.840...
2 01017 0500000US01017 01 017 Chambers County 596.531 POLYGON ((-85.18413 32.87053, -85.12342 32.772...
3 01021 0500000US01021 01 021 Chilton County 692.854 POLYGON ((-86.51734 33.02057, -86.51596 32.929...
4 01033 0500000US01033 01 033 Colbert County 592.619 POLYGON ((-88.13999 34.58170, -88.13925 34.587...
merged_df = pd.merge(df, geo_df, left_on='county', right_on='NAME', how='left')
merged_df.shape
(22727, 14)
merged_df.head()
county totalcountconfirmed totalcountdeaths newcountconfirmed newcountdeaths date id GEO_ID STATE COUNTY NAME LSAD CENSUSAREA geometry
0 Santa Clara 151.0 6.0 151 6 2020-03-18 06085 0500000US06085 06 085 Santa Clara County 1290.1 POLYGON ((-121.65404 36.95058, -121.75760 37.0...
1 Santa Clara 183.0 8.0 32 2 2020-03-19 06085 0500000US06085 06 085 Santa Clara County 1290.1 POLYGON ((-121.65404 36.95058, -121.75760 37.0...
2 Santa Clara 246.0 8.0 63 0 2020-03-20 06085 0500000US06085 06 085 Santa Clara County 1290.1 POLYGON ((-121.65404 36.95058, -121.75760 37.0...
3 Santa Clara 269.0 10.0 23 2 2020-03-21 06085 0500000US06085 06 085 Santa Clara County 1290.1 POLYGON ((-121.65404 36.95058, -121.75760 37.0...
4 Santa Clara 284.0 13.0 15 3 2020-03-22 06085 0500000US06085 06 085 Santa Clara County 1290.1 POLYGON ((-121.65404 36.95058, -121.75760 37.0...
from matplotlib.animation import FuncAnimation


plot_data_frames = []

for g, gdf in merged_df[merged_df['STATE'] == '06'].groupby(['date']):
    plot_data_frames.append(gdf.sort_values('totalcountconfirmed')[-10:])

    
counties = sorted(merged_df[merged_df['STATE'] == '06']['county'].unique())
colors = {county: color for county, color in zip(counties, sns.color_palette("tab20", n_colors=len(counties)))}

def get_colors(counties):
    return [colors[c] for c in counties]

    

def nice_axes(ax):
    ax.tick_params(labelsize=8, length=0)
    ax.set_axisbelow(True)
    [spine.set_visible(False) for spine in ax.spines.values()]
    

def init():
    ax.clear()
    nice_axes(ax)
#     ax.set_ylim(.2, 6.8)

def update(i):
    for bar in ax.containers:
        bar.remove()

    tmp_df = plot_data_frames[i]
    ax.barh(
        y=[n+0.5 for n in range(10)], 
        width=tmp_df['totalcountconfirmed'], 
        tick_label=tmp_df['county'],
        color=get_colors(tmp_df['county'].values)
    )
    
    date = tmp_df.iloc[0]['date']
    
    ax.set_title(f'COVID-19 Total Cases by County - {date}')
    
fig = plt.Figure(figsize=(12, 8))
ax = fig.add_subplot()
anim = FuncAnimation(
    fig=fig, 
    func=update, 
    init_func=init, 
    frames=len(plot_data_frames),
    interval=100, 
    repeat=False
)

display(HTML(anim.to_jshtml()))
</input>
from matplotlib.animation import FuncAnimation


plot_data_frames = []

for g, gdf in merged_df[merged_df['STATE'] == '06'].groupby(['date']):
    plot_data_frames.append(gdf.sort_values('newcountconfirmed')[-10:])

    
counties = sorted(merged_df[merged_df['STATE'] == '06']['county'].unique())
colors = {county: color for county, color in zip(counties, sns.color_palette("tab20", n_colors=len(counties)))}

def get_colors(counties):
    return [colors[c] for c in counties]

    

def nice_axes(ax):
    ax.tick_params(labelsize=8, length=0)
    ax.set_axisbelow(True)
    [spine.set_visible(False) for spine in ax.spines.values()]
    

def init():
    ax.clear()
    nice_axes(ax)
#     ax.set_ylim(.2, 6.8)

def update(i):
    for bar in ax.containers:
        bar.remove()

    tmp_df = plot_data_frames[i]
    ax.barh(
        y=[n+0.5 for n in range(10)], 
        width=tmp_df['newcountconfirmed'], 
        tick_label=tmp_df['county'],
        color=get_colors(tmp_df['county'].values)
    )
    
    date = tmp_df.iloc[0]['date']
    
    ax.set_title(f'COVID-19 New Cases by County - {date}', fontsize='small')
    
fig = plt.Figure(figsize=(12, 8))
ax = fig.add_subplot()
anim = FuncAnimation(
    fig=fig, 
    func=update, 
    init_func=init, 
    frames=len(plot_data_frames),
    interval=100, 
    repeat=False
)


display(HTML(anim.to_jshtml()))
</input>