-
Notifications
You must be signed in to change notification settings - Fork 3
/
utils.py
29 lines (24 loc) · 1.08 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def plot_multi(data, cols=None, spacing=.1, **kwargs):
import matplotlib.pyplot as plt
from pandas import plotting
# Get default color style from pandas - can be changed to any other color list
if cols is None: cols = data.columns
if len(cols) == 0: return
# colors = getattr(getattr(plotting, '_style'), '_get_standard_colors')(num_colors=len(cols))
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
# First axis
ax = data.loc[:, cols[0]].plot(label=cols[0], color=colors[0], **kwargs)
ax.set_ylabel(ylabel=cols[0])
lines, labels = ax.get_legend_handles_labels()
for n in range(1, len(cols)):
# Multiple y-axes
ax_new = ax.twinx()
ax_new.spines['right'].set_position(('axes', 1 + spacing * (n - 1)))
data.loc[:, cols[n]].plot(ax=ax_new, label=cols[n], color=colors[n % len(colors)])
ax_new.set_ylabel(ylabel=cols[n])
# Proper legend position
line, label = ax_new.get_legend_handles_labels()
lines += line
labels += label
ax.legend(lines, labels, loc=0)
return ax