-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Code to perform some exploratory data analysis and initial PCA decomp…
…osition of seizure data. Input data and results of running the code.
- Loading branch information
Showing
11 changed files
with
11,866 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,289 @@ | ||
import os | ||
|
||
import matplotlib | ||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.decomposition import PCA | ||
|
||
LEGEND_COORDS = (1.2, 0.8) | ||
TOTAL_PCA_COMPONENTS = 60 | ||
|
||
|
||
def main(): | ||
"""Explore epileptic seizure classification data set | ||
Returns: | ||
None | ||
""" | ||
epidata = pd.read_csv(os.path.join('data', 'data.csv')) | ||
set_matplotlib_params() | ||
explore_data(epidata) | ||
features = epidata.drop(['y', 'Unnamed: 0'], axis=1) | ||
target = epidata['y'] | ||
x_train, x_test, y_train, y_test = train_test_split(features, target, | ||
test_size=0.3, | ||
random_state=0) | ||
print(type(x_train)) | ||
explore_pca(x_train) | ||
|
||
|
||
def explore_pca(x_train, n_components=TOTAL_PCA_COMPONENTS): | ||
"""Create plots of Principal Component Analysis decomposition | ||
Find the first TOTAL_PCA_COMPONENTS PCA components of the argument. Create | ||
a plot of the explained variance ratio and a plot of the cumulative sum of | ||
the explained variance ratio, over the number of PCA components used. Save | ||
the explained variance ratio to a comma-separated file. | ||
Args: | ||
x_train: pandas.DataFrame | ||
Training data to decompose using Principal Component Analysis | ||
n_components : int, Optional, default 60 | ||
Number of PCA components to use in plots and csv file | ||
Returns: | ||
None | ||
""" | ||
pca = PCA(n_components=n_components) | ||
pca.fit(x_train) | ||
var_ratio = pd.Series(pca.explained_variance_ratio_) | ||
sum_var_ratio = var_ratio.cumsum() | ||
var_ratio.to_csv(os.path.join('outputs', 'var_ratio.csv'), header=True) | ||
plot_explained_variance(var_ratio, os.path.join('outputs', 'var_ratio.png')) | ||
plot_sum_explained_variance(sum_var_ratio, | ||
os.path.join('outputs', 'var_ratio_sum.png')) | ||
|
||
|
||
def plot_sum_explained_variance(var_ratio, filename): | ||
"""Plot one minus cumulative sum of the explained variance of a transform | ||
Create a plot with a linear x-axis and a logarithmic y-axis of one minus | ||
the cumulative explained variance of the components of a dimensionality | ||
reduction method and save the plot to a file. This plot allows the user | ||
to determine how many components are required to achieve a certain total | ||
explained variance ratio, such as 90% or 99%. | ||
Args: | ||
var_ratio: pandas.Series | ||
Explained variance ratio of each component. Should be | ||
monotonically decreasing. | ||
filename: str | ||
File to save. File format is inferred from the extension | ||
Returns: | ||
None | ||
""" | ||
plot_axes = plt.subplot(1, 1, 1) | ||
plot_axes.semilogy(1 - var_ratio, marker='.', linestyle='none') | ||
plot_axes.set_xlabel('PCA component') | ||
plot_axes.set_ylabel('1 - sum(Explained variance ratio)') | ||
fig = plot_axes.get_figure() | ||
fig.savefig(filename) | ||
fig.clf() | ||
|
||
|
||
def plot_explained_variance(var_ratio, filename): | ||
"""Plot the explained variance of each component of a transform | ||
Create a plot with a linear x-axis and a logarithmic y-axis of the | ||
explained variance ratio of each component. This plot allows the user | ||
to see how rapidly the explained variance decreases by component and to | ||
identify and eigengap, if it exists. | ||
Args: | ||
var_ratio: pandas.Series | ||
Explained variance ratio of each component. Should be | ||
monotonically decreasing. | ||
filename: str | ||
File to save. File format is inferred from the extension | ||
Returns: | ||
None | ||
""" | ||
plot_axes = plt.subplot(1, 1, 1) | ||
plot_axes.semilogy(var_ratio, marker='.', linestyle='none') | ||
plot_axes.set_xlabel('PCA component') | ||
plot_axes.set_ylabel('Explained variance ratio') | ||
fig = plot_axes.get_figure() | ||
fig.savefig(filename) | ||
fig.clf() | ||
|
||
|
||
def explore_data(epidata): | ||
"""Create a number of plots and csv files to explore the seizure data | ||
Args: | ||
epidata: pandas.DataFrame | ||
Data about seizures. Each row is a data point and each column is | ||
a feature, except for the column 'y' which contains the | ||
classification target. | ||
Returns: | ||
None | ||
""" | ||
features = epidata.drop(['y', 'Unnamed: 0'], axis=1) | ||
desc = features.describe() | ||
desc.transpose().describe().to_csv(os.path.join('outputs', | ||
'double_desc.csv')) | ||
epidata['y'].value_counts().to_csv(os.path.join('outputs', | ||
'class_counts.csv'), | ||
header=True) | ||
create_summary_plot(desc, os.path.join('outputs', 'feature_summary.png')) | ||
create_interquartile_plot(desc, os.path.join('outputs', | ||
'interquartile.png')) | ||
create_mean_median_plot(desc, os.path.join('outputs', 'mean_median.png')) | ||
create_std_plot(desc.loc['std', :], os.path.join('outputs', | ||
'feature_std.png')) | ||
create_corr_heatmap(features, os.path.join('outputs', 'corr_heatmap.png'), | ||
os.path.join('outputs', 'corr_X90.png')) | ||
|
||
|
||
def create_corr_heatmap(features, filename1, filename2): | ||
"""Create 2 plots: a feature correlation heat map and correlations with X90 | ||
Args: | ||
features: pandas.DataFrame | ||
Feature columns. Each row is a data point and each column is a | ||
feature. One column must be named 'X90'. | ||
filename1: str | ||
Path to which to save correlation heatmap figure. File format is | ||
inferred from the extension. | ||
filename2: | ||
Path to which to save figure showing correlations with feature X90. | ||
File format is inferred from the extension. | ||
Returns: | ||
None | ||
""" | ||
corr_mat = features.corr() | ||
corr_mat.to_csv(os.path.join('outputs', 'corr_mat.csv')) | ||
plot_axes = corr_mat['X90'].plot() | ||
plot_axes.set_xlabel('Feature') | ||
plot_axes.set_ylabel('Correlation with feature X90') | ||
fig = plot_axes.get_figure() | ||
fig.savefig(filename2) | ||
fig.clf() | ||
sns.heatmap(corr_mat, center=0, cmap='coolwarm') | ||
fig.savefig(filename1) | ||
fig.clf() | ||
|
||
|
||
def create_std_plot(std, filename): | ||
"""Create plot of standard deviation of each feature | ||
Args: | ||
std: pandas.Series | ||
Standard deviation of each feature across the data set. The index is | ||
the name of the feature. | ||
filename: str | ||
Path to which to save the figure. The file format is inferred | ||
from the extension. | ||
Returns: | ||
None | ||
""" | ||
plot_axes = std.plot() | ||
plot_axes.set_xlabel('Feature') | ||
plot_axes.set_ylabel('Standard deviation of feature value') | ||
fig = plot_axes.get_figure() | ||
fig.savefig(filename) | ||
fig.clf() | ||
|
||
|
||
def create_summary_plot(description, filename): | ||
"""Create a plot showing mean, min, max, and quartiles of features | ||
Args: | ||
description: pandas.DataFrame | ||
Description of the features. Result of running the | ||
pandas.DataFrame.describe method on the features | ||
filename: str | ||
Path to which to save the figure. The file format is inferred | ||
from the extension. | ||
Returns: | ||
None | ||
""" | ||
to_plot = description.drop(['count', 'std']).transpose() | ||
create_feature_value_plot(to_plot, filename) | ||
|
||
|
||
def create_interquartile_plot(data, filename): | ||
"""Create a plot of the mean, median, and 25th and 75th percentile | ||
Args: | ||
data: pandas.DataFrame | ||
Description of the features. Result of running the | ||
pandas.DataFrame.describe method on the features | ||
filename: str | ||
Path to which to save the figure. The file format is inferred | ||
from the extension. | ||
Returns: | ||
None | ||
""" | ||
cols = ['mean', '25%', '50%', '75%'] | ||
to_plot = data.transpose()[cols] | ||
create_feature_value_plot(to_plot, filename) | ||
|
||
|
||
def create_mean_median_plot(data, filename): | ||
"""Create a plot of the mean and median of the features | ||
Args: | ||
data: pandas.DataFrame | ||
Description of the features. Result of running the | ||
pandas.DataFrame.describe method on the features | ||
filename: str | ||
Path to which to save the figure. The file format is inferred | ||
from the extension. | ||
Returns: | ||
None | ||
""" | ||
cols = ['mean', '50%'] | ||
to_plot = data.transpose()[cols] | ||
to_plot.columns = ['mean', 'median'] | ||
create_feature_value_plot(to_plot, filename) | ||
|
||
|
||
def create_feature_value_plot(data, filename): | ||
"""Create a plot with features on the x-axis and values on the y-axis | ||
Args: | ||
data: pandas.DataFrame | ||
Data to plot | ||
filename: str | ||
Path to which to save the figure. The file format is inferred | ||
from the extension. | ||
Returns: | ||
None | ||
""" | ||
plot_axes = data.plot() | ||
plot_axes.set_xlabel('Feature') | ||
plot_axes.set_ylabel('Value') | ||
plot_axes.legend(loc='right', bbox_to_anchor=LEGEND_COORDS) | ||
fig = plot_axes.get_figure() | ||
fig.savefig(filename, bbox_inches='tight') | ||
fig.clf() | ||
|
||
|
||
def set_matplotlib_params(): | ||
"""Set matplotlib parameters to chosen aesthetics | ||
Set the figure dots per inch to 200, and the edge, tick and axis label | ||
colors to gray for all future matplotlib calls. | ||
Returns: | ||
None | ||
""" | ||
matplotlib.rcParams['savefig.dpi'] = 200 | ||
matplotlib.rcParams['axes.edgecolor'] = 'gray' | ||
matplotlib.rcParams['xtick.color'] = 'gray' | ||
matplotlib.rcParams['ytick.color'] = 'gray' | ||
matplotlib.rcParams['axes.labelcolor'] = 'gray' | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
,y | ||
5,2300 | ||
4,2300 | ||
3,2300 | ||
2,2300 | ||
1,2300 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
,count,mean,std,min,25%,50%,75%,max | ||
count,178.0,178.0,178.0,178.0,178.0,178.0,178.0,178.0 | ||
mean,11500.0,-7.72243624816805,164.54793542513497,-1812.5,-53.485955056179776,-7.314606741573034,36.41713483146067,1737.6404494382023 | ||
std,0.0,2.0152466233942636,3.4026820680967487,81.03924208099522,1.3267987207659901,0.9636970433526939,1.271904659140675,296.35957878313974 | ||
min,11500.0,-13.668869565217392,153.88138289800153,-1885.0,-57.0,-10.0,33.0,1213.0 | ||
25%,11500.0,-9.210934782608696,162.16481491658084,-1864.0,-54.0,-8.0,36.0,1451.25 | ||
50%,11500.0,-7.249,164.6027516168026,-1844.5,-53.0,-7.0,36.125,1781.5 | ||
75%,11500.0,-6.229760869565218,167.0659221012374,-1788.25,-53.0,-7.0,37.0,2047.0 | ||
max,11500.0,-4.143826086956522,172.43988729137322,-1415.0,-50.0,-5.0,39.0,2047.0 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
,0 | ||
0,0.05851539553721611 | ||
1,0.05295931429035144 | ||
2,0.05142283168157521 | ||
3,0.04956283068080061 | ||
4,0.04512750251578775 | ||
5,0.04183149861925076 | ||
6,0.040411969235401746 | ||
7,0.03890075947581718 | ||
8,0.03665612921066289 | ||
9,0.033500682440094055 | ||
10,0.03300985459132789 | ||
11,0.030857865853442135 | ||
12,0.029628639604841862 | ||
13,0.028965135860325994 | ||
14,0.02669929489623337 | ||
15,0.024036870519814693 | ||
16,0.022372536564910974 | ||
17,0.020754353287488345 | ||
18,0.020144226316130765 | ||
19,0.01891155013519223 | ||
20,0.018568892045012584 | ||
21,0.01776000369766883 | ||
22,0.017579631035157313 | ||
23,0.017399234894095263 | ||
24,0.016863362354587606 | ||
25,0.016444384563267522 | ||
26,0.014839093967262879 | ||
27,0.014522995797549343 | ||
28,0.013808867906598002 | ||
29,0.013715511316088055 | ||
30,0.01318465053262129 | ||
31,0.012507316799692868 | ||
32,0.011971642565251895 | ||
33,0.011080565341319713 | ||
34,0.010799991507960254 | ||
35,0.009172132818290572 | ||
36,0.008010301371800668 | ||
37,0.006540595431528309 | ||
38,0.005807131311384302 | ||
39,0.005342073543740615 | ||
40,0.004631168024817603 | ||
41,0.004387385376741865 | ||
42,0.0035364363434579418 | ||
43,0.0029813579309323807 | ||
44,0.0025291588746729023 | ||
45,0.0021924771259065347 | ||
46,0.0020344321871594746 | ||
47,0.0017454795147042851 | ||
48,0.0016198944890054435 | ||
49,0.0013954375021210123 | ||
50,0.001154674600832409 | ||
51,0.001089323872677253 | ||
52,0.0009106154349725524 | ||
53,0.0008807071328793311 | ||
54,0.0007910131242580489 | ||
55,0.000699931604917531 | ||
56,0.0006562820132999513 | ||
57,0.0005992152591999827 | ||
58,0.0005220298625399005 | ||
59,0.00046864753446095996 |