-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
convert SlidingWindowInferer such that can be exportable to ONNX #6526
Comments
I'm very interested in your question, but I have a few questions I'd like to ask you.
|
I am working with images that are really wide, the width is 12 times the height. I decided to train my model on patches and the prediction is better on a sliding window rather than passing the whole image even if my model is fully convolutional. Moreover, I want to deploy the model on a small device and try to decrease the inference time as much as possible, this is the reason why the conversion to ONNX. |
Sorry, May I ask that how long will it take if the original model is used in small device and is it has been confirmed that .onnx is faster than original model? Thanks. |
Hi, I have another question that do you mean you want to transfer from model pytorch to ONNX or other models you used to ONNX? |
MONAI’s sliding window inferer is agnostic to the underlying engine. Out of the box, it’s built for PyTorch models, but it isn’t inherently tied to PyTorch per se. In other words, if you have an ONNX or any other engine, you can write a thin wrapper (a predictor function) that:
Here's how you can do it: import onnxruntime
import torch
import numpy as np
# Load the ONNX model
session = onnxruntime.InferenceSession("model.onnx")
def onnx_predictor(patch: torch.Tensor) -> torch.Tensor:
"""
Predictor function that wraps ONNX Runtime inference for a single patch.
Args:
patch (torch.Tensor): Input tensor of shape [B, C, D, H, W]
Returns:
torch.Tensor: Output prediction as a tensor.
"""
# Convert the patch tensor to a NumPy array (ensuring it's on CPU)
patch_np = patch.cpu().numpy()
# Get the input name for the ONNX model
input_name = session.get_inputs()[0].name
# Run inference via ONNX Runtime
outputs = session.run(None, {input_name: patch_np})
# Assume the model's output is the first element in outputs.
output_np = outputs[0]
# Convert the output back to a torch.Tensor
output_tensor = torch.from_numpy(output_np)
return output_tensor from monai.inferers import sliding_window_inference
full_image = torch.randn(1, 1, 128, 128, 128)
roi_size = (64, 64, 64)
sw_batch_size = 4
overlap = 0.25
# Run sliding window inference using the custom ONNX predictor
prediction = sliding_window_inference(
inputs=full_image,
roi_size=roi_size,
sw_batch_size=sw_batch_size,
predictor=onnx_predictor,
overlap=overlap,
)
print("Prediction shape:", prediction.shape) |
@MohamedAtta-AI thanks for your comment. I like the idea, but I wanted all in onnx so I do not have to install torch and its dependencies anymore in the deploy phase. |
I implemented a UNet that in the inference phase uses monai.infers.SlidingWindowInferer, but this cannot be converted to ONNX.
The text was updated successfully, but these errors were encountered: