From 960c59b4d72e91a14509704f06d88e200a16d64f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20K=C3=A4nzig?= <36882833+nkaenzig@users.noreply.github.com> Date: Wed, 19 Feb 2025 07:40:38 +0100 Subject: [PATCH] Fix `packaging` imports in version comparison logic (#8347) Fixes #8349 ### Description The current behaviour is that `pkging, has_ver = optional_import("packaging.Version")` always returns `has_ver=False` because the import always fails (the `Version` class is exposed by the `packaging.version` submodule). This issue previously didn't surface, because when the import fails, it would just continue to use the fallback logic. However, there seem to be more hidden and more severe implications, which ultimately led me to discovering this particular bug: Function like `floor_divide()` in `monai.transforms` that check the module version using this logic are called many times in common ML dataloading workflows. The failed imports somehow can lead to OOM errors and the main process being killed (see https://github.com/Project-MONAI/MONAI/issues/8348). Maybe when `optional_import` fails to import a module, the lazy exceptions somehow stack up in memory when this function is called many times in a short time period? ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --- monai/utils/module.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/utils/module.py b/monai/utils/module.py index d3f2ff09f2..7bbbb4ab1e 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -540,11 +540,11 @@ def version_leq(lhs: str, rhs: str) -> bool: """ lhs, rhs = str(lhs), str(rhs) - pkging, has_ver = optional_import("packaging.Version") + pkging, has_ver = optional_import("packaging.version") if has_ver: try: - return cast(bool, pkging.version.Version(lhs) <= pkging.version.Version(rhs)) - except pkging.version.InvalidVersion: + return cast(bool, pkging.Version(lhs) <= pkging.Version(rhs)) + except pkging.InvalidVersion: return True lhs_, rhs_ = parse_version_strs(lhs, rhs) @@ -567,12 +567,12 @@ def version_geq(lhs: str, rhs: str) -> bool: """ lhs, rhs = str(lhs), str(rhs) - pkging, has_ver = optional_import("packaging.Version") + pkging, has_ver = optional_import("packaging.version") if has_ver: try: - return cast(bool, pkging.version.Version(lhs) >= pkging.version.Version(rhs)) - except pkging.version.InvalidVersion: + return cast(bool, pkging.Version(lhs) >= pkging.Version(rhs)) + except pkging.InvalidVersion: return True lhs_, rhs_ = parse_version_strs(lhs, rhs)