Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrating supervision annotators #169

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ requests
rich
torch
torchvision
supervision
wandb
9 changes: 6 additions & 3 deletions tests/test_tools/test_drawer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import sys
from pathlib import Path

import numpy as np
from PIL import Image
from torch import tensor
import supervision as sv

project_root = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(project_root))
Expand All @@ -24,6 +25,8 @@ def test_draw_model_by_model(model: YOLO):

def test_draw_bboxes():
"""Test drawing bounding boxes on an image."""
predictions = tensor([[0, 60, 60, 160, 160, 0.5], [0, 40, 40, 120, 120, 0.5]])
detections = sv.Detections(xyxy=np.asarray([[60, 60, 160, 160], [40, 40, 120, 120]]),
confidence=np.asarray([0.5, 0.5]),
class_id=np.asarray([0, 0]))
pil_image = Image.open("tests/data/images/train/000000050725.jpg")
draw_bboxes(pil_image, [predictions])
draw_bboxes(pil_image, detections)
21 changes: 21 additions & 0 deletions tests/test_utils/test_model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from yolo.utils.model_utils import prediction_to_sv
import torch
import numpy as np


def test_prediction_to_sv():

predictions = []
detections = prediction_to_sv(predictions)
assert len(detections) == 0

xyxy = torch.tensor([[60, 60, 160, 160], [40, 40, 120, 120]], dtype=torch.float32)
confidence = torch.tensor([0.5, 0.5], dtype=torch.float32).unsqueeze(1)
class_id = torch.tensor([0, 0], dtype=torch.int64).unsqueeze(1)
predictions = [torch.cat([class_id, xyxy, confidence], dim=1)]

detections = prediction_to_sv(predictions)
assert len(detections) == 2
assert np.allclose(detections.xyxy, xyxy.numpy())
assert np.allclose(detections.confidence, confidence.numpy().flatten())
assert np.allclose(detections.class_id, class_id.numpy().flatten())
2 changes: 2 additions & 0 deletions yolo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
YOLORichModelSummary,
YOLORichProgressBar,
)
from yolo.utils.model_utils import prediction_to_sv
from yolo.utils.model_utils import PostProcess

all = [
Expand All @@ -30,4 +31,5 @@
"FastModelLoader",
"TrainModel",
"PostProcess",
"prediction_to_sv",
]
52 changes: 14 additions & 38 deletions yolo/tools/drawer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import random
from typing import List, Optional, Union
from typing import Optional, Union

import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from PIL import Image
import supervision as sv
from torchvision.transforms.functional import to_pil_image

from yolo.config.config import ModelConfig
Expand All @@ -13,7 +13,7 @@

def draw_bboxes(
img: Union[Image.Image, torch.Tensor],
bboxes: List[List[Union[int, float]]],
detections: sv.Detections,
*,
idx2label: Optional[list] = None,
):
Expand All @@ -32,41 +32,17 @@ def draw_bboxes(
img = img[0]
img = to_pil_image(img)

if isinstance(bboxes, list) or bboxes.ndim == 3:
bboxes = bboxes[0]

box_annotator = sv.ColorAnnotator(color_lookup=sv.ColorLookup.CLASS)
label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.CLASS)
img = img.copy()
label_size = img.size[1] / 30
draw = ImageDraw.Draw(img, "RGBA")

try:
font = ImageFont.truetype("arial.ttf", int(label_size))
except IOError:
font = ImageFont.load_default(int(label_size))

for bbox in bboxes:
class_id, x_min, y_min, x_max, y_max, *conf = [float(val) for val in bbox]
x_min, x_max = min(x_min, x_max), max(x_min, x_max)
y_min, y_max = min(y_min, y_max), max(y_min, y_max)
bbox = [(x_min, y_min), (x_max, y_max)]

random.seed(int(class_id))
color_map = (random.randint(0, 200), random.randint(0, 200), random.randint(0, 200))

draw.rounded_rectangle(bbox, outline=(*color_map, 200), radius=5, width=2)
draw.rounded_rectangle(bbox, fill=(*color_map, 100), radius=5)

class_text = str(idx2label[int(class_id)] if idx2label else int(class_id))
label_text = f"{class_text}" + (f" {conf[0]: .0%}" if conf else "")

text_bbox = font.getbbox(label_text)
text_width = text_bbox[2] - text_bbox[0]
text_height = (text_bbox[3] - text_bbox[1]) * 1.5

text_background = [(x_min, y_min), (x_min + text_width, y_min + text_height)]
draw.rounded_rectangle(text_background, fill=(*color_map, 175), radius=2)
draw.text((x_min, y_min), label_text, fill="white", font=font)

img = box_annotator.annotate(img, detections=detections)
if idx2label:
labels = [
f"{str(idx2label[int(class_id)])} {confidence:.2f}"
for class_id, confidence
in zip(detections.class_id, detections.confidence)
]
img = label_annotator.annotate(img, labels=labels, detections=detections)
return img


Expand Down
5 changes: 4 additions & 1 deletion yolo/tools/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from yolo.config.config import Config
from yolo.model.yolo import create_model
from yolo.tools.data_loader import create_dataloader
from yolo.utils.model_utils import prediction_to_sv
from yolo.tools.drawer import draw_bboxes
from yolo.tools.loss_functions import create_loss_function
from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
Expand Down Expand Up @@ -126,7 +127,9 @@ def predict_dataloader(self):
def predict_step(self, batch, batch_idx):
images, rev_tensor, origin_frame = batch
predicts = self.post_process(self(images), rev_tensor=rev_tensor)
img = draw_bboxes(origin_frame, predicts, idx2label=self.cfg.dataset.class_list)
detections = prediction_to_sv(predicts) # convert to sv format
class_list = [str(label) for label in self.cfg.dataset.class_list]
img = draw_bboxes(origin_frame, detections, idx2label=class_list)
if getattr(self.predict_loader, "is_stream", None):
fps = self._display_stream(img)
else:
Expand Down
18 changes: 18 additions & 0 deletions yolo/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch import Tensor, no_grad
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
import supervision as sv

from yolo.config.config import IDX_TO_ID, NMSConfig, OptimizerConfig, SchedulerConfig
from yolo.model.yolo import YOLO
Expand Down Expand Up @@ -222,3 +223,20 @@ def predicts_to_json(img_paths, predicts, rev_tensor):
}
batch_json.append(bbox)
return batch_json


def prediction_to_sv(predicts: List[Tensor]) -> sv.Detections:
"""
Convert the prediction to the format of the Supervision
Args:
predicts:
rev_tensor:

Returns:
sv.Detections: The detections in the Supervision format
"""
if len(predicts) == 0:
return sv.Detections.empty()
predicts = predicts[0].detach().cpu().numpy()
detections = sv.Detections(xyxy=predicts[:, 1:5], class_id=predicts[:, 0].astype(int), confidence=predicts[:, 5])
return detections