Skip to content

Commit

Permalink
feat: remove tensorrt (#230)
Browse files Browse the repository at this point in the history
  • Loading branch information
AmineDiro authored Feb 14, 2025
1 parent e790d83 commit 8b8abbc
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 64 deletions.
30 changes: 1 addition & 29 deletions libs/megaparse/src/megaparse/examples/parsing_process.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from pathlib import Path
from typing import IO, Any, List, Tuple

Expand Down Expand Up @@ -84,33 +83,6 @@ def get_strategy_page(
return StrategyEnum.FAST


def _get_providers(device=DeviceEnum.CPU) -> List[str]:
prov = rt.get_available_providers()
print("Available providers:", prov)
if device == DeviceEnum.CUDA:
# TODO: support openvino, directml etc
if "CUDAExecutionProvider" not in prov:
raise ValueError(
"onnxruntime can't find CUDAExecutionProvider in list of available providers"
)
return ["TensorrtExecutionProvider", "CUDAExecutionProvider"]
elif device == DeviceEnum.COREML:
if "CoreMLExecutionProvider" not in prov:
raise ValueError(
"onnxruntime can't find CoreMLExecutionProvider in list of available providers"
)
return ["CoreMLExecutionProvider"]
elif device == DeviceEnum.CPU:
return ["CPUExecutionProvider"]
else:
warnings.warn(
"Device not supported, using CPU",
UserWarning,
stacklevel=2,
)
return ["CPUExecutionProvider"]


def validate_input(
file_path: Path | str | None = None,
file: IO[bytes] | None = None,
Expand Down Expand Up @@ -238,7 +210,7 @@ def main():
else:
text_det_config = TextDetConfig()
general_options = rt.SessionOptions()
providers = _get_providers(device=device)
providers = get_providers(device=device)
engine_config = EngineConfig(
session_options=general_options,
providers=providers,
Expand Down
26 changes: 0 additions & 26 deletions libs/megaparse/src/megaparse/parser/doctr_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import uuid
import warnings
from typing import Any, Dict, List, Tuple, Type
from uuid import UUID

Expand Down Expand Up @@ -108,31 +107,6 @@ def __init__(
self.detect_orientation = detect_orientation
self.detect_language = detect_language

def _get_providers(self) -> List[str]:
prov = rt.get_available_providers()
if self.device == DeviceEnum.CUDA:
# TODO: support openvino, directml etc
if "CUDAExecutionProvider" not in prov:
raise ValueError(
"onnxruntime can't find CUDAExecutionProvider in list of available providers"
)
return ["TensorrtExecutionProvider", "CUDAExecutionProvider"]
elif self.device == DeviceEnum.COREML:
if "CoreMLExecutionProvider" not in prov:
raise ValueError(
"onnxruntime can't find CoreMLExecutionProvider in list of available providers"
)
return ["CoreMLExecutionProvider"]
elif self.device == DeviceEnum.CPU:
return ["CPUExecutionProvider"]
else:
warnings.warn(
"Device not supported, using CPU",
UserWarning,
stacklevel=2,
)
return ["CPUExecutionProvider"]

def get_text_detections(self, pages: list[Page], **kwargs) -> List[Page]:
rasterized_pages = [np.array(page.rasterized) for page in pages]
# Dimension check
Expand Down
11 changes: 2 additions & 9 deletions libs/megaparse/src/megaparse/utils/onnx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import warnings
from typing import List

import onnxruntime as rt
Expand All @@ -12,12 +11,11 @@ def get_providers(device: DeviceEnum) -> List[str]:
prov = rt.get_available_providers()
logger.info("Available providers:", prov)
if device == DeviceEnum.CUDA:
# TODO: support openvino, directml etc
if "CUDAExecutionProvider" not in prov:
raise ValueError(
"onnxruntime can't find CUDAExecutionProvider in list of available providers"
)
return ["TensorrtExecutionProvider", "CUDAExecutionProvider"]
return ["CUDAExecutionProvider"]
elif device == DeviceEnum.COREML:
if "CoreMLExecutionProvider" not in prov:
raise ValueError(
Expand All @@ -27,9 +25,4 @@ def get_providers(device: DeviceEnum) -> List[str]:
elif device == DeviceEnum.CPU:
return ["CPUExecutionProvider"]
else:
warnings.warn(
"Device not supported, using CPU",
UserWarning,
stacklevel=2,
)
return ["CPUExecutionProvider"]
raise ValueError("device not in (CUDA,CoreML,CPU)")

0 comments on commit 8b8abbc

Please sign in to comment.