Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ Rework datasets #98

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ loguru
numpy
opencv-python
Pillow
pycocotools
torchmetrics
requests
rich
torch
Expand Down
22 changes: 4 additions & 18 deletions tests/test_tools/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,40 +17,26 @@ def test_create_dataloader_cache(train_cfg: Config):

make_cache_loader = create_dataloader(train_cfg.task.data, train_cfg.dataset)
load_cache_loader = create_dataloader(train_cfg.task.data, train_cfg.dataset)
m_batch_size, m_images, _, m_reverse_tensors, m_image_paths = next(iter(make_cache_loader))
l_batch_size, l_images, _, l_reverse_tensors, l_image_paths = next(iter(load_cache_loader))
m_batch_size, m_images, _, m_reverse_tensors = next(iter(make_cache_loader))
l_batch_size, l_images, _, l_reverse_tensors = next(iter(load_cache_loader))
assert m_batch_size == l_batch_size
assert m_images.shape == l_images.shape
assert m_reverse_tensors.shape == l_reverse_tensors.shape
assert m_image_paths == l_image_paths


def test_training_data_loader_correctness(train_dataloader: YoloDataLoader):
"""Test that the training data loader produces correctly shaped data and metadata."""
batch_size, images, _, reverse_tensors, image_paths = next(iter(train_dataloader))
batch_size, images, _, reverse_tensors = next(iter(train_dataloader))
assert batch_size == 2
assert images.shape == (2, 3, 640, 640)
assert reverse_tensors.shape == (2, 5)
expected_paths = [
Path("tests/data/images/train/000000050725.jpg"),
Path("tests/data/images/train/000000167848.jpg"),
]
assert list(image_paths) == list(expected_paths)


def test_validation_data_loader_correctness(validation_dataloader: YoloDataLoader):
batch_size, images, targets, reverse_tensors, image_paths = next(iter(validation_dataloader))
batch_size, images, targets, reverse_tensors = next(iter(validation_dataloader))
assert batch_size == 4
assert images.shape == (4, 3, 640, 640)
assert targets.shape == (4, 18, 5)
assert reverse_tensors.shape == (4, 5)
expected_paths = [
Path("tests/data/images/val/000000151480.jpg"),
Path("tests/data/images/val/000000284106.jpg"),
Path("tests/data/images/val/000000323571.jpg"),
Path("tests/data/images/val/000000570456.jpg"),
]
assert list(image_paths) == list(expected_paths)


def test_file_stream_data_loader_frame(file_stream_data_loader: StreamDataLoader):
Expand Down
11 changes: 5 additions & 6 deletions tests/test_tools/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
@pytest.fixture
def model_validator(validation_cfg: Config, model: YOLO, vec2box: Vec2Box, validation_progress_logger, device):
validator = ModelValidator(
validation_cfg.task, validation_cfg.dataset, model, vec2box, validation_progress_logger, device
validation_cfg.task, model, vec2box, validation_progress_logger, device
)
return validator

Expand All @@ -28,11 +28,10 @@ def test_model_validator_initialization(model_validator: ModelValidator):


def test_model_validator_solve_mock_dataset(model_validator: ModelValidator, validation_dataloader: YoloDataLoader):
mAPs = model_validator.solve(validation_dataloader)
except_mAPs = {"mAP.5": tensor(0.6969), "mAP.5:.95": tensor(0.4195)}
assert allclose(mAPs["mAP.5"], except_mAPs["mAP.5"], rtol=0.1)
print(mAPs)
assert allclose(mAPs["mAP.5:.95"], except_mAPs["mAP.5:.95"], rtol=0.1)
metrics = model_validator.solve(validation_dataloader)
except_metrics = {"map_50": tensor(0.7515), "map": tensor(0.5986)}
assert allclose(metrics["map_50"], except_metrics["map_50"], rtol=0.1)
assert allclose(metrics["map"], except_metrics["map"], rtol=0.1)


@pytest.fixture
Expand Down
16 changes: 1 addition & 15 deletions tests/test_utils/test_bounding_box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
Vec2Box,
bbox_nms,
calculate_iou,
calculate_map,
generate_anchors,
transform_bbox,
)
Expand Down Expand Up @@ -167,17 +166,4 @@ def test_bbox_nms():
output = bbox_nms(cls_dist, bbox, nms_cfg)

for out, exp in zip(output, expected_output):
assert allclose(out, exp, atol=1e-4), f"Output: {out} Expected: {exp}"


def test_calculate_map():
predictions = tensor([[0, 60, 60, 160, 160, 0.5], [0, 40, 40, 120, 120, 0.5]]) # [class, x1, y1, x2, y2]
ground_truths = tensor([[0, 50, 50, 150, 150], [0, 30, 30, 100, 100]]) # [class, x1, y1, x2, y2]

mAP = calculate_map(predictions, ground_truths)

expected_ap50 = tensor(0.5)
expected_ap50_95 = tensor(0.2)

assert isclose(mAP["mAP.5"], expected_ap50, atol=1e-5), f"AP50 mismatch"
assert isclose(mAP["mAP.5:.95"], expected_ap50_95, atol=1e-5), f"Mean AP mismatch"
assert allclose(out, exp, atol=1e-4), f"Output: {out} Expected: {exp}"
4 changes: 3 additions & 1 deletion yolo/config/dataset/coco.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
path: data/coco
_target_: yolo.dataset.coco.CocoDataset

root_path: data/coco
train: train2017
validation: val2017

Expand Down
2 changes: 1 addition & 1 deletion yolo/config/task/inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ fast_inference: # onnx, trt, deploy or Empty
data:
source: demo/images/inference/image.png
image_size: ${image_size}
data_augment: {}
data_augment: []
nms:
min_confidence: 0.5
min_iou: 0.5
Expand Down
11 changes: 6 additions & 5 deletions yolo/config/task/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ data:
shuffle: True
pin_memory: True
data_augment:
Mosaic: 1
# MixUp: 1
# HorizontalFlip: 0.5
RandomCrop: 1
RemoveOutliers: 1e-8
- _target_: yolo.tools.data_augmentation.Mosaic
prob: 1
- _target_: yolo.tools.data_augmentation.RandomCrop
prob: 1
- _target_: yolo.tools.data_augmentation.RemoveOutliers
min_box_area: 1e-8

optimizer:
type: SGD
Expand Down
2 changes: 1 addition & 1 deletion yolo/config/task/validation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ data:
cpu_num: ${cpu_num}
shuffle: False
pin_memory: True
data_augment: {}
data_augment: []
nms:
min_confidence: 0.05
min_iou: 0.9
64 changes: 64 additions & 0 deletions yolo/dataset/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations
from typing import List

import hydra
from torch.utils.data import Dataset
import torch

from yolo.config.config import DataConfig, DatasetConfig
from yolo.tools.data_augmentation import AugmentationComposer
from yolo.tools.dataset_preparation import prepare_dataset
from yolo.utils import primitives



def create_dataset(dataset_cfg: DatasetConfig, data_cfg: DataConfig, task: str) -> BaseDataset:
if task == "train":
init_config = dataset_cfg.train_dataset
elif task == "validation":
init_config = dataset_cfg.validation_dataset
else:
raise ValueError(f"Invalid task: {task}")
return hydra.utils.instantiate(init_config, data_cfg=data_cfg)


class BaseDataset(Dataset):
def __init__(self, data_cfg: DataConfig):
self.image_size = data_cfg.image_size
self.transform = AugmentationComposer(data_cfg.data_augment, self.image_size)
self.transform.get_more_data = self.get_more_data
self.load_data()

def extract_data(self) -> List[primitives.YoloImage]:
""" Prepare the data for the dataset.
Args:
dataset_cfg: The dataset configuration.

Returns:
List[Image]: The list of images.
"""
raise NotImplementedError

def load_data(self) -> None:
self.data = self.extract_data()

def extract_sample(self, idx: int) -> tuple[primitives.YoloImage, list[torch.Tensor]]:
sample = self.data[idx]
image = sample.image
image_path = sample.image_path

bboxes = torch.stack([bbox.tensor for bbox in sample.bboxes])
return image, bboxes, image_path

def __getitem__(self, idx: int) -> primitives.YoloImage:
image, bboxes, image_path = self.extract_sample(idx)
image, bboxes, rev_tensor = self.transform(image, bboxes)
return image, bboxes, rev_tensor, image_path

def __len__(self) -> int:
return len(self.data)

def get_more_data(self, num: int = 1) -> List[primitives.YoloImage]:
indices = torch.randint(0, len(self), (num,))
return [self.extract_sample(idx)[:2] for idx in indices]

Empty file added yolo/dataset/coco.py
Empty file.
100 changes: 100 additions & 0 deletions yolo/dataset/yolo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import List
import pathlib
from rich.progress import track

from yolo.dataset.base import BaseDataset
from yolo.config.config import DatasetConfig
from yolo.utils import primitives


class YoloDataset(BaseDataset):
"""
YOLO Dataset Format
**Directory Structure**

```
dataset/
├── images/
│ ├── subset_1/
│ │ ├── image1.jpg
│ │ └── ...
│ └── subset_2/
│ ├── image2.jpg
│ └── ...
└── labels/
├── subset_1/
│ ├── image1.txt
│ └── ...
└── subset_2/
├── image2.txt
└── ...
```

**Annotation Format**

Each annotation file corresponds to an image and contains one or more lines:

```
<class_id> <center_x> <center_y> <width> <height>
```

- **class_id**: Zero-based index matching a line in `classes.txt`.
- **center_x**, **center_y**: Normalized coordinates of the bounding box center (values between 0 and 1).
- **width**, **height**: Normalized dimensions of the bounding box (values between 0 and 1).

**Normalization**

Coordinates are normalized relative to image dimensions:

```
center_x = bbox_center_x / image_width
center_y = bbox_center_y / image_height
width = bbox_width / image_width
height = bbox_height / image_height
```

**Notes**
- Image and annotation filenames must match exactly (excluding extensions).
- The coordinate system origin `(0,0)` is at the top-left corner of the image.
- If an image contains no objects, its annotation file should be empty or omitted.
- Supported image formats include `.jpg`, `.jpeg`, and `.png`.
- Ensure consistency between training and validation splits.

"""

def __init__(self, root_path:str, subset:str, **kwargs):
self.root_path = pathlib.Path(root_path)
self.subset = subset
super().__init__(**kwargs)


def extract_data(self) -> List[primitives.YoloImage]:
""" Prepare the data for the dataset.
Args:
dataset_cfg: The dataset configuration.

Returns:
List[Image]: The list of images.
"""
images = []
for image_path in track((self.root_path / "images" / self.subset).glob("*.*"), description="Loading images"):
if not image_path.suffix in (".jpg", ".jpeg", ".png"):
continue
image = primitives.YoloImage(image_path=image_path)
image.bboxes = self.load_bounding_boxes(self.root_path / "labels" / self.subset / (image_path.stem + ".txt"))
images.append(image)
return images


def load_bounding_boxes(self, path:pathlib.Path) -> List[primitives.BoundingBox]:
boxes = []
if not path.exists():
return boxes
with open(path, "r") as file:
for line in file:
elements = line.strip().split()
class_id = int(elements[0])
center_x, center_y, width, height = map(float, elements[1:])
boxes.append(primitives.BoundingBox(class_id=class_id, x=center_x, y=center_y, width=width, height=height))

return boxes
3 changes: 1 addition & 2 deletions yolo/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@ def main(cfg: Config):
else:
model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
model = model.to(device)

converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)

if cfg.task.task == "train":
solver = ModelTrainer(cfg, model, converter, progress, device, use_ddp)
if cfg.task.task == "validation":
solver = ModelValidator(cfg.task, cfg.dataset, model, converter, progress, device)
solver = ModelValidator(cfg.task, model, converter, progress, device)
if cfg.task.task == "inference":
solver = ModelTester(cfg, model, converter, progress, device)
progress.start()
Expand Down
2 changes: 1 addition & 1 deletion yolo/model/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]):
weights: A OrderedDict containing the new weights.
"""
if isinstance(weights, Path):
weights = torch.load(weights, map_location=torch.device("cpu"))
weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False)
if "model_state_dict" in weights:
weights = weights["model_state_dict"]

Expand Down
Loading