Skip to content

Commit

Permalink
Fixed filtering and card sizes in the feature importance page
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreCNF committed Aug 19, 2020
1 parent 68c0f90 commit 0dee235
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 15 deletions.
35 changes: 30 additions & 5 deletions callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,22 @@ def update_sample_table(n_clicks, dataset_store, id_column, ts_column, label_col
return (data_columns, filtered_df.to_dict('records'), False,
f'Sample from patient {subject_id} on timestamp {ts}')

@app.callback(Output('feature_importance_dropdown', 'options'),
[Input('dataset_store', 'modified_timestamp')],
[State('dataset_store', 'data')])
def load_feat_import_filter_options(dataset_mod, dataset_store):
# Reconvert the dataframe to Pandas
df = pd.DataFrame(dataset_store)
# Get the list of columns without the identifier and the feature importance columns
shap_column_names = [feature for feature in df.columns
if feature.endswith('_shap')]
feature_names = [feature.split('_shap')[0] for feature in shap_column_names]
# Create the feature dropdown filter
options = list()
options.append(dict(label='All', value='All'))
[options.append(dict(label=feat, value=feat)) for feat in feature_names]
return options

# Page headers callbacks
@app.callback(Output('model_perf_header', 'children'),
[Input('model_name_div', 'children')])
Expand Down Expand Up @@ -1664,7 +1680,9 @@ def create_fltd_feat_import_cards(df, data_filter=None):
if data_filter is None:
# Use the full dataframe
cards_list.append(create_feat_import_card(df, card_title='Feature importance',
xaxis_title='mean(|SHAP value|)'))
max_display=15,
xaxis_title='mean(|SHAP value|)',
card_height=None))
else:
# Filter the data on the specified filter (categorical feature)
categ_columns = [column for column in df.columns
Expand All @@ -1691,7 +1709,9 @@ def create_fltd_feat_import_cards(df, data_filter=None):
card_title = f'Feature importance on data with {data_filter} = {categ.split(data_filter)[1]}'
# Add a feature importance card
cards_list.append(create_feat_import_card(filtered_df, card_title=card_title,
xaxis_title='mean(|SHAP value|)'))
max_display=15,
xaxis_title='mean(|SHAP value|)',
card_height=None))
return cards_list

@app.callback(Output('feature_importance_cards', 'children'),
Expand Down Expand Up @@ -2134,11 +2154,11 @@ def apply_data_changes(new_data, dataset_store, id_column, ts_column, label_colu
# Guarantee that the model is in evaluation mode, so as to deactivate dropout
model.eval()
# Recalculate the SHAP values
interpreter = ModelInterpreter(model, df, id_column_name=id_column, inst_column_name=ts_column,
interpreter = ModelInterpreter(model, id_column_name=id_column, inst_column_name=ts_column,
label_column_name=label_column, fast_calc=True,
padding_value=padding_value, is_custom=is_custom,
seq_len_dict=seq_len_dict,
feat_names=feature_names)
feat_names=feature_names+[label_column])
_ = interpreter.interpret_model(test_data=features, test_labels=labels,
instance_importance=False, feature_importance='shap')
# Join the updated SHAP values to the dataframe
Expand Down Expand Up @@ -2175,10 +2195,15 @@ def calc_exp_val(dataset_mod, model_mod, model_file_name, dataset_name):
# Load the model
model = du.deep_learning.load_checkpoint(filepath=f'{models_path}{subdata_path}{model_file_name}.pth',
ModelClass=model_class)
if 'mf1lstm' in model_file_name:
# Account for the `delta_ts` column, which isn't directly used as a feature
n_inputs = model.n_inputs+1
else:
n_inputs = model.n_inputs
# Put the model in evaluation mode to deactivate dropout
model.eval()
# Create an all zeroes reference value
ref_value = torch.zeros((1, 1, model.n_inputs))
ref_value = torch.zeros((1, 1, n_inputs))
# Calculate the expected value by getting the output from the reference value
exp_val = model(ref_value).item()
return exp_val
2 changes: 1 addition & 1 deletion index.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@
])

if __name__ == '__main__':
app.run_server(debug=True)
app.run_server(debug=False)
10 changes: 1 addition & 9 deletions layouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,15 +826,7 @@
id='feature_importance_dropdown',
# [TODO] Add options dynamically, according to the dataset's categorical features
options=[
dict(label='All', value='All'),
dict(label='Var0', value='Var0'),
dict(label='Var4_a', value='Var4_a'),
dict(label='Var4_b', value='Var4_b'),
dict(label='Var4_c', value='Var4_c'),
# dict(label='Sex', value='Sex'),
# dict(label='Age', value='Age'),
# dict(label='Diagnostic', value='Diagnostic'),
# dict(label='Treatment', value='Treatment')
dict(label='All', value='All')
],
placeholder='Choose how to filter the data',
value='All',
Expand Down

0 comments on commit 0dee235

Please sign in to comment.