Skip to content

Commit

Permalink
Code to perform some exploratory data analysis and initial PCA decomp…
Browse files Browse the repository at this point in the history
…osition of seizure data. Input data and results of running the code.
  • Loading branch information
moink committed Apr 30, 2019
1 parent 82b039e commit f79b7d1
Show file tree
Hide file tree
Showing 11 changed files with 11,866 additions and 0 deletions.
11,501 changes: 11,501 additions & 0 deletions data/data.csv

Large diffs are not rendered by default.

289 changes: 289 additions & 0 deletions epiclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
import os

import matplotlib
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA

LEGEND_COORDS = (1.2, 0.8)
TOTAL_PCA_COMPONENTS = 60


def main():
"""Explore epileptic seizure classification data set
Returns:
None
"""
epidata = pd.read_csv(os.path.join('data', 'data.csv'))
set_matplotlib_params()
explore_data(epidata)
features = epidata.drop(['y', 'Unnamed: 0'], axis=1)
target = epidata['y']
x_train, x_test, y_train, y_test = train_test_split(features, target,
test_size=0.3,
random_state=0)
print(type(x_train))
explore_pca(x_train)


def explore_pca(x_train, n_components=TOTAL_PCA_COMPONENTS):
"""Create plots of Principal Component Analysis decomposition
Find the first TOTAL_PCA_COMPONENTS PCA components of the argument. Create
a plot of the explained variance ratio and a plot of the cumulative sum of
the explained variance ratio, over the number of PCA components used. Save
the explained variance ratio to a comma-separated file.
Args:
x_train: pandas.DataFrame
Training data to decompose using Principal Component Analysis
n_components : int, Optional, default 60
Number of PCA components to use in plots and csv file
Returns:
None
"""
pca = PCA(n_components=n_components)
pca.fit(x_train)
var_ratio = pd.Series(pca.explained_variance_ratio_)
sum_var_ratio = var_ratio.cumsum()
var_ratio.to_csv(os.path.join('outputs', 'var_ratio.csv'), header=True)
plot_explained_variance(var_ratio, os.path.join('outputs', 'var_ratio.png'))
plot_sum_explained_variance(sum_var_ratio,
os.path.join('outputs', 'var_ratio_sum.png'))


def plot_sum_explained_variance(var_ratio, filename):
"""Plot one minus cumulative sum of the explained variance of a transform
Create a plot with a linear x-axis and a logarithmic y-axis of one minus
the cumulative explained variance of the components of a dimensionality
reduction method and save the plot to a file. This plot allows the user
to determine how many components are required to achieve a certain total
explained variance ratio, such as 90% or 99%.
Args:
var_ratio: pandas.Series
Explained variance ratio of each component. Should be
monotonically decreasing.
filename: str
File to save. File format is inferred from the extension
Returns:
None
"""
plot_axes = plt.subplot(1, 1, 1)
plot_axes.semilogy(1 - var_ratio, marker='.', linestyle='none')
plot_axes.set_xlabel('PCA component')
plot_axes.set_ylabel('1 - sum(Explained variance ratio)')
fig = plot_axes.get_figure()
fig.savefig(filename)
fig.clf()


def plot_explained_variance(var_ratio, filename):
"""Plot the explained variance of each component of a transform
Create a plot with a linear x-axis and a logarithmic y-axis of the
explained variance ratio of each component. This plot allows the user
to see how rapidly the explained variance decreases by component and to
identify and eigengap, if it exists.
Args:
var_ratio: pandas.Series
Explained variance ratio of each component. Should be
monotonically decreasing.
filename: str
File to save. File format is inferred from the extension
Returns:
None
"""
plot_axes = plt.subplot(1, 1, 1)
plot_axes.semilogy(var_ratio, marker='.', linestyle='none')
plot_axes.set_xlabel('PCA component')
plot_axes.set_ylabel('Explained variance ratio')
fig = plot_axes.get_figure()
fig.savefig(filename)
fig.clf()


def explore_data(epidata):
"""Create a number of plots and csv files to explore the seizure data
Args:
epidata: pandas.DataFrame
Data about seizures. Each row is a data point and each column is
a feature, except for the column 'y' which contains the
classification target.
Returns:
None
"""
features = epidata.drop(['y', 'Unnamed: 0'], axis=1)
desc = features.describe()
desc.transpose().describe().to_csv(os.path.join('outputs',
'double_desc.csv'))
epidata['y'].value_counts().to_csv(os.path.join('outputs',
'class_counts.csv'),
header=True)
create_summary_plot(desc, os.path.join('outputs', 'feature_summary.png'))
create_interquartile_plot(desc, os.path.join('outputs',
'interquartile.png'))
create_mean_median_plot(desc, os.path.join('outputs', 'mean_median.png'))
create_std_plot(desc.loc['std', :], os.path.join('outputs',
'feature_std.png'))
create_corr_heatmap(features, os.path.join('outputs', 'corr_heatmap.png'),
os.path.join('outputs', 'corr_X90.png'))


def create_corr_heatmap(features, filename1, filename2):
"""Create 2 plots: a feature correlation heat map and correlations with X90
Args:
features: pandas.DataFrame
Feature columns. Each row is a data point and each column is a
feature. One column must be named 'X90'.
filename1: str
Path to which to save correlation heatmap figure. File format is
inferred from the extension.
filename2:
Path to which to save figure showing correlations with feature X90.
File format is inferred from the extension.
Returns:
None
"""
corr_mat = features.corr()
corr_mat.to_csv(os.path.join('outputs', 'corr_mat.csv'))
plot_axes = corr_mat['X90'].plot()
plot_axes.set_xlabel('Feature')
plot_axes.set_ylabel('Correlation with feature X90')
fig = plot_axes.get_figure()
fig.savefig(filename2)
fig.clf()
sns.heatmap(corr_mat, center=0, cmap='coolwarm')
fig.savefig(filename1)
fig.clf()


def create_std_plot(std, filename):
"""Create plot of standard deviation of each feature
Args:
std: pandas.Series
Standard deviation of each feature across the data set. The index is
the name of the feature.
filename: str
Path to which to save the figure. The file format is inferred
from the extension.
Returns:
None
"""
plot_axes = std.plot()
plot_axes.set_xlabel('Feature')
plot_axes.set_ylabel('Standard deviation of feature value')
fig = plot_axes.get_figure()
fig.savefig(filename)
fig.clf()


def create_summary_plot(description, filename):
"""Create a plot showing mean, min, max, and quartiles of features
Args:
description: pandas.DataFrame
Description of the features. Result of running the
pandas.DataFrame.describe method on the features
filename: str
Path to which to save the figure. The file format is inferred
from the extension.
Returns:
None
"""
to_plot = description.drop(['count', 'std']).transpose()
create_feature_value_plot(to_plot, filename)


def create_interquartile_plot(data, filename):
"""Create a plot of the mean, median, and 25th and 75th percentile
Args:
data: pandas.DataFrame
Description of the features. Result of running the
pandas.DataFrame.describe method on the features
filename: str
Path to which to save the figure. The file format is inferred
from the extension.
Returns:
None
"""
cols = ['mean', '25%', '50%', '75%']
to_plot = data.transpose()[cols]
create_feature_value_plot(to_plot, filename)


def create_mean_median_plot(data, filename):
"""Create a plot of the mean and median of the features
Args:
data: pandas.DataFrame
Description of the features. Result of running the
pandas.DataFrame.describe method on the features
filename: str
Path to which to save the figure. The file format is inferred
from the extension.
Returns:
None
"""
cols = ['mean', '50%']
to_plot = data.transpose()[cols]
to_plot.columns = ['mean', 'median']
create_feature_value_plot(to_plot, filename)


def create_feature_value_plot(data, filename):
"""Create a plot with features on the x-axis and values on the y-axis
Args:
data: pandas.DataFrame
Data to plot
filename: str
Path to which to save the figure. The file format is inferred
from the extension.
Returns:
None
"""
plot_axes = data.plot()
plot_axes.set_xlabel('Feature')
plot_axes.set_ylabel('Value')
plot_axes.legend(loc='right', bbox_to_anchor=LEGEND_COORDS)
fig = plot_axes.get_figure()
fig.savefig(filename, bbox_inches='tight')
fig.clf()


def set_matplotlib_params():
"""Set matplotlib parameters to chosen aesthetics
Set the figure dots per inch to 200, and the edge, tick and axis label
colors to gray for all future matplotlib calls.
Returns:
None
"""
matplotlib.rcParams['savefig.dpi'] = 200
matplotlib.rcParams['axes.edgecolor'] = 'gray'
matplotlib.rcParams['xtick.color'] = 'gray'
matplotlib.rcParams['ytick.color'] = 'gray'
matplotlib.rcParams['axes.labelcolor'] = 'gray'


if __name__ == '__main__':
main()
6 changes: 6 additions & 0 deletions outputs/class_counts.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
,y
5,2300
4,2300
3,2300
2,2300
1,2300
Binary file added outputs/corr_X90.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added outputs/corr_heatmap.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 9 additions & 0 deletions outputs/double_desc.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
,count,mean,std,min,25%,50%,75%,max
count,178.0,178.0,178.0,178.0,178.0,178.0,178.0,178.0
mean,11500.0,-7.72243624816805,164.54793542513497,-1812.5,-53.485955056179776,-7.314606741573034,36.41713483146067,1737.6404494382023
std,0.0,2.0152466233942636,3.4026820680967487,81.03924208099522,1.3267987207659901,0.9636970433526939,1.271904659140675,296.35957878313974
min,11500.0,-13.668869565217392,153.88138289800153,-1885.0,-57.0,-10.0,33.0,1213.0
25%,11500.0,-9.210934782608696,162.16481491658084,-1864.0,-54.0,-8.0,36.0,1451.25
50%,11500.0,-7.249,164.6027516168026,-1844.5,-53.0,-7.0,36.125,1781.5
75%,11500.0,-6.229760869565218,167.0659221012374,-1788.25,-53.0,-7.0,37.0,2047.0
max,11500.0,-4.143826086956522,172.43988729137322,-1415.0,-50.0,-5.0,39.0,2047.0
Binary file added outputs/feature_std.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added outputs/feature_summary.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added outputs/interquartile.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added outputs/mean_median.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
61 changes: 61 additions & 0 deletions outputs/var_ratio.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
,0
0,0.05851539553721611
1,0.05295931429035144
2,0.05142283168157521
3,0.04956283068080061
4,0.04512750251578775
5,0.04183149861925076
6,0.040411969235401746
7,0.03890075947581718
8,0.03665612921066289
9,0.033500682440094055
10,0.03300985459132789
11,0.030857865853442135
12,0.029628639604841862
13,0.028965135860325994
14,0.02669929489623337
15,0.024036870519814693
16,0.022372536564910974
17,0.020754353287488345
18,0.020144226316130765
19,0.01891155013519223
20,0.018568892045012584
21,0.01776000369766883
22,0.017579631035157313
23,0.017399234894095263
24,0.016863362354587606
25,0.016444384563267522
26,0.014839093967262879
27,0.014522995797549343
28,0.013808867906598002
29,0.013715511316088055
30,0.01318465053262129
31,0.012507316799692868
32,0.011971642565251895
33,0.011080565341319713
34,0.010799991507960254
35,0.009172132818290572
36,0.008010301371800668
37,0.006540595431528309
38,0.005807131311384302
39,0.005342073543740615
40,0.004631168024817603
41,0.004387385376741865
42,0.0035364363434579418
43,0.0029813579309323807
44,0.0025291588746729023
45,0.0021924771259065347
46,0.0020344321871594746
47,0.0017454795147042851
48,0.0016198944890054435
49,0.0013954375021210123
50,0.001154674600832409
51,0.001089323872677253
52,0.0009106154349725524
53,0.0008807071328793311
54,0.0007910131242580489
55,0.000699931604917531
56,0.0006562820132999513
57,0.0005992152591999827
58,0.0005220298625399005
59,0.00046864753446095996

0 comments on commit f79b7d1

Please sign in to comment.