19
19
from torch .nn import CrossEntropyLoss , L1Loss , MSELoss , NLLLoss
20
20
from torch .optim import SGD , Adam , AdamW , Optimizer
21
21
from torch .optim .lr_scheduler import MultiStepLR , _LRScheduler
22
- from torch .utils .data import DataLoader , Subset
22
+ from torch .utils .data import DataLoader , Dataset , Subset
23
23
from torch .utils .tensorboard import SummaryWriter
24
24
25
25
from aviary .core import BaseModelClass , Normalizer , TaskType , sampled_softmax
@@ -45,13 +45,13 @@ def init_model(
45
45
Args:
46
46
model_class (type[BaseModelClass]): Which model class to initialize.
47
47
model_params (dict[str, Any]): Dictionary containing model specific hyperparameters.
48
- device (type[torch.device] | Literal[ "cuda", "cpu"] ): Device the model will run on.
48
+ device (type[torch.device] | "cuda" | "cpu"): Device the model will run on.
49
49
resume (str, optional): Path to model checkpoint to resume. Defaults to None.
50
50
fine_tune (str, optional): Path to model checkpoint to fine tune. Defaults to None.
51
51
transfer (str, optional): Path to model checkpoint to transfer. Defaults to None.
52
52
53
53
Returns:
54
- type[ BaseModelClass] : An initialised model of type model_class
54
+ BaseModelClass: An initialised model of type model_class.
55
55
"""
56
56
robust = model_params ["robust" ]
57
57
n_targets = model_params ["n_targets" ]
@@ -149,11 +149,11 @@ def init_optim(
149
149
150
150
Args:
151
151
model (type[BaseModelClass]): Model to be optimized.
152
- optim (type[Optimizer] | Literal[ "SGD", "Adam", "AdamW"] ): Which optimizer to use
153
- learning_rate (float): Learning rate for optimzation
152
+ optim (type[Optimizer] | "SGD" | "Adam" | "AdamW"): Which optimizer to use
153
+ learning_rate (float): Learning rate for optimization
154
154
weight_decay (float): Weight decay for optimizer
155
155
momentum (float): Momentum for optimizer
156
- device (type[torch.device] | Literal[ "cuda", "cpu"] ): Device the model will run on
156
+ device (type[torch.device] | "cuda" | "cpu"): Device the model will run on
157
157
milestones (Iterable, optional): When to decay learning rate. Defaults to ().
158
158
gamma (float, optional): Multiplier for learning rate decay. Defaults to 0.3.
159
159
resume (str, optional): Path to model checkpoint to resume. Defaults to None.
@@ -203,7 +203,7 @@ def init_losses(
203
203
204
204
Args:
205
205
task_dict (dict[str, TaskType]): Map of target names to "regression" or "classification".
206
- loss_dict (dict[str, Literal[ "L1", "L2", "CSE"] ]): Map of target names to loss functions.
206
+ loss_dict (dict[str, "L1" | "L2" | "CSE"]): Map of target names to loss functions.
207
207
robust (bool, optional): Whether to use an uncertainty adjusted loss. Defaults to False.
208
208
209
209
Returns:
@@ -253,7 +253,7 @@ def init_normalizers(
253
253
254
254
Args:
255
255
task_dict (dict[str, TaskType]): Map of target names to "regression" or "classification".
256
- device (type[ torch.device] | Literal[ "cuda", "cpu"] ): Device the model will run on
256
+ device (torch.device | "cuda" | "cpu"): Device the model will run on
257
257
resume (str, optional): Path to model checkpoint to resume. Defaults to None.
258
258
259
259
Returns:
@@ -284,8 +284,8 @@ def train_ensemble(
284
284
run_id : int ,
285
285
ensemble_folds : int ,
286
286
epochs : int ,
287
- train_set : Subset ,
288
- val_set : Subset ,
287
+ train_set : Dataset | Subset ,
288
+ val_set : Dataset | Subset ,
289
289
log : bool ,
290
290
data_params : dict [str , Any ],
291
291
setup_params : dict [str , Any ],
@@ -310,12 +310,17 @@ def train_ensemble(
310
310
setup_params (dict[str, Any]): Dictionary of setup parameters
311
311
restart_params (dict[str, Any]): Dictionary of restart parameters
312
312
model_params (dict[str, Any]): Dictionary of model parameters
313
- loss_dict (dict[str, Literal[ "L1", "L2", "CSE"] ]): Map of target names
313
+ loss_dict (dict[str, "L1" | "L2" | "CSE"]): Map of target names
314
314
to loss functions.
315
315
patience (int, optional): Maximum number of epochs without improvement
316
316
when early stopping. Defaults to None.
317
317
verbose (bool, optional): Whether to show progress bars for each epoch.
318
318
"""
319
+ if isinstance (train_set , Subset ):
320
+ train_set = train_set .dataset
321
+ if isinstance (val_set , Subset ):
322
+ val_set = val_set .dataset
323
+
319
324
train_generator = DataLoader (train_set , ** data_params )
320
325
print (f"Training on { len (train_set ):,} samples" )
321
326
@@ -350,13 +355,11 @@ def train_ensemble(
350
355
351
356
for target , normalizer in normalizer_dict .items ():
352
357
if normalizer is not None :
353
- sample_target = Tensor (
354
- train_set .dataset .df [target ].iloc [train_set .indices ].values
355
- )
358
+ sample_target = Tensor (train_set .df [target ].values )
356
359
if not restart_params ["resume" ]:
357
360
normalizer .fit (sample_target )
358
361
print (
359
- f"Dummy MAE: { torch . mean ( torch . abs ( sample_target - normalizer .mean )):.4f} "
362
+ f"Dummy MAE: { ( sample_target - normalizer .mean ). abs (). mean ( ):.4f} "
360
363
)
361
364
362
365
if log :
@@ -415,7 +418,7 @@ def results_multitask( # noqa: C901
415
418
model_name : str ,
416
419
run_id : int ,
417
420
ensemble_folds : int ,
418
- test_set : Subset ,
421
+ test_set : Dataset | Subset ,
419
422
data_params : dict [str , Any ],
420
423
robust : bool ,
421
424
task_dict : dict [str , TaskType ],
@@ -436,7 +439,7 @@ def results_multitask( # noqa: C901
436
439
loss function.
437
440
task_dict (dict[str, TaskType]): Map of target names to "regression" or
438
441
"classification".
439
- device (type[torch.device] | Literal[ "cuda", "cpu"] ): Device the model will run on
442
+ device (type[torch.device] | "cuda" | "cpu"): Device the model will run on
440
443
eval_type (str, optional): Whether to use final or early-stopping checkpoints.
441
444
Defaults to "checkpoint".
442
445
print_results (bool, optional): Whether to print out summary metrics.
@@ -459,6 +462,9 @@ def results_multitask( # noqa: C901
459
462
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n "
460
463
)
461
464
465
+ if isinstance (test_set , Subset ):
466
+ test_set = test_set .dataset
467
+
462
468
test_generator = DataLoader (test_set , ** data_params )
463
469
print (f"Testing on { len (test_set ):,} samples" )
464
470
0 commit comments