Skip to content

Commit

Permalink
🍏 Integrate post processing in model for coreml exports
Browse files Browse the repository at this point in the history
  • Loading branch information
ramonhollands committed Feb 23, 2025
1 parent 718d28c commit 89ea875
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 25 deletions.
26 changes: 15 additions & 11 deletions yolo/model/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class YOLO(nn.Module):
parameters, and any other relevant configuration details.
"""

def __init__(self, model_cfg: ModelConfig, class_num: int = 80, export_mode : bool =False):
def __init__(self, model_cfg: ModelConfig, class_num: int = 80, export_mode: bool = False):
super(YOLO, self).__init__()
self.num_classes = class_num
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
Expand Down Expand Up @@ -88,14 +88,14 @@ def generate_anchors(self, image_size: List[int], strides: List[int]):
all_anchors = torch.cat(anchors, dim=0)
all_scalers = torch.cat(scaler, dim=0)
return all_anchors, all_scalers

def get_strides(self, output, input_width) -> List[int]:
W = input_width
strides = []
for predict_head in output:
_, _, *anchor_num = predict_head[2].shape
strides.append(W // anchor_num[1])

return strides

def forward(self, x, external: Optional[Dict] = None, shortcut: Optional[str] = None):
Expand Down Expand Up @@ -130,24 +130,26 @@ def forward(self, x, external: Optional[Dict] = None, shortcut: Optional[str] =
if self.export_mode:

preds_cls, preds_anc, preds_box = [], [], []
for layer_output in output['Main']:
for layer_output in output["Main"]:
pred_cls, pred_anc, pred_box = layer_output
preds_cls.append(pred_cls.permute(0, 2, 3, 1).reshape(pred_cls.shape[0], -1, pred_cls.shape[1]))
preds_anc.append(pred_anc.permute(0, 3, 4, 1, 2).reshape(pred_anc.shape[0], -1, pred_anc.shape[2], pred_anc.shape[1]))
preds_anc.append(
pred_anc.permute(0, 3, 4, 1, 2).reshape(pred_anc.shape[0], -1, pred_anc.shape[2], pred_anc.shape[1])
)
preds_box.append(pred_box.permute(0, 2, 3, 1).reshape(pred_box.shape[0], -1, pred_box.shape[1]))

preds_cls = torch.concat(preds_cls, dim=1).to(x[0][0].device)
preds_anc = torch.concat(preds_anc, dim=1).to(x[0][0].device)
preds_box = torch.concat(preds_box, dim=1).to(x[0][0].device)
strides = self.get_strides(output['Main'], input_width)
anchor_grid, scaler = self.generate_anchors([input_width,input_height], strides) #

strides = self.get_strides(output["Main"], input_width)
anchor_grid, scaler = self.generate_anchors([input_width, input_height], strides) #
anchor_grid = anchor_grid.to(x[0][0].device)
scaler = scaler.to(x[0][0].device)
pred_LTRB = preds_box * scaler.view(1, -1, 1)
lt, rb = pred_LTRB.chunk(2, dim=-1)
preds_box = torch.cat([anchor_grid - lt, anchor_grid + rb], dim=-1)

return preds_cls, preds_anc, preds_box

return output
Expand Down Expand Up @@ -221,7 +223,9 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]):
self.model.load_state_dict(model_state_dict)


def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80, export_mode: bool = False) -> YOLO:
def create_model(
model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80, export_mode: bool = False
) -> YOLO:
"""Constructs and returns a model from a Dictionary configuration file.
Args:
Expand Down
4 changes: 3 additions & 1 deletion yolo/tools/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
class BaseModel(LightningModule):
def __init__(self, cfg: Config, export_mode: bool = False):
super().__init__()
self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight, export_mode=export_mode)
self.model = create_model(
cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight, export_mode=export_mode
)

def forward(self, x):
return self.model(x)
Expand Down
8 changes: 4 additions & 4 deletions yolo/utils/deploy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def __init__(self, cfg: Config, model: YOLO):
cfg.weight = Path("weights") / f"{cfg.model.name}.pt"

extention = self.compiler
if self.compiler == 'coreml':
extention = 'mlpackage'
if self.compiler == "coreml":
extention = "mlpackage"

self.model_path = f"{Path(cfg.weight).stem}.{extention}"

Expand Down Expand Up @@ -139,10 +139,10 @@ def coreml_forward(self, x: Tensor):
model_outputs = []
predictions = self.predict({"x": x})

output_keys = ['preds_cls', 'preds_anc', 'preds_box']
output_keys = ["preds_cls", "preds_anc", "preds_box"]
for key in output_keys:
model_outputs.append(torch.from_numpy(predictions[key]).to(device))

return model_outputs

models.MLModel.__call__ = coreml_forward
Expand Down
15 changes: 8 additions & 7 deletions yolo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def __init__(self, cfg: Config, model: YOLO, format: str, model_path: Optional[s
self.model_path = model_path
else:
extention = self.format
if self.format == 'coreml':
extention = 'mlpackage'
if self.format == "coreml":
extention = "mlpackage"

self.model_path = f"{Path(self.cfg.weight).stem}.{extention}"

self.output_names: List[str] = [
Expand All @@ -36,9 +36,7 @@ def __init__(self, cfg: Config, model: YOLO, format: str, model_path: Optional[s
"9_bbox_deltas_large",
]

self.output_names: List[str] = [
"preds_cls", "preds_anc", "preds_box"
]
self.output_names: List[str] = ["preds_cls", "preds_anc", "preds_box"]

def export_onnx(self, dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None, model_path: Optional[str] = None):
logger.info(f":package: Exporting model to onnx format")
Expand Down Expand Up @@ -89,12 +87,15 @@ def export_coreml(self):
exported_program = torch.export.export(self.model, example_inputs)

import logging

import coremltools as ct

# Convert to Core ML program using the Unified Conversion API.
logging.getLogger("coremltools").disabled = True

model_from_export = ct.convert(exported_program, outputs=[ct.TensorType(name=name) for name in self.output_names], convert_to="mlprogram")
model_from_export = ct.convert(
exported_program, outputs=[ct.TensorType(name=name) for name in self.output_names], convert_to="mlprogram"
)

model_from_export.save(self.model_path)
logger.info(f":white_check_mark: Model exported to coreml format {self.model_path}")
4 changes: 2 additions & 2 deletions yolo/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ def __call__(
prediction = self.converter(predict["Main"])
else:
prediction = predict

pred_class, _, pred_bbox = predict[:3]
pred_conf = prediction[3] if len(prediction) == 4 else None

if rev_tensor is not None:
pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None]
pred_bbox = bbox_nms(pred_class, pred_bbox, self.nms, pred_conf)
Expand Down

0 comments on commit 89ea875

Please sign in to comment.