Skip to content

Commit e91f75a

Browse files
authored
Handle Dataset and Subset in train_ensemble() and results_multitask() (#42)
* handle Dataset and Subset in train_ensemble() and results_multitask() * more consise literal types with pipe char in doc strings
1 parent 8fa857a commit e91f75a

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

aviary/core.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
Args:
4545
task_dict (dict[str, TaskType]): Map target names to "regression" or "classification".
4646
robust (bool): Whether to estimate standard deviation for use in a robust loss function
47-
device (type[torch.device] | Literal["cuda", "cpu"]): Device the model will run on.
47+
device (torch.device | "cuda" | "cpu"): Device the model will run on.
4848
epoch (int, optional): Epoch model training will begin/resume from. Defaults to 1.
4949
best_val_scores (dict[str, float], optional): Validation score to use for early
5050
stopping. Defaults to None.
@@ -228,12 +228,12 @@ def evaluate(
228228
optimizer (torch.optim.Optimizer): PyTorch Optimizer
229229
normalizer_dict (dict[str, Normalizer]): Dictionary of Normalizers to apply
230230
to each task.
231-
action (Literal["train", "val"], optional): Whether to track gradients depending on
231+
action ("train" | "val"], optional): Whether to track gradients depending on
232232
whether we are carrying out a training or validation pass. Defaults to "train".
233233
verbose (bool, optional): Whether to print out intermediate results. Defaults to False.
234234
235235
Returns:
236-
dict[str, dict[Literal["Loss", "MAE", "RMSE", "Acc", "F1"], np.ndarray]]: nested
236+
dict[str, dict["Loss" | "MAE" | "RMSE" | "Acc" | "F1", np.ndarray]]: nested
237237
dictionary of metrics for each task.
238238
"""
239239
if action == "val":

aviary/utils.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch.nn import CrossEntropyLoss, L1Loss, MSELoss, NLLLoss
2020
from torch.optim import SGD, Adam, AdamW, Optimizer
2121
from torch.optim.lr_scheduler import MultiStepLR, _LRScheduler
22-
from torch.utils.data import DataLoader, Subset
22+
from torch.utils.data import DataLoader, Dataset, Subset
2323
from torch.utils.tensorboard import SummaryWriter
2424

2525
from aviary.core import BaseModelClass, Normalizer, TaskType, sampled_softmax
@@ -45,13 +45,13 @@ def init_model(
4545
Args:
4646
model_class (type[BaseModelClass]): Which model class to initialize.
4747
model_params (dict[str, Any]): Dictionary containing model specific hyperparameters.
48-
device (type[torch.device] | Literal["cuda", "cpu"]): Device the model will run on.
48+
device (type[torch.device] | "cuda" | "cpu"): Device the model will run on.
4949
resume (str, optional): Path to model checkpoint to resume. Defaults to None.
5050
fine_tune (str, optional): Path to model checkpoint to fine tune. Defaults to None.
5151
transfer (str, optional): Path to model checkpoint to transfer. Defaults to None.
5252
5353
Returns:
54-
type[BaseModelClass]: An initialised model of type model_class
54+
BaseModelClass: An initialised model of type model_class.
5555
"""
5656
robust = model_params["robust"]
5757
n_targets = model_params["n_targets"]
@@ -149,11 +149,11 @@ def init_optim(
149149
150150
Args:
151151
model (type[BaseModelClass]): Model to be optimized.
152-
optim (type[Optimizer] | Literal["SGD", "Adam", "AdamW"]): Which optimizer to use
153-
learning_rate (float): Learning rate for optimzation
152+
optim (type[Optimizer] | "SGD" | "Adam" | "AdamW"): Which optimizer to use
153+
learning_rate (float): Learning rate for optimization
154154
weight_decay (float): Weight decay for optimizer
155155
momentum (float): Momentum for optimizer
156-
device (type[torch.device] | Literal["cuda", "cpu"]): Device the model will run on
156+
device (type[torch.device] | "cuda" | "cpu"): Device the model will run on
157157
milestones (Iterable, optional): When to decay learning rate. Defaults to ().
158158
gamma (float, optional): Multiplier for learning rate decay. Defaults to 0.3.
159159
resume (str, optional): Path to model checkpoint to resume. Defaults to None.
@@ -203,7 +203,7 @@ def init_losses(
203203
204204
Args:
205205
task_dict (dict[str, TaskType]): Map of target names to "regression" or "classification".
206-
loss_dict (dict[str, Literal["L1", "L2", "CSE"]]): Map of target names to loss functions.
206+
loss_dict (dict[str, "L1" | "L2" | "CSE"]): Map of target names to loss functions.
207207
robust (bool, optional): Whether to use an uncertainty adjusted loss. Defaults to False.
208208
209209
Returns:
@@ -253,7 +253,7 @@ def init_normalizers(
253253
254254
Args:
255255
task_dict (dict[str, TaskType]): Map of target names to "regression" or "classification".
256-
device (type[torch.device] | Literal["cuda", "cpu"]): Device the model will run on
256+
device (torch.device | "cuda" | "cpu"): Device the model will run on
257257
resume (str, optional): Path to model checkpoint to resume. Defaults to None.
258258
259259
Returns:
@@ -284,8 +284,8 @@ def train_ensemble(
284284
run_id: int,
285285
ensemble_folds: int,
286286
epochs: int,
287-
train_set: Subset,
288-
val_set: Subset,
287+
train_set: Dataset | Subset,
288+
val_set: Dataset | Subset,
289289
log: bool,
290290
data_params: dict[str, Any],
291291
setup_params: dict[str, Any],
@@ -310,12 +310,17 @@ def train_ensemble(
310310
setup_params (dict[str, Any]): Dictionary of setup parameters
311311
restart_params (dict[str, Any]): Dictionary of restart parameters
312312
model_params (dict[str, Any]): Dictionary of model parameters
313-
loss_dict (dict[str, Literal["L1", "L2", "CSE"]]): Map of target names
313+
loss_dict (dict[str, "L1" | "L2" | "CSE"]): Map of target names
314314
to loss functions.
315315
patience (int, optional): Maximum number of epochs without improvement
316316
when early stopping. Defaults to None.
317317
verbose (bool, optional): Whether to show progress bars for each epoch.
318318
"""
319+
if isinstance(train_set, Subset):
320+
train_set = train_set.dataset
321+
if isinstance(val_set, Subset):
322+
val_set = val_set.dataset
323+
319324
train_generator = DataLoader(train_set, **data_params)
320325
print(f"Training on {len(train_set):,} samples")
321326

@@ -350,13 +355,11 @@ def train_ensemble(
350355

351356
for target, normalizer in normalizer_dict.items():
352357
if normalizer is not None:
353-
sample_target = Tensor(
354-
train_set.dataset.df[target].iloc[train_set.indices].values
355-
)
358+
sample_target = Tensor(train_set.df[target].values)
356359
if not restart_params["resume"]:
357360
normalizer.fit(sample_target)
358361
print(
359-
f"Dummy MAE: {torch.mean(torch.abs(sample_target-normalizer.mean)):.4f}"
362+
f"Dummy MAE: {(sample_target - normalizer.mean).abs().mean():.4f}"
360363
)
361364

362365
if log:
@@ -415,7 +418,7 @@ def results_multitask( # noqa: C901
415418
model_name: str,
416419
run_id: int,
417420
ensemble_folds: int,
418-
test_set: Subset,
421+
test_set: Dataset | Subset,
419422
data_params: dict[str, Any],
420423
robust: bool,
421424
task_dict: dict[str, TaskType],
@@ -436,7 +439,7 @@ def results_multitask( # noqa: C901
436439
loss function.
437440
task_dict (dict[str, TaskType]): Map of target names to "regression" or
438441
"classification".
439-
device (type[torch.device] | Literal["cuda", "cpu"]): Device the model will run on
442+
device (type[torch.device] | "cuda" | "cpu"): Device the model will run on
440443
eval_type (str, optional): Whether to use final or early-stopping checkpoints.
441444
Defaults to "checkpoint".
442445
print_results (bool, optional): Whether to print out summary metrics.
@@ -459,6 +462,9 @@ def results_multitask( # noqa: C901
459462
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
460463
)
461464

465+
if isinstance(test_set, Subset):
466+
test_set = test_set.dataset
467+
462468
test_generator = DataLoader(test_set, **data_params)
463469
print(f"Testing on {len(test_set):,} samples")
464470

0 commit comments

Comments
 (0)