@@ -316,17 +316,13 @@ def train_ensemble(
316
316
when early stopping. Defaults to None.
317
317
verbose (bool, optional): Whether to show progress bars for each epoch.
318
318
"""
319
- if isinstance (train_set , Subset ):
320
- train_set = train_set .dataset
321
- if isinstance (val_set , Subset ):
322
- val_set = val_set .dataset
323
-
324
319
train_loader = DataLoader (train_set , ** data_params )
325
320
print (f"Training on { len (train_set ):,} samples" )
326
321
327
322
if val_set is not None :
328
323
data_params .update ({"batch_size" : 16 * data_params ["batch_size" ]})
329
324
val_loader = DataLoader (val_set , ** data_params )
325
+ print (f"Validating on { len (val_set ):,} samples" )
330
326
else :
331
327
val_loader = None
332
328
@@ -354,7 +350,13 @@ def train_ensemble(
354
350
355
351
for target , normalizer in normalizer_dict .items ():
356
352
if normalizer is not None :
357
- sample_target = Tensor (train_set .df [target ].values )
353
+ if isinstance (train_set , Subset ):
354
+ sample_target = Tensor (
355
+ train_set .dataset .df [target ].iloc [train_set .indices ].values
356
+ )
357
+ else :
358
+ sample_target = Tensor (train_set .df [target ].values )
359
+
358
360
if not restart_params ["resume" ]:
359
361
normalizer .fit (sample_target )
360
362
print (f"Dummy MAE: { (sample_target - normalizer .mean ).abs ().mean ():.4f} " )
@@ -455,10 +457,6 @@ def results_multitask(
455
457
"------------Evaluate model on Test Set------------\n "
456
458
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n "
457
459
)
458
-
459
- if isinstance (test_set , Subset ):
460
- test_set = test_set .dataset
461
-
462
460
test_loader = DataLoader (test_set , ** data_params )
463
461
print (f"Testing on { len (test_set ):,} samples" )
464
462
0 commit comments