Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow DataLoader and Dataset to retain Generic features from torch #8333

Open
stuartthomson opened this issue Feb 10, 2025 · 1 comment
Open

Comments

@stuartthomson
Copy link

Is your feature request related to a problem? Please describe.
I would like to be able to provide better type hints for my monai code. The DataLoader and Dataset classes in monai inherit from torch but hide the fact that in torch these are generic classes in torch. For example, in torch I can define a dataset like:

from torch.utils.data import Dataset

class MyData(TypedDict):
    filename: str
    image: torch.Tensor
    segmentation: torch.Tensor

my_dataset: Dataset[MyData] = create_data()

This means that elsewhere in the code I can have a better idea what the data will look like. This kind of thing isn't possible if I'm using the monai code.

Describe the solution you'd like
I think you could solve this by doing something like the following (for Dataset - you'd need to do something similar in DataLoader):

import collections.abc
from typing import Any, Mapping, Sequence, TypeVar, Union, overload
import numpy as np
import torch
from torch.utils.data import Dataset as _TorchDataset
from torch.utils.data import Subset as _TorchSubset


NdarrayOrTensor = Union[np.ndarray, torch.Tensor]


T = TypeVar(
    "T",
    bound=NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor],
)
class Dataset(_TorchDataset[T]):

    # Leave the rest of the class as-is
    ...

    @overload
    def __getitem__(self, index: slice) -> _TorchSubset[T]:
        ...
    @overload
    def __getitem__(self, index: Sequence[int]) -> _TorchSubset[T]:
        ...
    @overload
    def __getitem__(self, index: int) -> T:
        ...
    
    def __getitem__(self, index: int | slice | Sequence[int]) -> T | _TorchSubset[T]:
        """
        Returns a `Subset` if `index` is a slice or Sequence, a data item otherwise.
        """
        if isinstance(index, slice):
            # dataset[:42]
            start, stop, step = index.indices(len(self))
            indices = range(start, stop, step)
            return _TorchSubset(dataset=self, indices=indices)
        if isinstance(index, collections.abc.Sequence):
            # dataset[[1, 3, 4]]
            return _TorchSubset(dataset=self, indices=index)
        return self._transform(index)
@stuartthomson
Copy link
Author

stuartthomson commented Feb 11, 2025

Let me know if this is something you'd be interested in seeing a PR for - I'd be happy to have a go.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant