Skip to content

Commit 181e2b2

Browse files
authored
Address subset issue highlighted in #95 (#97)
* fix: address subset issue highlighted in #95 * test: add checks that would have caught the test set subset issue
1 parent 63e0ea0 commit 181e2b2

8 files changed

+16
-11
lines changed

aviary/core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def evaluate(
346346
metrics_str = " ".join(
347347
f"{key} {val:<9.2f}" for key, val in avrg_metrics[target].items()
348348
)
349-
print(f"{action:>9}: {target} {metrics_str}")
349+
print(f"{action:>9}: {target} N {len(data_loader):,} {metrics_str}")
350350

351351
return avrg_metrics
352352

aviary/utils.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -316,17 +316,13 @@ def train_ensemble(
316316
when early stopping. Defaults to None.
317317
verbose (bool, optional): Whether to show progress bars for each epoch.
318318
"""
319-
if isinstance(train_set, Subset):
320-
train_set = train_set.dataset
321-
if isinstance(val_set, Subset):
322-
val_set = val_set.dataset
323-
324319
train_loader = DataLoader(train_set, **data_params)
325320
print(f"Training on {len(train_set):,} samples")
326321

327322
if val_set is not None:
328323
data_params.update({"batch_size": 16 * data_params["batch_size"]})
329324
val_loader = DataLoader(val_set, **data_params)
325+
print(f"Validating on {len(val_set):,} samples")
330326
else:
331327
val_loader = None
332328

@@ -354,7 +350,13 @@ def train_ensemble(
354350

355351
for target, normalizer in normalizer_dict.items():
356352
if normalizer is not None:
357-
sample_target = Tensor(train_set.df[target].values)
353+
if isinstance(train_set, Subset):
354+
sample_target = Tensor(
355+
train_set.dataset.df[target].iloc[train_set.indices].values
356+
)
357+
else:
358+
sample_target = Tensor(train_set.df[target].values)
359+
358360
if not restart_params["resume"]:
359361
normalizer.fit(sample_target)
360362
print(f"Dummy MAE: {(sample_target - normalizer.mean).abs().mean():.4f}")
@@ -455,10 +457,6 @@ def results_multitask(
455457
"------------Evaluate model on Test Set------------\n"
456458
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
457459
)
458-
459-
if isinstance(test_set, Subset):
460-
test_set = test_set.dataset
461-
462460
test_loader = DataLoader(test_set, **data_params)
463461
print(f"Testing on {len(test_set):,} samples")
464462

tests/test_cgcnn_classification.py

+1
Original file line numberDiff line numberDiff line change
@@ -136,5 +136,6 @@ def test_cgcnn_clf(df_matbench_phonons):
136136

137137
ens_acc, *_, ens_roc_auc = get_metrics(targets, ens_logits, task).values()
138138

139+
assert len(targets) == len(test_set) == len(test_idx)
139140
assert ens_acc > 0.85
140141
assert ens_roc_auc > 0.9

tests/test_cgcnn_regression.py

+1
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def test_cgcnn_regression(df_matbench_phonons):
135135

136136
mae, rmse, r2 = get_metrics(targets, y_ens, task).values()
137137

138+
assert len(targets) == len(test_set) == len(test_idx)
138139
assert r2 > 0.7
139140
assert mae < 150
140141
assert rmse < 300

tests/test_roost_classification.py

+2
Original file line numberDiff line numberDiff line change
@@ -138,5 +138,7 @@ def test_roost_clf(df_matbench_phonons):
138138

139139
ens_acc, *_, ens_roc_auc = get_metrics(targets, ens_logits, task).values()
140140

141+
assert len(logits) == ensemble
142+
assert len(targets) == len(test_set) == len(test_idx)
141143
assert ens_acc > 0.9
142144
assert ens_roc_auc > 0.9

tests/test_roost_regression.py

+1
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def test_roost_regression(df_matbench_phonons):
137137

138138
mae, rmse, r2 = get_metrics(targets, y_ens, task).values()
139139

140+
assert len(targets) == len(test_set) == len(test_idx)
140141
assert r2 > 0.7
141142
assert mae < 150
142143
assert rmse < 300

tests/test_wren_classification.py

+1
Original file line numberDiff line numberDiff line change
@@ -146,5 +146,6 @@ def test_wren_clf(df_matbench_phonons_wyckoff):
146146

147147
ens_acc, *_, ens_roc_auc = get_metrics(targets, ens_logits, task).values()
148148

149+
assert len(targets) == len(test_set) == len(test_idx)
149150
assert ens_acc > 0.85
150151
assert ens_roc_auc > 0.9

tests/test_wren_regression.py

+1
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def test_wren_regression(df_matbench_phonons_wyckoff):
145145

146146
mae, rmse, r2 = get_metrics(targets, y_ens, task).values()
147147

148+
assert len(targets) == len(test_set) == len(test_idx)
148149
assert r2 > 0.7
149150
assert mae < 150
150151
assert rmse < 300

0 commit comments

Comments
 (0)