Skip to content

Commit

Permalink
Updated list of models to load; added a model description for those t…
Browse files Browse the repository at this point in the history
…hat have an embedding layer that learned simultaneously
  • Loading branch information
AndreCNF committed Aug 24, 2020
1 parent cd44178 commit 85dddb5
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 22 deletions.
38 changes: 25 additions & 13 deletions callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,44 +151,54 @@ def load_model_callback(model_name, dataset_name):
# Based on the chosen dataset and model type, set a file path to the desired model
if dataset_name == 'ALS':
subdata_path = 'ALS/'
if model_name == 'Bidir LSTM, embedded, time aware':
if model_name == 'Bidir LSTM, time aware':
# Specify the model file name and class
model_file_name = 'lstm_bidir_pre_embedded_delta_ts_90dayswindow_0.3481valloss_06_07_2020_04_15'
model_file_name = 'lstm_bidir_one_hot_encoded_delta_ts_90dayswindow_0.3784valloss_08_07_2020_04_14'
model_class = Models.VanillaLSTM
is_custom = False
elif model_name == 'Bidir LSTM, time aware':
elif model_name == 'Bidir LSTM, embedded':
# Specify the model file name and class
model_file_name = 'lstm_bidir_one_hot_encoded_delta_ts_90dayswindow_0.3809valloss_06_07_2020_04_08'
model_file_name = 'lstm_bidir_pre_embedded_90dayswindow_0.2490valloss_06_07_2020_03_47'
model_class = Models.VanillaLSTM
is_custom = False
elif model_name == 'Bidir LSTM, embedded':
elif model_name == 'Bidir LSTM':
# Specify the model file name and class
model_file_name = 'lstm_bidir_pre_embedded_90dayswindow_0.2490valloss_06_07_2020_03_47'
model_file_name = 'lstm_bidir_one_hot_encoded_90dayswindow_0.4497valloss_08_07_2020_04_31'
model_class = Models.VanillaLSTM
is_custom = False
elif model_name == 'Bidir LSTM, embedded, time aware':
# Specify the model file name and class
model_file_name = 'lstm_bidir_pre_embedded_delta_ts_90dayswindow_0.3705valloss_08_07_2020_04_04'
model_class = Models.VanillaLSTM
is_custom = False
elif model_name == 'LSTM':
# Specify the model file name and class
model_file_name = 'lstm_one_hot_encoded_90dayswindow_0.4363valloss_06_07_2020_03_28'
model_file_name = 'lstm_one_hot_encoded_90dayswindow_0.5125valloss_08_07_2020_04_41'
model_class = Models.VanillaLSTM
is_custom = False
elif model_name == 'Bidir RNN, embedded, time aware':
# Specify the model file name and class
model_file_name = 'rnn_bidir_pre_embedded_delta_ts_90dayswindow_0.3059valloss_06_07_2020_03_10'
model_file_name = 'rnn_bidir_pre_embedded_delta_ts_90dayswindow_0.3579valloss_08_07_2020_03_55'
model_class = Models.VanillaRNN
is_custom = False
elif model_name == 'RNN, embedded':
elif model_name == 'Bidir RNN':
# Specify the model file name and class
model_file_name = 'rnn_with_embedding_90dayswindow_0.5569valloss_30_06_2020_17_04.pth'
model_file_name = 'rnn_bidir_one_hot_encoded_90dayswindow_0.3713valloss_08_07_2020_04_49'
model_class = Models.VanillaRNN
is_custom = False
elif model_name == 'RNN, time aware':
# Specify the model file name and class
model_file_name = 'rnn_one_hot_encoded_delta_ts_90dayswindow_0.5354valloss_21_08_2020_04_24.pth'
model_class = Models.VanillaRNN
is_custom = False
elif model_name == 'RNN':
# Specify the model file name and class
model_file_name = 'rnn_one_hot_encoded_90dayswindow_0.5497valloss_30_06_2020_18_25.pth'
model_file_name = 'rnn_one_hot_encoded_90dayswindow_0.5445valloss_21_08_2020_04_34.pth'
model_class = Models.VanillaRNN
is_custom = False
elif model_name == 'MF1-LSTM':
# Specify the model file name and class
model_file_name = 'mf1lstm_one_hot_encoded_90dayswindow_0.6009valloss_07_07_2020_03_46'
model_file_name = 'mf1lstm_pre_embedded_90dayswindow_0.6516valloss_07_07_2020_03_35'
model_class = Models.MF1LSTM
is_custom = True
elif dataset_name == 'Toy Example':
Expand Down Expand Up @@ -276,7 +286,9 @@ def load_model_description(model_mod, model_file_name):
else:
description_list.append(dbc.ListGroupItem(descriptions['embedding']['one hot encoded']))
# Add time variation description
if 'delta_ts' in model_file_name:
if 'mf1lstm' in model_file_name:
description_list.append(dbc.ListGroupItem(descriptions['delta_ts']['integrates_in_model']))
elif 'delta_ts' in model_file_name:
description_list.append(dbc.ListGroupItem(descriptions['delta_ts']['uses delta_ts']))
else:
description_list.append(dbc.ListGroupItem(descriptions['delta_ts']['no delta_ts']))
Expand Down
24 changes: 16 additions & 8 deletions layouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@
dcc.Dropdown(
id='model_dropdown',
options=[
dict(label='Bidir LSTM, embedded, time aware', value='Bidir LSTM, embedded, time aware'),
dict(label='Bidir LSTM, time aware', value='Bidir LSTM, time aware'),
dict(label='Bidir LSTM, embedded', value='Bidir LSTM, embedded'),
dict(label='Bidir LSTM', value='Bidir LSTM'),
dict(label='Bidir LSTM, embedded, time aware', value='Bidir LSTM, embedded, time aware'),
dict(label='LSTM', value='LSTM'),
dict(label='Bidir RNN, embedded, time aware', value='Bidir RNN, embedded, time aware'),
dict(label='RNN, embedded', value='RNN, embedded'),
dict(label='Bidir RNN', value='Bidir RNN'),
dict(label='RNN, time aware', value='RNN, time aware'),
dict(label='RNN', value='RNN'),
dict(label='MF1-LSTM', value='MF1-LSTM')
],
Expand Down Expand Up @@ -305,12 +307,14 @@
dcc.Dropdown(
id='model_dropdown',
options=[
dict(label='Bidir LSTM, embedded, time aware', value='Bidir LSTM, embedded, time aware'),
dict(label='Bidir LSTM, time aware', value='Bidir LSTM, time aware'),
dict(label='Bidir LSTM, embedded', value='Bidir LSTM, embedded'),
dict(label='Bidir LSTM', value='Bidir LSTM'),
dict(label='Bidir LSTM, embedded, time aware', value='Bidir LSTM, embedded, time aware'),
dict(label='LSTM', value='LSTM'),
dict(label='Bidir RNN, embedded, time aware', value='Bidir RNN, embedded, time aware'),
dict(label='RNN, embedded', value='RNN, embedded'),
dict(label='Bidir RNN', value='Bidir RNN'),
dict(label='RNN, time aware', value='RNN, time aware'),
dict(label='RNN', value='RNN'),
dict(label='MF1-LSTM', value='MF1-LSTM')
],
Expand Down Expand Up @@ -502,12 +506,14 @@
dcc.Dropdown(
id='model_dropdown',
options=[
dict(label='Bidir LSTM, embedded, time aware', value='Bidir LSTM, embedded, time aware'),
dict(label='Bidir LSTM, time aware', value='Bidir LSTM, time aware'),
dict(label='Bidir LSTM, embedded', value='Bidir LSTM, embedded'),
dict(label='Bidir LSTM', value='Bidir LSTM'),
dict(label='Bidir LSTM, embedded, time aware', value='Bidir LSTM, embedded, time aware'),
dict(label='LSTM', value='LSTM'),
dict(label='Bidir RNN, embedded, time aware', value='Bidir RNN, embedded, time aware'),
dict(label='RNN, embedded', value='RNN, embedded'),
dict(label='Bidir RNN', value='Bidir RNN'),
dict(label='RNN, time aware', value='RNN, time aware'),
dict(label='RNN', value='RNN'),
dict(label='MF1-LSTM', value='MF1-LSTM')
],
Expand Down Expand Up @@ -781,12 +787,14 @@
dcc.Dropdown(
id='model_dropdown',
options=[
dict(label='Bidir LSTM, embedded, time aware', value='Bidir LSTM, embedded, time aware'),
dict(label='Bidir LSTM, time aware', value='Bidir LSTM, time aware'),
dict(label='Bidir LSTM, embedded', value='Bidir LSTM, embedded'),
dict(label='Bidir LSTM', value='Bidir LSTM'),
dict(label='Bidir LSTM, embedded, time aware', value='Bidir LSTM, embedded, time aware'),
dict(label='LSTM', value='LSTM'),
dict(label='Bidir RNN, embedded, time aware', value='Bidir RNN, embedded, time aware'),
dict(label='RNN, embedded', value='RNN, embedded'),
dict(label='Bidir RNN', value='Bidir RNN'),
dict(label='RNN, time aware', value='RNN, time aware'),
dict(label='RNN', value='RNN'),
dict(label='MF1-LSTM', value='MF1-LSTM')
],
Expand Down
4 changes: 3 additions & 1 deletion models/model_descriptions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@ embedding:
delta_ts:
uses delta_ts:
Has time variation between samples in consideration, using that information as an additional feature.
integrates_in_model:
Incorporates time variation information inside the model architecture.
no delta_ts:
Doesn't directly consider time variation between samples.
Doesn't directly consider time variation between samples.

0 comments on commit 85dddb5

Please sign in to comment.