Skip to content

Commit

Permalink
New features for config classes (#701)
Browse files Browse the repository at this point in the history
* suggestion of new config class

* apply changes to all transforms
  • Loading branch information
thibaultdvx authored Feb 20, 2025
1 parent acf261d commit d6a9eb4
Show file tree
Hide file tree
Showing 29 changed files with 1,191 additions and 684 deletions.
1 change: 0 additions & 1 deletion clinicadl/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .datasets import CapsDataset, ConcatDataset
11 changes: 6 additions & 5 deletions clinicadl/data/datasets/caps_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,14 @@ class CapsDataset(Dataset):
>>> # │ └── sub-000_ses-M000_trc-18FAV45_space-MNI152NLin2009cSym_res-1x1x1_suvr-pons2_seg.nii.gz
>>> # ...
>>> # ...
>>> from clinicadl.data import CapsDataset
>>> from clinicadl.data.datasets import CapsDataset
>>> from clinicadl.data.datatype import PETLinear
>>> from clinicadl.transforms import Transforms, get_transform_config
>>> from clinicadl.transforms import Transforms
>>> from clinicadl.transforms.config import ZNormalizationConfig, MaskConfig, RandomFlipConfig
>>> from clinicadl.transforms.extraction import Patch
>>> normalization = get_transform_config("ZNormalization", masking_method="brain")
>>> mask = get_transform_config("Mask", masking_method="leftHippocampus")
>>> flip = get_transform_config("RandomFlip", flip_probability=0.3)
>>> normalization = ZNormalizationConfig(masking_method="brain")
>>> mask = MaskConfig(masking_method="leftHippocampus")
>>> flip = RandomFlipConfig(flip_probability=0.3)
>>> dataset = CapsDataset(
caps_directory="mycaps",
preprocessing=PETLinear(
Expand Down
5 changes: 2 additions & 3 deletions clinicadl/data/tensor_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
from pydantic import SerializeAsAny, ValidationError
from tqdm import tqdm

from clinicadl.dictionary.suffixes import JSON, PT
from clinicadl.dictionary.suffixes import JSON
from clinicadl.dictionary.words import (
AFFINE,
IMAGE,
LABEL,
MASK,
)
from clinicadl.transforms import get_transform_config
from clinicadl.transforms.config import TransformConfig
from clinicadl.transforms.config import TransformConfig, get_transform_config
from clinicadl.transforms.types import Transform
from clinicadl.utils.config import ClinicaDLConfig
from clinicadl.utils.exceptions import (
Expand Down
2 changes: 0 additions & 2 deletions clinicadl/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
from .extraction import Image, Patch, Slice
from .factory import get_transform_config
from .transforms import Transforms
11 changes: 8 additions & 3 deletions clinicadl/transforms/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from .base import TransformConfig
from .enum import ImplementedTransform, TransformType
from .factory import create_transform_config
from .base import *
from .enum import ImplementedTransform
from .factory import get_transform_config
from .intensity import *
from .intensity_augmentations import *
from .label import *
from .spatial import *
from .spatial_augmentations import *
85 changes: 52 additions & 33 deletions clinicadl/transforms/config/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, List, Optional, Tuple, Union

import torchio as tio
from pydantic import (
NonNegativeFloat,
NonNegativeInt,
Expand All @@ -9,31 +9,33 @@
model_validator,
)

from clinicadl.utils.config import ClinicaDLConfig
from clinicadl.utils.factories import DefaultFromLibrary
from clinicadl.utils.config import NewClinicaDLConfig

from .enum import (
AnatomicalAxis,
AnatomicalLabel,
ImplementedTransform,
NumericalAxis,
TransformType,
)

__all__ = [
"TransformConfig",
"OneOfConfig",
]


class TransformConfig(ClinicaDLConfig, ABC):
class TransformConfig(NewClinicaDLConfig):
"""Base config class for the transforms."""

@computed_field
@property
@abstractmethod
def name(self) -> str:
"""The name of the transform."""
def get_object(self) -> tio.Transform:
"""
Returns the transform associated to this configuration,
parametrized with the parameters passed by the user.
@property
def _type(self) -> TransformType:
"""The source where the transform can be found."""
return TransformType.TORCHIO
Returns
-------
tio.Transform:
The TorchIO transform.
"""
return super().get_object()

@staticmethod
def _is_couple_sorted(tup: Tuple[Any, Any], field_name: str) -> None:
Expand Down Expand Up @@ -78,12 +80,43 @@ class OneOfConfig(TransformConfig):
transforms: List[TransformConfig]
probabilities: Optional[List[NonNegativeFloat]] = None

def __init__(
self,
transforms: List[TransformConfig],
probabilities: Optional[List[NonNegativeFloat]] = None,
):
super().__init__(
transforms=transforms,
probabilities=probabilities,
)

@computed_field
@property
def name(self) -> str:
"""The name of the transform."""
return ImplementedTransform.ONE_OF.value

def get_object(self) -> tio.Transform:
"""
Returns the transform associated to this configuration,
parametrized with the parameters passed by the user.
Returns
-------
tio.Transform:
The TorchIO transform.
"""
config_dict = {
transform.get_object(): proba
for transform, proba in zip(self.transforms, self.probabilities)
}
one_of = self._get_class()(transforms=config_dict)
return one_of

def _get_class(self) -> type[tio.Transform]:
"""Returns the transform associated to this config class."""
return tio.OneOf

@model_validator(mode="after")
def check_probabilities(self):
"""Checks that 'probabilities' is the same length as 'transforms'."""
Expand Down Expand Up @@ -111,12 +144,10 @@ def check_probabilities(self):
]


class _MaskingMethodConfig(ClinicaDLConfig):
"""Base config class for normalization transforms."""
class MaskingMethodConfig(NewClinicaDLConfig):
"""Base config class 'masking_method' argument."""

masking_method: Optional[
Union[str, AnatomicalLabel, Bounds, DefaultFromLibrary]
] = DefaultFromLibrary.YES
masking_method: Optional[Union[str, AnatomicalLabel, Bounds]]

@field_validator("masking_method", mode="before")
@classmethod
Expand All @@ -130,15 +161,3 @@ def validator_masking_method(cls, v):
except ValueError:
pass
return v


class _AnatomicalAxesConfig(ClinicaDLConfig):
"""Config class for 'axes' option when it supports anatomical values."""

axes: Union[
NumericalAxis,
Tuple[NumericalAxis, ...],
AnatomicalAxis,
Tuple[AnatomicalAxis, ...],
DefaultFromLibrary,
] = DefaultFromLibrary.YES
9 changes: 0 additions & 9 deletions clinicadl/transforms/config/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,6 @@
from clinicadl.utils.enum import BaseEnum


class TransformType(str, BaseEnum):
"""
Sources for transforms in ClinicaDL.
"""

TORCHIO = "TorchIO"
HOMEMADE = "HomeMade"


class ImplementedTransform(str, BaseEnum):
"""
Implemented transforms in ClinicaDL.
Expand Down
33 changes: 17 additions & 16 deletions clinicadl/transforms/config/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Type, Union
from typing import Any, Union

# pylint: disable=unused-import
from .base import OneOfConfig, TransformConfig
Expand Down Expand Up @@ -38,29 +38,30 @@
)


def create_transform_config(
transform: Union[str, ImplementedTransform],
) -> Type[TransformConfig]:
def get_transform_config(
name: Union[str, ImplementedTransform], **kwargs: Any
) -> TransformConfig:
"""
A factory function to create a config class suited for the transform.
Factory function to get a transform configuration object from its name
and parameters.
Parameters
----------
transform : Union[str, ImplementedTransform]
The name of the transform.
name : Union[str, ImplementedTransform]
the name of the transform. Check our documentation to know
supported transforms.
**kwargs : Any
any parameter of the transform. Check our documentation on transforms to
know these parameters.
Returns
-------
Type[TransformConfig]
The config class.
Raises
------
ValueError
If `transform` is not supported.
TransformConfig
the config object. Default values will be returned for the parameters
not passed by the user.
"""
transform = ImplementedTransform(transform)
transform = ImplementedTransform(name)
config_name = "".join([transform, "Config"])
config = globals()[config_name]

return config
return config(**kwargs)
Loading

0 comments on commit d6a9eb4

Please sign in to comment.