Skip to content

Commit

Permalink
Fixed the instance importance calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreCNF committed Aug 18, 2020
1 parent 5bf9d82 commit a954dd7
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 20 deletions.
61 changes: 49 additions & 12 deletions callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from app import app
import layouts

# Set the random seed for reproducibility:
du.set_random_seed(42)
# Path to the directory where all the ML models are stored
models_path = 'models/'
metrics_path = 'metrics/individual_models/'
Expand Down Expand Up @@ -107,7 +109,8 @@ def load_dataset_callback(dataset_name, dataset_mod, model_file_name,

@app.callback([Output('model_store', 'data'),
Output('model_metrics', 'data'),
Output('model_hyperparam', 'data')],
Output('model_hyperparam', 'data'),
Output('is_custom_store', 'data')],
[Input('model_name_div', 'children'),
Input('dataset_name_div', 'children')])
def load_model_callback(model_name, dataset_name):
Expand All @@ -119,60 +122,74 @@ def load_model_callback(model_name, dataset_name):
# Specify the model file name and class
model_file_name = 'lstm_bidir_one_hot_encoded_delta_ts_90dayswindow_0.3809valloss_06_07_2020_04_08'
model_class = Models.VanillaLSTM
is_custom = False
elif model_name == 'Bidir LSTM, embedded, time aware':
# Specify the model file name and class
model_file_name = 'lstm_bidir_pre_embedded_delta_ts_90dayswindow_0.3481valloss_06_07_2020_04_15'
model_class = Models.VanillaLSTM
is_custom = False
elif model_name == 'Bidir LSTM, embedded':
# Specify the model file name and class
model_file_name = 'lstm_bidir_pre_embedded_90dayswindow_0.2490valloss_06_07_2020_03_47'
model_class = Models.VanillaLSTM
is_custom = False
elif model_name == 'LSTM':
# Specify the model file name and class
model_file_name = 'lstm_one_hot_encoded_90dayswindow_0.4363valloss_06_07_2020_03_28'
model_class = Models.VanillaLSTM
is_custom = False
elif model_name == 'Bidir RNN, embedded, time aware':
# Specify the model file name and class
model_file_name = 'rnn_bidir_pre_embedded_delta_ts_90dayswindow_0.3059valloss_06_07_2020_03_10'
model_class = Models.VanillaRNN
is_custom = False
elif model_name == 'RNN, embedded':
# Specify the model file name and class
model_file_name = 'rnn_with_embedding_90dayswindow_0.5569valloss_30_06_2020_17_04.pth'
model_class = Models.VanillaRNN
is_custom = False
elif model_name == 'MF1-LSTM':
# Specify the model file name and class
model_file_name = 'mf1lstm_one_hot_encoded_90dayswindow_0.6009valloss_07_07_2020_03_46'
model_class = Models.MF1LSTM
is_custom = True
elif dataset_name == 'Toy Example':
# [TODO] Train and add each model for the toy example
if model_name == 'Bidir LSTM, time aware':
# Specify the model file name and class
model_file_name = ''
model_class = Models.VanillaLSTM
is_custom = False
elif model_name == 'Bidir LSTM, embedded, time aware':
# Specify the model file name and class
model_file_name = ''
model_class = Models.VanillaLSTM
is_custom = False
elif model_name == 'Bidir LSTM, embedded':
# Specify the model file name and class
model_file_name = ''
model_class = Models.VanillaLSTM
is_custom = False
elif model_name == 'LSTM':
# Specify the model file name and class
model_file_name = ''
model_class = Models.VanillaLSTM
is_custom = False
elif model_name == 'Bidir RNN, embedded, time aware':
# Specify the model file name and class
model_file_name = ''
model_class = Models.VanillaRNN
is_custom = False
elif model_name == 'RNN, embedded':
# Specify the model file name and class
model_file_name = ''
model_class = Models.VanillaRNN
is_custom = False
elif model_name == 'MF1-LSTM':
# Specify the model file name and class
model_file_name = ''
model_class = Models.MF1LSTM
is_custom = True
else:
raise Exception(f'ERROR: The HAI dashboarded isn\'t currently suited to load the dataset named {dataset_name}.')
# Load the metrics file
Expand All @@ -186,7 +203,8 @@ def load_model_callback(model_name, dataset_name):
for param in model_args])
return (model_file_name,
metrics,
hyperparams)
hyperparams,
is_custom)

@app.callback(Output('model_description_list', 'children'),
[Input('model_store', 'modified_timestamp')],
Expand Down Expand Up @@ -681,7 +699,7 @@ def update_det_analysis_preview(df_store, model_file_name, id_column):
ids = random.sample(ids, 4)
tmp_df = tmp_df[tmp_df[id_column].isin(ids)]
# Calculate the instance importance scores (it should be fast enough; otherwise try to do it previously and integrate on the dataframe)
interpreter = ModelInterpreter(model, tmp_df, inst_column=1, is_custom=True)
interpreter = ModelInterpreter(model, tmp_df, inst_column=1, is_custom=is_custom)
interpreter.interpret_model(instance_importance=True, feature_importance=False)
# Get the instance importance plot
return interpreter.instance_importance_plot(interpreter.test_data,
Expand All @@ -707,7 +725,8 @@ def update_det_analysis_preview_callback(dataset_mod, model_mod, df_store, model
return update_det_analysis_preview(df_store, model_file_name, id_column)

@cache.memoize(timeout=TIMEOUT)
def update_full_inst_import(df_store, model_file_name):
def update_full_inst_import(df_store, model_file_name, id_column,
ts_column, label_column, is_custom):
global models_path
# Reconvert the dataframe to Pandas
df = pd.DataFrame(df_store)
Expand All @@ -724,12 +743,16 @@ def update_full_inst_import(df_store, model_file_name):
# Guarantee that the model is in evaluation mode, so as to deactivate dropout
model.eval()
# Create a dataframe copy that doesn't include the feature importance columns
column_names = [feature for feature in df.columns
if not feature.endswith('_shap')]
shap_column_names = [feature for feature in df.columns
if feature.endswith('_shap')]
column_names = list(df.columns)
[column_names.remove(shap_column) for shap_column in shap_column_names]
tmp_df = df.copy()
tmp_df = tmp_df[column_names]
# Calculate the instance importance scores (it should be fast enough; otherwise try to do it previously and integrate on the dataframe)
interpreter = ModelInterpreter(model, tmp_df, inst_column=1, is_custom=True)
interpreter = ModelInterpreter(model, tmp_df, id_column_name=id_column, inst_column_name=ts_column,
label_column_name=label_column, fast_calc=True,
padding_value=padding_value, is_custom=is_custom, occlusion_wgt=0.7)
interpreter.interpret_model(instance_importance=True, feature_importance=False)
# Get the instance importance plot
return interpreter.instance_importance_plot(interpreter.test_data,
Expand All @@ -746,11 +769,23 @@ def update_full_inst_import(df_store, model_file_name):

@app.callback(Output('instance_importance_graph', 'figure'),
[Input('dataset_store', 'modified_timestamp'),
Input('model_store', 'modified_timestamp')],
Input('model_store', 'modified_timestamp'),
Input('id_col_name_store', 'modified_timestamp'),
Input('ts_col_name_store', 'modified_timestamp'),
Input('label_col_name_store', 'modified_timestamp'),
Input('is_custom_store', 'modified_timestamp')],
[State('dataset_store', 'data'),
State('model_store', 'data')])
def update_full_inst_import_callback(dataset_mod, model_mod, df_store, model_file_name):
return update_full_inst_import(df_store, model_file_name)
State('model_store', 'data'),
State('id_col_name_store', 'data'),
State('ts_col_name_store', 'data'),
State('label_col_name_store', 'data'),
State('is_custom_store', 'data')])
def update_full_inst_import_callback(dataset_mod, model_mod, id_column_mod, ts_column_mod,
label_column_mod, is_custom_mod, df_store,
model_file_name, id_column, ts_column, label_column,
is_custom):
return update_full_inst_import(df_store, model_file_name, id_column,
ts_column, label_column, is_custom)

@app.callback(Output('salient_features_list', 'children'),
[Input('instance_importance_graph', 'hoverData'),
Expand Down Expand Up @@ -860,11 +895,12 @@ def update_ts_feat_import(hovered_data, clicked_data, dataset_mod,
State('model_store', 'data'),
State('model_name_div', 'children'),
State('id_col_name_store', 'data'),
State('ts_col_name_store', 'data'),
State('clicked_ts', 'children'),
State('hovered_ts', 'children'),
State('curr_final_output', 'data')])
def update_final_output(dataset_mod, hovered_data, clicked_data, df_store, model_file_name,
model_name, id_column, clicked_ts, hovered_ts, prev_output):
model_name, id_column, ts_column, clicked_ts, hovered_ts, prev_output):
global is_custom
global clicked_thrsh
global models_path
Expand Down Expand Up @@ -897,6 +933,7 @@ def update_final_output(dataset_mod, hovered_data, clicked_data, df_store, model
# Only use the model-relevant features
feature_names = [feature.split('_shap')[0] for feature in df.columns
if feature.endswith('_shap')]
# feature_names = [id_column, ts_column] + feature_names
filtered_df = filtered_df[feature_names]
data = torch.from_numpy(filtered_df.values)
# Remove unwanted columns from the data
Expand Down
5 changes: 5 additions & 0 deletions layouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
dcc.Store(id='label_col_name_store', storage_type='memory'),
dcc.Store(id='cols_to_remove_store', storage_type='memory'),
dcc.Store(id='expected_value_store', storage_type='memory'),
dcc.Store(id='is_custom_store', storage_type='memory'),
# Chosen machine learning model
html.Div(id='model_name_div', children='LSTM', hidden=True),
dcc.Store(id='model_store', storage_type='memory'),
Expand Down Expand Up @@ -267,6 +268,7 @@
dcc.Store(id='label_col_name_store', storage_type='memory'),
dcc.Store(id='cols_to_remove_store', storage_type='memory'),
dcc.Store(id='expected_value_store', storage_type='memory'),
dcc.Store(id='is_custom_store', storage_type='memory'),
# Chosen machine learning model
html.Div(id='model_name_div', children='LSTM', hidden=True),
dcc.Store(id='model_store', storage_type='memory'),
Expand Down Expand Up @@ -472,6 +474,7 @@
dcc.Store(id='label_col_name_store', storage_type='memory'),
dcc.Store(id='cols_to_remove_store', storage_type='memory'),
dcc.Store(id='expected_value_store', storage_type='memory'),
dcc.Store(id='is_custom_store', storage_type='memory'),
# Chosen machine learning model
html.Div(id='model_name_div', children='LSTM', hidden=True),
dcc.Store(id='model_store', storage_type='memory'),
Expand Down Expand Up @@ -596,6 +599,7 @@
dcc.Store(id='label_col_name_store', storage_type='memory'),
dcc.Store(id='cols_to_remove_store', storage_type='memory'),
dcc.Store(id='expected_value_store', storage_type='memory'),
dcc.Store(id='is_custom_store', storage_type='memory'),
# Chosen machine learning model
html.Div(id='model_name_div', children='LSTM', hidden=True),
dcc.Store(id='model_store', storage_type='memory'),
Expand Down Expand Up @@ -661,6 +665,7 @@
dcc.Store(id='label_col_name_store', storage_type='memory'),
dcc.Store(id='cols_to_remove_store', storage_type='memory'),
dcc.Store(id='expected_value_store', storage_type='memory'),
dcc.Store(id='is_custom_store', storage_type='memory'),
# Chosen machine learning model
html.Div(id='model_name_div', children='LSTM', hidden=True),
dcc.Store(id='model_store', storage_type='memory'),
Expand Down
13 changes: 5 additions & 8 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

102 changes: 102 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
alabaster==0.7.12
attrs==19.3.0
Babel==2.8.0
certifi==2020.6.20
chardet==3.0.4
Click==7.0
cloudpickle==1.5.0
colorlover==0.3.0
comet-git-pure==0.19.16
comet-ml==3.1.16
configobj==5.0.6
dash==1.14.0
dash-bootstrap-components==0.10.3
dash-core-components==1.10.2
dash-html-components==1.0.3
dash-renderer==1.6.0
dash-table==4.9.0
dask==2.22.0
-e git+https://github.com/AndreCNF/data-utils.git@91455bbd0768377d33253c5e573f538143d3fe48#egg=data_utils
distributed==2.22.0
docutils==0.16
everett==1.0.2
falcon==2.0.0
Flask==1.1.1
Flask-Caching==1.9.0
Flask-Compress==1.4.0
fsspec==0.6.3
future==0.18.2
gitdb==4.0.5
gitdb2==4.0.2
GitPython==3.1.7
HeapDict==1.0.1
hug==2.6.1
idna==2.10
imagesize==1.2.0
importlib-metadata==1.7.0
itsdangerous==1.1.0
Jinja2==2.11.1
joblib==0.14.1
jsonschema==3.2.0
livereload==2.6.2
lunr==0.5.8
Mako==1.1.3
Markdown==3.2.2
MarkupSafe==1.1.1
mkdocs==1.1.2
mkdocs-material==5.5.6
mkdocs-material-extensions==1.0
-e git+https://github.com/AndreCNF/model-interpreter.git@122b00416b27d318e9629359966fa43a14c2eb13#egg=model_interpreter
modin==0.7.3
msgpack==1.0.0
netifaces==0.10.9
nltk==3.5
numpy==1.18.1
nvidia-ml-py3==7.352.0
packaging==20.1
pandas==1.0.3
pdoc3==0.7.5
pdocs==1.0.2
Pillow==7.2.0
plotly==4.9.0
portray==1.4.0
psutil==5.7.2
Pygments==2.6.1
pymdown-extensions==7.1
pyparsing==2.4.6
pyrsistent==0.16.0
python-dateutil==2.8.1
pytz==2019.3
PyYAML==5.3.1
regex==2020.7.14
requests==2.24.0
retrying==1.3.3
scikit-learn==0.22.1
scipy==1.4.1
-e git+https://github.com/AndreCNF/shap.git@18e24569d551b9f53ea73dc44f30866cdc92623b#egg=shap
six==1.14.0
sklearn==0.0
smmap==3.0.4
snowballstemmer==2.0.0
sortedcontainers==2.2.2
Sphinx==2.4.4
sphinxcontrib-applehelp==1.0.2
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-htmlhelp==1.0.3
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.4
tblib==1.7.0
toml==0.10.1
toolz==0.10.0
torch==1.4.0
tornado==6.0.4
tqdm==4.43.0
urllib3==1.25.10
websocket-client==0.57.0
Werkzeug==1.0.0
wrapt==1.12.1
wurlitzer==2.0.1
yaspin==0.15.0
zict==2.0.0
zipp==3.1.0

0 comments on commit a954dd7

Please sign in to comment.