From 18f44e2842f8fca371092887b8570453ee970ede Mon Sep 17 00:00:00 2001 From: harikdava Date: Sun, 9 Feb 2025 20:28:37 +0100 Subject: [PATCH 1/3] using supervision annotators --- requirements.txt | 1 + yolo/__init__.py | 2 ++ yolo/tools/drawer.py | 52 +++++++++++---------------------------- yolo/tools/solver.py | 5 +++- yolo/utils/model_utils.py | 16 ++++++++++++ 5 files changed, 37 insertions(+), 39 deletions(-) diff --git a/requirements.txt b/requirements.txt index f6d336cb..0a427b80 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ requests rich torch torchvision +supervision wandb diff --git a/yolo/__init__.py b/yolo/__init__.py index b4b98d7f..edbf279c 100644 --- a/yolo/__init__.py +++ b/yolo/__init__.py @@ -10,6 +10,7 @@ YOLORichModelSummary, YOLORichProgressBar, ) +from yolo.utils.model_utils import prediction_to_sv from yolo.utils.model_utils import PostProcess all = [ @@ -30,4 +31,5 @@ "FastModelLoader", "TrainModel", "PostProcess", + "prediction_to_sv", ] diff --git a/yolo/tools/drawer.py b/yolo/tools/drawer.py index ca0fbdbd..a76a43e6 100644 --- a/yolo/tools/drawer.py +++ b/yolo/tools/drawer.py @@ -1,9 +1,9 @@ -import random -from typing import List, Optional, Union +from typing import Optional, Union import numpy as np import torch -from PIL import Image, ImageDraw, ImageFont +from PIL import Image +import supervision as sv from torchvision.transforms.functional import to_pil_image from yolo.config.config import ModelConfig @@ -13,7 +13,7 @@ def draw_bboxes( img: Union[Image.Image, torch.Tensor], - bboxes: List[List[Union[int, float]]], + detections: sv.Detections, *, idx2label: Optional[list] = None, ): @@ -32,41 +32,17 @@ def draw_bboxes( img = img[0] img = to_pil_image(img) - if isinstance(bboxes, list) or bboxes.ndim == 3: - bboxes = bboxes[0] - + box_annotator = sv.ColorAnnotator(color_lookup=sv.ColorLookup.CLASS) + label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.CLASS) img = img.copy() - label_size = img.size[1] / 30 - draw = ImageDraw.Draw(img, "RGBA") - - try: - font = ImageFont.truetype("arial.ttf", int(label_size)) - except IOError: - font = ImageFont.load_default(int(label_size)) - - for bbox in bboxes: - class_id, x_min, y_min, x_max, y_max, *conf = [float(val) for val in bbox] - x_min, x_max = min(x_min, x_max), max(x_min, x_max) - y_min, y_max = min(y_min, y_max), max(y_min, y_max) - bbox = [(x_min, y_min), (x_max, y_max)] - - random.seed(int(class_id)) - color_map = (random.randint(0, 200), random.randint(0, 200), random.randint(0, 200)) - - draw.rounded_rectangle(bbox, outline=(*color_map, 200), radius=5, width=2) - draw.rounded_rectangle(bbox, fill=(*color_map, 100), radius=5) - - class_text = str(idx2label[int(class_id)] if idx2label else int(class_id)) - label_text = f"{class_text}" + (f" {conf[0]: .0%}" if conf else "") - - text_bbox = font.getbbox(label_text) - text_width = text_bbox[2] - text_bbox[0] - text_height = (text_bbox[3] - text_bbox[1]) * 1.5 - - text_background = [(x_min, y_min), (x_min + text_width, y_min + text_height)] - draw.rounded_rectangle(text_background, fill=(*color_map, 175), radius=2) - draw.text((x_min, y_min), label_text, fill="white", font=font) - + img = box_annotator.annotate(img, detections=detections) + if idx2label: + labels = [ + f"{str(idx2label[int(class_id)])} {confidence:.2f}" + for class_id, confidence + in zip(detections.class_id, detections.confidence) + ] + img = label_annotator.annotate(img, labels=labels, detections=detections) return img diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index 8246a66d..0c59aabb 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -7,6 +7,7 @@ from yolo.config.config import Config from yolo.model.yolo import create_model from yolo.tools.data_loader import create_dataloader +from yolo.utils.model_utils import prediction_to_sv from yolo.tools.drawer import draw_bboxes from yolo.tools.loss_functions import create_loss_function from yolo.utils.bounding_box_utils import create_converter, to_metrics_format @@ -126,7 +127,9 @@ def predict_dataloader(self): def predict_step(self, batch, batch_idx): images, rev_tensor, origin_frame = batch predicts = self.post_process(self(images), rev_tensor=rev_tensor) - img = draw_bboxes(origin_frame, predicts, idx2label=self.cfg.dataset.class_list) + detections = prediction_to_sv(predicts) # convert to sv format + class_list = [str(label) for label in self.cfg.dataset.class_list] + img = draw_bboxes(origin_frame, detections, idx2label=class_list) if getattr(self.predict_loader, "is_stream", None): fps = self._display_stream(img) else: diff --git a/yolo/utils/model_utils.py b/yolo/utils/model_utils.py index 9d6c0ce5..7ef86f6f 100644 --- a/yolo/utils/model_utils.py +++ b/yolo/utils/model_utils.py @@ -12,6 +12,7 @@ from torch import Tensor, no_grad from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler +import supervision as sv from yolo.config.config import IDX_TO_ID, NMSConfig, OptimizerConfig, SchedulerConfig from yolo.model.yolo import YOLO @@ -222,3 +223,18 @@ def predicts_to_json(img_paths, predicts, rev_tensor): } batch_json.append(bbox) return batch_json + + +def prediction_to_sv(predicts: List[Tensor]) -> sv.Detections: + """ + Convert the prediction to the format of the Supervision + Args: + predicts: + rev_tensor: + + Returns: + sv.Detections: The detections in the Supervision format + """ + predicts = predicts[0].detach().cpu().numpy() + detections = sv.Detections(xyxy=predicts[:, 1:5], class_id=predicts[:, 0].astype(int), confidence=predicts[:, 5]) + return detections From 34c1d3e665a4269e098daa046de92b27f499a720 Mon Sep 17 00:00:00 2001 From: harikdava Date: Sun, 9 Feb 2025 20:42:32 +0100 Subject: [PATCH 2/3] tests added --- tests/test_tools/test_drawer.py | 9 ++++++--- tests/test_utils/test_model_utils.py | 21 +++++++++++++++++++++ yolo/utils/model_utils.py | 2 ++ 3 files changed, 29 insertions(+), 3 deletions(-) create mode 100644 tests/test_utils/test_model_utils.py diff --git a/tests/test_tools/test_drawer.py b/tests/test_tools/test_drawer.py index 07912764..44014fd8 100644 --- a/tests/test_tools/test_drawer.py +++ b/tests/test_tools/test_drawer.py @@ -1,8 +1,9 @@ import sys from pathlib import Path +import numpy as np from PIL import Image -from torch import tensor +import supervision as sv project_root = Path(__file__).resolve().parent.parent.parent sys.path.append(str(project_root)) @@ -24,6 +25,8 @@ def test_draw_model_by_model(model: YOLO): def test_draw_bboxes(): """Test drawing bounding boxes on an image.""" - predictions = tensor([[0, 60, 60, 160, 160, 0.5], [0, 40, 40, 120, 120, 0.5]]) + detections = sv.Detections(xyxy=np.asarray([[60, 60, 160, 160], [40, 40, 120, 120]]), + confidence=np.asarray([0.5, 0.5]), + class_id=np.asarray([0, 0])) pil_image = Image.open("tests/data/images/train/000000050725.jpg") - draw_bboxes(pil_image, [predictions]) + draw_bboxes(pil_image, detections) diff --git a/tests/test_utils/test_model_utils.py b/tests/test_utils/test_model_utils.py new file mode 100644 index 00000000..1362bd53 --- /dev/null +++ b/tests/test_utils/test_model_utils.py @@ -0,0 +1,21 @@ +from yolo.utils.model_utils import prediction_to_sv +import torch +import numpy as np + + +def test_prediction_to_sv(): + + predictions = [] + detections = prediction_to_sv(predictions) + assert len(detections) == 0 + + xyxy = torch.tensor([[60, 60, 160, 160], [40, 40, 120, 120]], dtype=torch.float32) + confidence = torch.tensor([0.5, 0.5], dtype=torch.float32).unsqueeze(1) + class_id = torch.tensor([0, 0], dtype=torch.int64).unsqueeze(1) + predictions = [torch.cat([class_id, xyxy, confidence], dim=1)] + + detections = prediction_to_sv(predictions) + assert len(detections) == 2 + assert np.allclose(detections.xyxy, xyxy.numpy()) + assert np.allclose(detections.confidence, confidence.numpy().flatten()) + assert np.allclose(detections.class_id, class_id.numpy().flatten()) diff --git a/yolo/utils/model_utils.py b/yolo/utils/model_utils.py index 7ef86f6f..102fb578 100644 --- a/yolo/utils/model_utils.py +++ b/yolo/utils/model_utils.py @@ -235,6 +235,8 @@ def prediction_to_sv(predicts: List[Tensor]) -> sv.Detections: Returns: sv.Detections: The detections in the Supervision format """ + if len(predicts) == 0: + return sv.Detections.empty() predicts = predicts[0].detach().cpu().numpy() detections = sv.Detections(xyxy=predicts[:, 1:5], class_id=predicts[:, 0].astype(int), confidence=predicts[:, 5]) return detections From 39c3bbf7bee5c30cda24e57c0b481f9929ac69e6 Mon Sep 17 00:00:00 2001 From: harikdava Date: Mon, 10 Feb 2025 19:52:16 +0100 Subject: [PATCH 3/3] update --- yolo/tools/drawer.py | 2 +- yolo/utils/model_utils.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/yolo/tools/drawer.py b/yolo/tools/drawer.py index a76a43e6..69c5900d 100644 --- a/yolo/tools/drawer.py +++ b/yolo/tools/drawer.py @@ -12,7 +12,7 @@ def draw_bboxes( - img: Union[Image.Image, torch.Tensor], + img: Image.Image | torch.Tensor, detections: sv.Detections, *, idx2label: Optional[list] = None, diff --git a/yolo/utils/model_utils.py b/yolo/utils/model_utils.py index 102fb578..96277f31 100644 --- a/yolo/utils/model_utils.py +++ b/yolo/utils/model_utils.py @@ -228,6 +228,7 @@ def predicts_to_json(img_paths, predicts, rev_tensor): def prediction_to_sv(predicts: List[Tensor]) -> sv.Detections: """ Convert the prediction to the format of the Supervision + Args: predicts: rev_tensor: