Hoy vamos a aprender cómo crear animaciones en Python, para hacer que nuestras visualizaciones sean mucho más impactantes y podamos dar más información de una forma visual e impactante. En este post, aprenderás a crear todo tipo de animaciones en Python desde animaciones sencillas a gráficos animados como barchartraces. ¿Te suena interesante? ¡Pues vamos a ello!
Funcionamiento de las animaciones en Python
Para crear nuestras animaciones usaremos la función FuncAnimation dentro de matplolib. Un aspecto fundamental para poder crear nuestras animaciones es entender que este paquete no crea animaciones, sino que simplemente se limita a crear animaciones a partir de una serie de gráficos que le pasemos.
Esto es algo muy importante, ya que es un enfoque muy diferente al de otros paquetes como el de gganimate de R (si no sabes cómo funciona, aquí tienes el tutorial).
De hecho, para crear animaciones en Python usando FuncAnimation simplemente debes pasar una función que tiene como valor de entrada un número que hace referencia a un frame y devuelve el gráfico correspondiente a ese frame.
Esto hace que para crear animaciones en Python debamos preparar muy bien los datos. Veámos cómo hacerlo.
Estructura de datos para crear animaciones en Python
Aunque se pueden crear gráficos partiendo de datos con una forma muy diferente, en mi opinión, para que sea más sencillo graficar los datos deben estar en formato tidy, es decir:
- Cada variable debe estar en una columna.
- Cada observación de esa variable debe ser una fila diferente.
Veamos un ejemplo con el dataset gapminder, que es el que usaremos como ejemplo:
gapminder.head()
country continent year lifeExp pop gdpPercap
0 Afghanistan Asia 1952 28.801 8425333 779.445314
1 Afghanistan Asia 1957 30.332 9240934 820.853030
2 Afghanistan Asia 1962 31.997 10267083 853.100710
3 Afghanistan Asia 1967 34.020 11537966 836.197138
4 Afghanistan Asia 1972 36.088 13079460 739.981106
Ahora que ya sabemos cómo deben estar los datos para crear una animación en Python, vamos a ver cómo crear diferentes animaciones en Python!
Cómo crear animaciones en Python
Para crear animaciones en Python usaremos las funciones animation del módulo matplotlib. Por tanto, crear una animación es muy sencillo y parecido a crear gráficos con matplotlib. Simplemente debemos crear dos cuestiones:
fig: es el objeto que utilizaremos para pintar nuestro gráfico.func: es una función que debe de devolver el estado de la animación para cada frame. Básicamente lo que debemos hacer es crear una función que devuelva todos los gráficos. Siguiendo el ejemplo de la animación de line chart comentada anteriormente, la función debe devolver, en la primera iteración un linechart con el primer año, en la segunda interación un linechart con los dos primeros años y así para todos los años.interval: es el delay en milisegundos entre los diferentes frames de la animación.frames: número de imágenes en las que se va a basar el gráfico. Esto dependerá de cuántos “estados” tenga la animación. Si tenemos una animación con datos en 5 estados diferentes (imaginemos, 5 años), el número de frames será 5, mientras que si tenemos datos de 100 años, el número de frames será 100.
Con estos tres argumentos podemos crear todo tipo de animaciones. Ahora bien, esto puede ser algo complejo (sobre todo la parte del update), así que yo siempre recomendaría primero crear el gráfico que nosotros queremos y, a partir de eso, generar la animación.
En cualquier caso, ya contamos con todo lo básico, así que, ¡veámos cómo crear animaciones en Python!
Cómo crear animación de líneas
Como decía, lo más fácil para crear una animación es primero crear un gráfico que se parezca a lo que nosotros queremos animar. En este caso es muy sencillo, simplemente debemos crear un linechart dell Pib per Cápita para los países España, Italia y Estados Unidos.
Ahora que ya tenemos nuestro gráfico, para crear una animación de líneas, simplemente tendremos que crear una función que, para cada iteración, cree la gráfica de línea pero para los datos que tengamos disponibles.
De esta forma, en la primera iteración la gráfica de línea deberá crear solo un punto para el año 1952, en la segunda iteración creará la gráfica con los dos primeros puntos (1952, 1957), y así hasta completar toda la gráfica.
Por suerte, crear esta iteración habiendo creado ya el gráfico es bastante sencillo, ya que simplemente deberemos utilizar los índices para definir los datos que el gráfico debe coger.
from matplotlib import animation
countries_plot = ['Spain', 'Italy', 'United States']
linechart_plot = gapminder.loc[gapminder['country'].isin(countries_plot), :]
# Define colors
colors = ['red', 'green', 'blue']
fig, ax = plt.subplots()
def update_linechart(i):
for j in range(len(colors)):
country = countries_plot[j]
color = colors[j]
data = linechart_plot.loc[linechart_plot['country'] == country,:]
ax.plot(data.year[:i], data.gdpPercap[:i], color)
Con esto, ya hemos creado todo lo que necesitamos para nuestra animación. Ahora simplemente la tenemos que llamar usando la función FuncAnimation que he explicado previamente. En este tensido, de cara
num_frames = len(linechart_plot['year'].unique())
anim = animation.FuncAnimation(fig, update_linechart, frames = num_frames)
anim.save('linechart.gif')
¡Ya tenemos nuestra animación de linechart creada con Python! Sencillo, ¿verdad? Ahora sigamos viendo cómo crear animaciones de barcharts!
Cómo crear una animación de barchart en Python
Una buena práctica para que crear nuestro barchart (y todas las animaciones más allá de linecharts) sea más sencillo es filtrar los datos dentro de la propia función de iteración. Esto nos facilitará mucho la creación de animaciones y hará que entenderlas sea mucho más fácil.
De todos modos, como siempre que queremos crear una animación, debemos empezar por graficar lo que queremos llegar a conseguir. Así pues, en este caso voy a crear barchart muy simple en el que veamos cómo ha evolucionado el Pib per Cápita de diferentes países.
countries_plot = ['Spain', 'Italy', 'United States','Ireland','China']
barchart_data = gapminder.loc[gapminder['country'].isin(countries_plot), :]
font = {
'weight': 'normal',
'size' : 40,
'color': 'lightgray'
}
colors =['#FF0000','#169b62','#008c45','#aa151b','#002868']
data_temp = barchart_data.loc[barchart_data['year'] == 2007, :]
fig, ax = plt.subplots(figsize=(10, 5))
ax.clear()
ax.barh(data_temp.country,data_temp.gdpPercap, color = colors)
ax.text(0.95, 0.2, data_temp['year'].iloc[0],
horizontalalignment='right',
verticalalignment='top',
transform=ax.transAxes,
fontdict=font)
plt.show()
En este gráfico veremos como he incluido dos cambios importantes, que pueden tener un impacto muy importante en tu gráfico.
Tip 1. Limpiar el gráfico anterior
Si creamos el gráfico dentro de nuestro objeto ax, los gráficos se nos van a ir “amontonando”, lo que puede hacer que nuestros datos no sean los reales. Lo peor de todo es que, si no usas transparencia en el gráfico, depende del del tipo de gráfico puede que ni te des cuenta de esto.
Para evitar esto, en cada iteración deberemos llamar a ax.clear(), de tal forma que limpie el resultado proveniente del frame anterior.
Tip 2. Filtra tus datos en la función de update
Hacer una animación para un único año es relativamente sencillo. Sin embargo, hacerlo para muchos años parece algo más complejo. Es por eso que para facilitar la creación de animaciones en Python yo recomiendo:
- Crear una lista con todos los estados posibles de la animación. En mi caso los estados son los años, así que creo una lista con todos los años que puedo llegar a plotear.
- Filtra el dataset completo en función del estado dentro de la propia función de update.
Con estos dos pasos se te hará mucho más fácil crear animaciones.
Así pues, voy a crear la función de update de mi animación de barplot teniendo en cuenta los dos puntos anteriores:
countries_plot = ['Spain', 'Italy', 'United States','Ireland','China']
barchart_data = gapminder.loc[gapminder['country'].isin(countries_plot), :]
font = {
'weight': 'normal',
'size' : 40,
'color': 'lightgray'
}
years = barchart_data['year'].unique()
colors =['#FF0000','#169b62','#008c45','#aa151b','#002868']
fig, ax = plt.subplots(figsize=(10, 5))
label = ax.text(0.95, 0.2, years[0],
horizontalalignment='right',
verticalalignment='top',
transform=ax.transAxes,
fontdict=font)
def update_barchart(i):
year = years[i]
data_temp = barchart_data.loc[barchart_data['year'] == year, :]
ax.clear()
ax.barh(data_temp.country,data_temp.gdpPercap, color = colors)
label.set_text(year)
anim = animation.FuncAnimation(fig, update_barchart, frames = len(years))
anim.save('barchart.gif')
¡Animación de barplot en Python lista! Como ves, filtrar los datos dentro de la propia función de update hace que todo sea mucho más sencillo.
Ahora que ya tenemos más control sobre las animaciones, vamos a crear una animación algo más compleja pero mucho más impactante: vamos a animar un gráfico de scatter plot en Python. ¡Vamos a ello!
Cómo animar Scatter Plot en Python
Una vez más, para crear la animación del scatter plot lo primero de todo es crear el gráfico para un único año. Para ello vamos a seguir exactamente los mismos casos que para crear la animación de barplot: primero filtramos los datos y después creamos el gráfico.
En este caso, al haber muchos países colorearé los países en base al continente y, además, les daré transparencia:
import numpy as np
import matplotlib
fig, ax = plt.subplots(figsize=(10, 5))
scatter_data = gapminder.copy()
# Create a color depending on
conditions = [
scatter_data.continent == 'Asia',
scatter_data.continent == 'Europe',
scatter_data.continent == 'Africa',
scatter_data.continent == 'Americas',
scatter_data.continent == 'Oceania',
]
values = list(range(5))
scatter_data['color'] = np.select(conditions, values)
font = {
'weight': 'normal',
'size' : 40,
'color': 'lightgray'
}
years = scatter_data['year'].unique()
data_temp = scatter_data.loc[scatter_data['year'] == years[-1], :]
label = ax.text(0.95, 0.25, years[-1],
horizontalalignment='right',
verticalalignment='top',
transform=ax.transAxes,
fontdict=font)
colors =[f'C{i}' for i in np.arange(1, 6)]
cmap, norm = matplotlib.colors.from_levels_and_colors(np.arange(1, 5+2), colors)
scatter = ax.scatter(data_temp.gdpPercap,
data_temp.lifeExp,
s=data_temp['pop']/500000,
alpha = 0.5,
c=data_temp.color,
cmap=cmap,
norm=norm)
label.set_text(years[-1])
plt.show()
Ahora que ya tenemos nuestra gráfica montada, ahora debemos convertirlo en función para animarla. En este caso, resulta fundamental que antes de cada frame limpiemos el contenido anterior del gráfico dentro del objeto ax, ya que sino la animación no quedará bien.
Más allá de eso, el procedimiento para crear la animación de scatter plot es el mismo que el explicado previamente para crear otro tipo de animaciones en Python:
fig, ax = plt.subplots(figsize=(10, 5))
years = scatter_data['year'].unique()
colors =[f'C{i}' for i in np.arange(1, 6)]
cmap, norm = matplotlib.colors.from_levels_and_colors(np.arange(1, 5+2), colors)
label = ax.text(0.95, 0.25, years[0],
horizontalalignment='right',
verticalalignment='top',
transform=ax.transAxes,
fontdict=font)
def update_scatter(i):
year = years[i]
data_temp = scatter_data.loc[scatter_data['year'] == year, :]
ax.clear()
label = ax.text(0.95, 0.20, years[i],
horizontalalignment='right',
verticalalignment='top',
transform=ax.transAxes,
fontdict=font)
ax.scatter(
data_temp['gdpPercap'],
data_temp['lifeExp'],
s=data_temp['pop']/500000,
alpha = 0.5,
c=data_temp.color,
cmap=cmap,
norm=norm
)
label.set_text(year)
anim = animation.FuncAnimation(fig, update_scatter, frames = len(years), interval = 30)
anim.save('scatter.gif')
¡Ya tenemos nuestro scatter plot animado! Ahora vayamos a por la última de las animaciones que vamos aprender a crear en Python: una animación de barchart race.
Cómo crear animaciones de barplot race en Python
La animación de barchart race es muy similar a la animación de barplot que hemos hecho anteriormente. La principal diferencia reside en que, en el barplot race los datos están ordenados, de tal forma que veamos cómo ha ido evolucionando el top de X observaciones para una variable (puede ser desde valoración en bolsa a uso de videojuegos o, como en nuestro caso, el PIB per Cápita de los países).
Así pues, para crear nuestro barplot race necesitaremos tener, para cada uno de los años, cuál es el ranking de los países. Para ello, yo recomiendo utilizar el método rank de pandas , ya que podremos obtener los rankings de una forma muy sencilla.
Una vez tenemos el ranking, simplemente deberemos filtrar los datos para quedarnos con el número de observaciones que nos interese, en mi caso 10.
Por último, una vez tengamos nuestros datos filtrados, solo habrá que crear un gráfico de barras horizontales donde el eje vertical sea el ranking. Además, para hacer que el gráfico sea más entendible, cambiaremos el nombre del tick por el nombre del país. Esto lo haremos con el parámetro tick_label de la función barh.
Así pues, veámos cómo sería para un único caso:
barchartrace_data = gapminder.copy()
n_observations = 10
font = {
'weight': 'normal',
'size' : 40,
'color': 'lightgray'
}
data_temp = barchartrace_data.loc[barchartrace_data['year'] == 1952, :]
# Create rank and get first 10 countries
data_temp['ranking'] = data_temp.gdpPercap.rank(method='max', ascending = False).values
data_temp = data_temp.loc[data_temp['ranking'] <= n_observations]
colors = plt.cm.Dark2(range(6))
fig, ax = plt.subplots(figsize=(10, 5))
ax.barh(y = data_temp['ranking'] ,
width = data_temp.gdpPercap,
tick_label=data_temp['country'],
color=colors)
ax.text(0.95, 0.2, data_temp['year'].iloc[0],
horizontalalignment='right',
verticalalignment='top',
transform=ax.transAxes,
fontdict=font)
ax.set_ylim(ax.get_ylim()[::-1]) # Revert axis
plt.show()
¡Perfecto! Ya tenemos nuestra base creada. Ahora solo queda crear nuestra función de animación. En este caso, yo recomiendo que el ranking y selección de los países se haga dentro de la propia función de actualización, ya que facilita mucho el entendimiento y permite aprovechar el código de nuestro gráfico base.
Así pues, la función de update de nuestro barplot race es la siguiente:
barchartrace_data = gapminder.copy()
n_observations = 10
fig, ax = plt.subplots(figsize=(10, 5))
font = {
'weight': 'normal',
'size' : 40,
'color': 'lightgray'
}
years = barchartrace_data['year'].unique()
label = ax.text(0.95, 0.20, years[0],
horizontalalignment='right',
verticalalignment='top',
transform=ax.transAxes,
fontdict=font)
def update_barchart_race(i):
year = years[i]
data_temp = barchartrace_data.loc[barchartrace_data['year'] == year, :]
# Create rank and get first 10 countries
data_temp['prueba'] = data_temp['gdpPercap'].rank(ascending = False)
data_temp = data_temp.loc[data_temp['prueba'] <= n_observations]
colors = plt.cm.Dark2(range(6))
ax.clear()
ax.barh(y = data_temp['prueba'] ,
width = data_temp.gdpPercap,
tick_label=data_temp['country'],
color=colors)
label = ax.text(0.95, 0.20, year,
horizontalalignment='right',
verticalalignment='top',
transform=ax.transAxes,
fontdict=font)
ax.set_ylim(ax.get_ylim()[::-1]) # Revert axis
anim = animation.FuncAnimation(fig, update_barchart_race, frames = len(years))
anim.save('barchart_race.gif')
¡Barplot race creado! Ya hemos visto cómo crear diferentes tipos de animaciones en Python. Sin embargo, no parece que sean del todo visuales, ya que simplemente se limitan a poner un gráfico encima del otro. Así pues, veámos cómo podemos mejorar la fluidez de nuestras animaciones en Python. ¡Vamos a ello!
Cómo mejorar la fluidez de las animaciones de Python
Como he explicado más arriba en este post, la función FuncAnimation se limita a crear la animación poniendo las imágenes que genera nuestra función de update. Como nuestras imágenes cambian año a año, nuestra animación dará pequeños “saltos”.
Así pues, para que nuestra animación sea mucho más fluida, deberemos crear más datos entre cada uno de los estados que tenemos. De esta forma conseguiremos:
- Tener datos intermedios, por lo que el salto de las animaciones no será tan grande.
- Tener muchos más frames, de tal forma que para la misma duración de la animación, tendrá más fps (frames por segundo) haciendo que se vea mucho más fluida.
Para crear este objetivo, vamos a realizar lo siguiente:
- Crear más observaciones entre los datos que ya tenemos. Estas observaiones estarán vacías.
- Imputar datos a esas nuevas observaciones vacías mediante la interpolación entre los estados.
Suena complejo, pero es más fácil de lo que parece. Veámos cómo hacerlo:
Crear más observaciones entre los datos que ya tenemos
Crear más observaciones de las que ya tenemos entre los estados actuales es muy sencillo. Simplemente debemos partir de un índice que vaya de 0 al número de observaciones que tengamos. Esto lo podemos conseguir con el método reset_index.
Una vez nuestros datos son así, simplemente podemos cambiar el índice actual de los datos multiplicando cada índice por el número de frames entre estados que queramos crear. Si queremos crear 10 frames, al multiplicar el índice antiguo por 10, la segunda observación (índice 1) pasará a tener el índice 10 y entre medias se habrán creado muchas variables vacías.
En cualquier caso, para que la interpolación funcione bien, deberemos tener los datos en el formato adecuado, que es:
- Cada fila debe ser un estado, un año en mi caso.
- Cada columna debe ser la observación que nosotros vayamos a graficar, en mi caso, un país.
- El valor debe ser la variable que vayamos a graficar. En mi caso seguiré creando el barplot race, por lo que la variable sigue siendo el
gdpPercap.
Así pues, esto es lo que debemos realizar:
barchartrace_data = gapminder.copy()
n_observations = 10
n_frames_between_states = 30
barchartrace_data= barchartrace_data.pivot('year', 'country', 'gdpPercap')
barchartrace_data['year'] = barchartrace_data.index
barchartrace_data.reset_index(drop = True, inplace = True)
barchartrace_data.index = barchartrace_data.index * n_frames_between_states
barchartrace_data = barchartrace_data.reindex(range(barchartrace_data.index.max()+1))
barchartrace_data.iloc[:15,:5]
country Afghanistan Albania Algeria Angola Argentina
0 779.445314 1601.056136 2449.008185 3520.610273 5911.315053
1 NaN NaN NaN NaN NaN
2 NaN NaN NaN NaN NaN
3 NaN NaN NaN NaN NaN
4 NaN NaN NaN NaN NaN
5 NaN NaN NaN NaN NaN
6 NaN NaN NaN NaN NaN
7 NaN NaN NaN NaN NaN
8 NaN NaN NaN NaN NaN
9 NaN NaN NaN NaN NaN
10 NaN NaN NaN NaN NaN
11 NaN NaN NaN NaN NaN
12 NaN NaN NaN NaN
Como ves, cada columna es un país y he creado 30 nuevas observaciones entre los estados que ya tenía. Una vez hecho esto, ya podemos ver cómo imputar esos nuevos datos mediante interpolación.
Imputar datos a esas nuevas observaciones vacías mediante la interpolación entre los estados
Para imputar los datos vacíos, vamos a usar la interpolación. Esto se puede realizar con el método interpolate de pandas. Existen diferentes métodos de interpolación (puedes encontrar los método aquí) y cada uno darán un efecto diferente, como puedes ver en la siguiente animación hecha por Nicholas A Rossi (enlace).
En nuestro caso lo vamos a hacer fácil, dejando los valores por defecto del método, esto es, aplicando una interpolación lineal. Aunque sea lo más sencillo, el cambio va a ser importante, solo hay que ver la diferencia entre usar la interpolación lineal y no usar interpolación en la animación.
Así pues, podemos interpolar nuestros datos de la siguiente forma:
barchartrace_data = barchartrace_data.interpolate()
barchartrace_data.iloc[:15,:5]
country Afghanistan Albania Algeria Angola Argentina
0 779.445314 1601.056136 2449.008185 3520.610273 5911.315053
1 780.825572 1612.430406 2467.840446 3530.854613 5942.833092
2 782.205829 1623.804677 2486.672708 3541.098952 5974.351130
3 783.586086 1635.178947 2505.504969 3551.343292 6005.869169
4 784.966343 1646.553217 2524.337230 3561.587632 6037.387208
5 786.346600 1657.927487 2543.169491 3571.831972 6068.905246
6 787.726858 1669.301758 2562.001753 3582.076311 6100.423285
7 789.107115 1680.676028 2580.834014 3592.320651 6131.941323
8 790.487372 1692.050298 2599.666275 3602.564991 6163.459362
9 791.867629 1703.424568 2618.498536 3612.809331 6194.977401
10 793.247886 1714.798839 2637.330798 3623.053670 6226.495439
11 794.628143 1726.173109 2656.163059 3633.298010 6258.013478
12 796.008401 1737.547379 2674.995320 3643.542350 6289.531517
13 797.388658 1748.921649 2693.827581 3653.786690 6321.049555
14 798.768915 1760.295920 2712.659843 3664.031029 6352.567594
Por último, ahora que ya tenemos nuestros datos interpolados, vamos a cambiar la forma de nuestro dataframe para que siga manteniendo la forma que tenía antes, es decir, que tanto el año, como el país como el PIB per Cápita sean variables. Esto lo podemos conseguir con el método melt de pandas.
# Hacemos otro pivot para volver a los datos originales
barchartrace_data = barchartrace_data.melt(id_vars='year', var_name ='country', value_name = 'gdpPercap')
barchartrace_data.iloc[:15,:5]
year country gdpPercap
0 1952.000000 Afghanistan 779.445314
1 1952.166667 Afghanistan 780.825572
2 1952.333333 Afghanistan 782.205829
3 1952.500000 Afghanistan 783.586086
4 1952.666667 Afghanistan 784.966343
5 1952.833333 Afghanistan 786.346600
6 1953.000000 Afghanistan 787.726858
7 1953.166667 Afghanistan 789.107115
8 1953.333333 Afghanistan 790.487372
9 1953.500000 Afghanistan 791.867629
10 1953.666667 Afghanistan 793.247886
11 1953.833333 Afghanistan 794.628143
12 1954.000000 Afghanistan 796.008401
13 1954.166667 Afghanistan 797.388658
14 1954.333333 Afghanistan 798.768915
Si te fijas, tenemos un dataframe exactamente igual que el que teníamos cuando hemos hecho la animación del barchart race previamente, solo que ahora tenemos muchos más datos intermedios. Así pues, simplemente debemos replicar el código de antes para conseguir buenos resultados:
import math
n_observations = 10
fig, ax = plt.subplots(figsize=(10, 5))
font = {
'weight': 'normal',
'size' : 40,
'color': 'lightgray'
}
years = barchartrace_data['year'].unique()
label = ax.text(0.95, 0.20, years[0],
horizontalalignment='right',
verticalalignment='top',
transform=ax.transAxes,
fontdict=font)
colors = plt.cm.Dark2(range(200))
def update_barchart_race(i):
year = years[i]
data_temp = barchartrace_data.loc[barchartrace_data['year'] == year, :]
# Create rank and get first 10 countries
data_temp['ranking'] = data_temp['gdpPercap'].rank(method = 'first',ascending = False)
data_temp = data_temp.loc[data_temp['ranking'] <= n_observations]
ax.clear()
ax.barh(y = data_temp['ranking'] ,
width = data_temp.gdpPercap,
tick_label=data_temp['country'],
color=colors)
label = ax.text(0.95, 0.20, math.floor(year),
horizontalalignment='right',
verticalalignment='top',
transform=ax.transAxes,
fontdict=font)
ax.set_ylim(ax.get_ylim()[::-1]) # Revert axis
anim = animation.FuncAnimation(fig, update_barchart_race, frames = len(years))
anim.save('barchart_race_cool.gif', fps = 20)
¡Animación mejorada! Ahora queda mucho mejor, ¿verdad? Sin embargo, hay una cosa que quizás se pueda seguir mejorando: los estados intermedio. Y es que, aunque las gráficas se animen, los cambios de posición siguen siendo saltos. Veamos cómo crear ese movimiento horizontal de nuestras animaciones en Python.
Cómo crear movimiento horizontal en barchart race
La razón por la cual las posiciones siguen dando “saltos” es que, aunque hayamos interpolado los datos de la gráfica, no hemos interpolado los datos de las posiciones. Así pues, podemos crear una interpolación de los datos de las posiciones y juntarnos a nuestro dataset anterior.
Para ello, primero tendremos que quedarnos con el ranking de cada país para cada año. Una vez tengamos ese dataframe, el proceso será exactamente el mismo al realizado anteriormente: pivotamos, interpolamos y deshacemos el pivotado con un melt.
Importante: para que este método funcione debemos usar el mismo sistema de interpolación que hemos usado anteriormente.
ranking_data = gapminder.copy()
n_observations = 10
n_frames_between_states = 30
#barchartrace_data['ranking']
ranking = ranking_data.groupby('year')['gdpPercap'].rank(method = 'first', ascending = False)
ranking = ranking.rename('ranking', axis = 1)
ranking_data = ranking_data.join(ranking)
ranking_data = ranking_data.pivot('year', 'country', 'ranking')
ranking_data['year'] =ranking_data.index
ranking_data.reset_index(drop = True, inplace = True)
ranking_data.index = ranking_data.index * n_frames_between_states
ranking_data = ranking_data.reindex(range(ranking_data.index.max()+1))
ranking_data = ranking_data.interpolate('linear')
ranking_data = ranking_data.melt(id_vars='year', var_name ='country', value_name = 'ranking')
ranking_data.iloc[:15,:5]
year country ranking
0 1952.000000 Afghanistan 113.000000
1 1952.166667 Afghanistan 113.066667
2 1952.333333 Afghanistan 113.133333
3 1952.500000 Afghanistan 113.200000
4 1952.666667 Afghanistan 113.266667
5 1952.833333 Afghanistan 113.333333
6 1953.000000 Afghanistan 113.400000
7 1953.166667 Afghanistan 113.466667
8 1953.333333 Afghanistan 113.533333
9 1953.500000 Afghanistan 113.600000
10 1953.666667 Afghanistan 113.666667
11 1953.833333 Afghanistan 113.733333
12 1954.000000 Afghanistan 113.800000
13 1954.166667 Afghanistan 113.866667
14 1954.333333 Afghanistan 113.933333
Ahora que tenemos esta información, deberemos unirla con la información que teníamos en el dataset previo:
barchartrace_data = ranking_data.merge(barchartrace_data,
left_on = ['country','year'],
right_on = ['country','year'])
barchartrace_data.head()
year country ranking gdpPercap
0 1952.000000 Afghanistan 113.000000 779.445314
1 1952.166667 Afghanistan 113.066667 780.825572
2 1952.333333 Afghanistan 113.133333 782.205829
3 1952.500000 Afghanistan 113.200000 783.586086
4 1952.666667 Afghanistan 113.266667 784.966343
Como veis, el ranking no son números redondos sino que se van modificando ligeramente en cada estado. Así pues, ya podemos crear de nuevo la animación. En este caso, como ya tenemos el ranking calculado, no hará falta que lo volvamos a calcular:
import math
n_observations = 10
fig, ax = plt.subplots(figsize=(10, 5))
font = {
'weight': 'normal',
'size' : 40,
'color': 'lightgray'
}
years = barchartrace_data['year'].unique()
label = ax.text(0.95, 0.20, years[0],
horizontalalignment='right',
verticalalignment='top',
transform=ax.transAxes,
fontdict=font)
# Create colors
# 1. Get continent
continent = gapminder[['country','continent']].drop_duplicates().reset_index(drop = True)
# 2. Add continent info
barchartrace_data = barchartrace_data.merge(continent,left_on = 'country', right_on = 'country')
# 3. Use continent to get color
conditions = [
barchartrace_data['continent'] == 'Asia',
barchartrace_data['continent'] == 'Europe',
barchartrace_data['continent'] == 'Africa',
barchartrace_data['continent'] == 'Americas',
barchartrace_data['continent'] == 'Oceania',
]
values = ['#0275d8', '#5cb85c', '#5bc0de', '#f0ad4e', '#d9534f']
barchartrace_data['color'] = np.select(conditions, values)
def update_barchart_race(i):
year = years[i]
data_temp = barchartrace_data.loc[barchartrace_data['year'] == year, :]
# Create rank and get first 10 countries
data_temp = data_temp.loc[data_temp['ranking'] <= n_observations]
ax.clear()
ax.barh(y = data_temp['ranking'] ,
width = data_temp.gdpPercap,
tick_label=data_temp['country'],
color=data_temp['color'])
label = ax.text(0.95, 0.20, math.floor(year),
horizontalalignment='right',
verticalalignment='top',
transform=ax.transAxes,
fontdict=font)
ax.set_ylim(ax.get_ylim()[::-1]) # Revert axis
anim = animation.FuncAnimation(fig, update_barchart_race, frames = len(years), )
anim.save('barchart_race_cool2.gif', fps=30)
¡Ya tenemos nuestra animación creada en Python! Como ves, la interpolación de ha hecho que nuestra animación sea mucho más atractiva y parezca muchísimo más fluida de lo que era al principio. ¡Y todo, con solo unas pocas líneas de código! ¿No es fantástico? Pero eso no es todo, veamos otro truquito más que nos permitirá mejorar mucho nuestras animaciones en Python.
Evitar saltos en las animaciones
Un problema típico en las animaciones, como ha ocurrido en la animación del scatter plot, es que haya saltos entre frames. Esto se debe a que los ejes del gráfico cambian, haciendo que el contenido del mismo parezca diferente cuando, en realidad, no lo es.
Arreglar esto es bastante sencillo. Simplemente para cada frame se debe fijar el valor máximo de los ejes X e Y. Ese valor será el máximo que alcanzará el gráfico en toda la serie. De esta forma, conseguiremos evitar esos tirones.
Esto sobre todo es importante aplicarlo cuando animamos gráficos como el scatter plot. Sin embargo, a la hora de animar el linechart no es recomendable aplicarlo ya que reduce mucho el impacto visual de la animación.
Así pues, vamos a rehacer la animación del scatter plot, pero esta vez aplicando una mayor fluidez mediante la interpolación y evitando los saltos mediante la fijación de las escalas.
En este caso, deberemos aplicar la interpolación a las tres variables que se utilizan en la animación. Así pues, para facilitar el proceso crearemos una función que nos haga la interpolación.
scatter_data = gapminder.copy()
n_frames_between_states = 30
continent = gapminder[['country','continent']].drop_duplicates().reset_index(drop = True)
def interpolate_data(data,frame,obs,variable, n_new_frames, interpolation = 'linear'):
data= data.pivot(frame, obs, variable)
data[frame] = data.index
data.reset_index(drop = True, inplace = True)
data.index = data.index * n_new_frames
data = data.reindex(range(data.index.max()+1))
data = data.interpolate(interpolation)
data = data.melt(id_vars= frame, var_name = obs, value_name = variable)
return data
# Interpolate data
scatter_data_pop = interpolate_data(scatter_data, 'year', 'country','pop',30)
scatter_data_gdpPerCap = interpolate_data(scatter_data, 'year', 'country','gdpPercap',30)
scatter_data_lifeExp = interpolate_data(scatter_data, 'year', 'country','lifeExp',30)
# Merge the datasets
scatter_data = scatter_data_gdpPerCap.merge(scatter_data_pop,
left_on = ['country','year'],
right_on = ['country','year'])
scatter_data = scatter_data.merge(scatter_data_lifeExp,
left_on = ['country','year'],
right_on = ['country','year']).merge(continent)
scatter_data.head()
year country gdpPercap pop lifeExp continent
0 1952.000000 Afghanistan 779.445314 8425333.0 28.801000 Asia
1 1952.166667 Afghanistan 780.825572 8452519.7 28.852033 Asia
2 1952.333333 Afghanistan 782.205829 8479706.4 28.903067 Asia
3 1952.500000 Afghanistan 783.586086 8506893.1 28.954100 Asia
4 1952.666667 Afghanistan 784.966343 8534079.8 29.005133 Asia
Ahora que ya tenemos el dataset, podemos crear la animación. Para fijar los límites de la animación simplemente debemos utilizar el método set_xlim.
fig, ax = plt.subplots(figsize=(10, 5))
years = scatter_data['year'].unique()
conditions = [
scatter_data.continent == "Asia",
scatter_data.continent == "Europe",
scatter_data.continent == "Africa",
scatter_data.continent == "Americas",
scatter_data.continent == "Oceania",
]
values = list(range(5))
scatter_data['color'] = np.select(conditions, values)
colors =[f"C{i}" for i in np.arange(1, 6)]
cmap, norm = matplotlib.colors.from_levels_and_colors(np.arange(1, 5+2), colors)
# Get maximum values
x_max = scatter_data['gdpPercap'].max()
y_max = scatter_data['lifeExp'].max()
def update_scatter(i):
year = years[i]
data_temp = scatter_data.loc[scatter_data['year'] == year, :]
ax.clear()
label = ax.text(0.95, 0.25, years[0],
horizontalalignment='right',
verticalalignment='top',
transform=ax.transAxes,
fontdict=font)
# Set limits
ax.set_xlim((0,x_max))
ax.set_ylim((0,y_max))
ax.scatter(
data_temp['gdpPercap'],
data_temp['lifeExp'],
s=data_temp['pop']/500000,
alpha = 0.5,
c=data_temp.color,
cmap=cmap,
norm=norm
)
label.set_text(math.floor(year))
anim = animation.FuncAnimation(fig, update_scatter, frames = len(years))
anim.save('scatter2.gif', fps = 20)
Conclusión
Sin duda alguna crear animaciones en Python es algo que va a permitirte crear gráficos muy visuales que generen mucho más impacto. Esto es algo básico que te va a permitir desde generar mucho más impactantes a poder explicar procesos de una forma más sencilla, como hice con el algoritmo k-Mean en este post.
Además, si estás acostumbrado a trabajar con pandas y matplotlib y entiendes el funcionamiento detrás de las funciones de animación, es algo muy sencillo de poder hacer.
Espero que este post te haya gustado. Si es así, te animo a suscribirte para estar al día de todos los posts que voy subiendo. En cualquier caso, ¡nos vemos en el próximo!