Skip to content

Commit

Permalink
Merge branch 'dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
phisanti authored Mar 8, 2025
2 parents 3c2dbc6 + 7c26e5a commit da0a186
Show file tree
Hide file tree
Showing 64 changed files with 231 additions and 256 deletions.
10 changes: 3 additions & 7 deletions .github/workflows/cron.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,13 @@ jobs:
strategy:
matrix:
environment:
- "PT113+CUDA118"
- "PT210+CUDA121"
- "PT230+CUDA121"
- "PT240+CUDA126"
- "PTLATEST+CUDA126"
include:
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes
- environment: PT113+CUDA118
pytorch: "torch==1.13.1 torchvision==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu121"
base: "nvcr.io/nvidia/pytorch:22.10-py3" # CUDA 11.8
- environment: PT210+CUDA121
pytorch: "pytorch==2.1.0 torchvision==0.16.0 --extra-index-url https://download.pytorch.org/whl/cu121"
- environment: PT230+CUDA121
pytorch: "pytorch==2.3.0 torchvision==0.18.0 --extra-index-url https://download.pytorch.org/whl/cu121"
base: "nvcr.io/nvidia/pytorch:23.08-py3" # CUDA 12.1
- environment: PT240+CUDA126
pytorch: "pytorch==2.4.0 torchvision==0.19.0 --extra-index-url https://download.pytorch.org/whl/cu121"
Expand Down
26 changes: 14 additions & 12 deletions .github/workflows/pythonapp-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,21 @@ jobs:
strategy:
matrix:
environment:
- "PT113+CUDA116"
- "PT210+CUDA121DOCKER"
- "PT230+CUDA124DOCKER"
- "PT240+CUDA125DOCKER"
- "PT250+CUDA126DOCKER"
include:
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes
- environment: PT113+CUDA116
pytorch: "torch==1.13.1 torchvision==0.14.1"
base: "nvcr.io/nvidia/cuda:11.6.1-devel-ubuntu18.04"
- environment: PT210+CUDA121DOCKER
# 23.08: 2.1.0a0+29c30b1
- environment: PT230+CUDA124DOCKER
# 24.04: 2.3.0a0+6ddf5cf85e
pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error
base: "nvcr.io/nvidia/pytorch:23.08-py3"
- environment: PT210+CUDA121DOCKER
# 24.08: 2.3.0a0+40ec155e58.nv24.3
base: "nvcr.io/nvidia/pytorch:24.04-py3"
- environment: PT240+CUDA125DOCKER
# 24.06: 2.4.0a0+f70bd71a48
pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error
base: "nvcr.io/nvidia/pytorch:24.06-py3"
- environment: PT250+CUDA126DOCKER
# 24.08: 2.5.0a0+872d972e41
pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error
base: "nvcr.io/nvidia/pytorch:24.08-py3"
container:
Expand All @@ -49,7 +51,7 @@ jobs:
apt-get update
apt-get install -y wget
if [ ${{ matrix.environment }} = "PT113+CUDA116" ]
if [ ${{ matrix.environment }} = "PT230+CUDA124" ]
then
PYVER=3.9 PYSFX=3 DISTUTILS=python3-distutils && \
apt-get update && apt-get install -y --no-install-recommends \
Expand Down Expand Up @@ -114,7 +116,7 @@ jobs:
# build for the current self-hosted CI Tesla V100
BUILD_MONAI=1 TORCH_CUDA_ARCH_LIST="7.0" ./runtests.sh --build --disttests
./runtests.sh --quick --unittests
if [ ${{ matrix.environment }} = "PT113+CUDA116" ]; then
if [ ${{ matrix.environment }} = "PT230+CUDA124" ]; then
# test the clang-format tool downloading once
coverage run -m tests.clang_format_utils
fi
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pythonapp-min.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ jobs:
strategy:
fail-fast: false
matrix:
pytorch-version: ['1.13.1', '2.0.1', '2.2.2', '2.3.1', '2.4.1', 'latest']
pytorch-version: ['2.3.1', '2.4.1', '2.5.1', 'latest']
timeout-minutes: 40
steps:
- uses: actions/checkout@v4
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ jobs:
- if: runner.os == 'windows'
name: Install torch cpu from pytorch.org (Windows only)
run: |
python -m pip install torch==1.13.1+cpu torchvision==0.14.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install torch==2.4.1 torchvision==0.19.1+cpu --index-url https://download.pytorch.org/whl/cpu
- if: runner.os == 'Linux'
name: Install itk pre-release (Linux only)
run: |
Expand All @@ -103,7 +103,7 @@ jobs:
- name: Install the dependencies
run: |
python -m pip install --user --upgrade pip wheel
python -m pip install torch==1.13.1 torchvision==0.14.1
python -m pip install torch==2.4.1 torchvision==0.19.1
cat "requirements-dev.txt"
python -m pip install -r requirements-dev.txt
python -m pip list
Expand Down Expand Up @@ -155,7 +155,7 @@ jobs:
# install the latest pytorch for testing
# however, "pip install monai*.tar.gz" will build cpp/cuda with an isolated
# fresh torch installation according to pyproject.toml
python -m pip install torch>=1.13.1 torchvision
python -m pip install torch>=2.3.0 torchvision
- name: Check packages
run: |
pip uninstall monai
Expand Down
4 changes: 2 additions & 2 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-f https://download.pytorch.org/whl/cpu/torch-1.13.1%2Bcpu-cp39-cp39-linux_x86_64.whl
torch>=1.13.1
-f https://download.pytorch.org/whl/cpu/torch-2.3.0%2Bcpu-cp39-cp39-linux_x86_64.whl
torch>=2.3.0
pytorch-ignite==0.4.11
numpy>=1.20
itk>=5.2
Expand Down
4 changes: 2 additions & 2 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ channels:
- nvidia
- conda-forge
dependencies:
- numpy>=1.24,<2.0
- pytorch>=1.13.1
- numpy>=1.24,<3.0
- pytorch>=2.3.0
- torchio
- torchvision
- pytorch-cuda>=11.6
Expand Down
2 changes: 1 addition & 1 deletion monai/apps/deepedit/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __call__(self, engine: SupervisedTrainer | SupervisedEvaluator, batchdata: d

with torch.no_grad():
if engine.amp:
with torch.cuda.amp.autocast():
with torch.autocast("cuda"):
predictions = engine.inferer(inputs, engine.network)
else:
predictions = engine.inferer(inputs, engine.network)
Expand Down
2 changes: 1 addition & 1 deletion monai/apps/deepgrow/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __call__(self, engine: SupervisedTrainer | SupervisedEvaluator, batchdata: d
engine.network.eval()
with torch.no_grad():
if engine.amp:
with torch.cuda.amp.autocast():
with torch.autocast("cuda"):
predictions = engine.inferer(inputs, engine.network)
else:
predictions = engine.inferer(inputs, engine.network)
Expand Down
2 changes: 1 addition & 1 deletion monai/apps/detection/networks/retinanet_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def forward(self, images: torch.Tensor):
nesterov=True,
)
torch.save(detector.network.state_dict(), 'model.pt') # save model
detector.network.load_state_dict(torch.load('model.pt')) # load model
detector.network.load_state_dict(torch.load('model.pt', weights_only=True)) # load model
"""

def __init__(
Expand Down
10 changes: 5 additions & 5 deletions monai/apps/detection/networks/retinanet_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def __init__(

for layer in self.conv.children():
if isinstance(layer, conv_type): # type: ignore
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.constant_(layer.bias, 0)
torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]

self.cls_logits = conv_type(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
Expand Down Expand Up @@ -167,8 +167,8 @@ def __init__(self, in_channels: int, num_anchors: int, spatial_dims: int):

for layer in self.conv.children():
if isinstance(layer, conv_type): # type: ignore
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.zeros_(layer.bias)
torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
torch.nn.init.zeros_(layer.bias) # type: ignore[arg-type]

def forward(self, x: list[Tensor]) -> list[Tensor]:
"""
Expand Down Expand Up @@ -297,7 +297,7 @@ def __init__(
)
self.feature_extractor = feature_extractor

self.feature_map_channels: int = self.feature_extractor.out_channels
self.feature_map_channels: int = self.feature_extractor.out_channels # type: ignore[assignment]
self.num_anchors = num_anchors
self.classification_head = RetinaNetClassificationHead(
self.feature_map_channels, self.num_anchors, self.num_classes, spatial_dims=self.spatial_dims
Expand Down
4 changes: 2 additions & 2 deletions monai/apps/detection/utils/box_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,15 @@ def decode_single(self, rel_codes: Tensor, reference_boxes: Tensor) -> Tensor:

pred_ctr_xyx_axis = dxyz_axis * whd_axis[:, None] + ctr_xyz_axis[:, None]
pred_whd_axis = torch.exp(dwhd_axis) * whd_axis[:, None]
pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype)
pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype) # type: ignore[union-attr]

# When convert float32 to float16, Inf or Nan may occur
if torch.isnan(pred_whd_axis).any() or torch.isinf(pred_whd_axis).any():
raise ValueError("pred_whd_axis is NaN or Inf.")

# Distance from center to box's corner.
c_to_c_whd_axis = (
torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis
torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis # type: ignore[arg-type]
)

pred_boxes.append(pred_ctr_xyx_axis - c_to_c_whd_axis)
Expand Down
2 changes: 1 addition & 1 deletion monai/apps/mmars/mmars.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def load_from_mmar(
return torch.jit.load(_model_file, map_location=map_location)

# loading with `torch.load`
model_dict = torch.load(_model_file, map_location=map_location)
model_dict = torch.load(_model_file, map_location=map_location, weights_only=True)
if weights_only:
return model_dict.get(model_key, model_dict) # model_dict[model_key] or model_dict directly

Expand Down
2 changes: 1 addition & 1 deletion monai/apps/reconstruction/networks/blocks/varnetblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def soft_dc(self, x: Tensor, ref_kspace: Tensor, mask: Tensor) -> Tensor:
Returns:
Output of DC block with the same shape as x
"""
return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight
return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight # type: ignore

def forward(self, current_kspace: Tensor, ref_kspace: Tensor, mask: Tensor, sens_maps: Tensor) -> Tensor:
"""
Expand Down
7 changes: 3 additions & 4 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def load(
if load_ts_module is True:
return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
# loading with `torch.load`
model_dict = torch.load(full_path, map_location=torch.device(device))
model_dict = torch.load(full_path, map_location=torch.device(device), weights_only=True)

if not isinstance(model_dict, Mapping):
warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.")
Expand Down Expand Up @@ -1279,9 +1279,8 @@ def verify_net_in_out(
if input_dtype == torch.float16:
# fp16 can only be executed in gpu mode
net.to("cuda")
from torch.cuda.amp import autocast

with autocast():
with torch.autocast("cuda"):
output = net(test_data.cuda(), **extra_forward_args_)
net.to(device_)
else:
Expand Down Expand Up @@ -1330,7 +1329,7 @@ def _export(
# here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
Checkpoint.load_objects(to_load={key_in_ckpt: net}, checkpoint=ckpt_file)
else:
ckpt = torch.load(ckpt_file)
ckpt = torch.load(ckpt_file, weights_only=True)
copy_model_state(dst=net, src=ckpt if key_in_ckpt == "" else ckpt[key_in_ckpt])

# Use the given converter to convert a model and save with metadata, config content
Expand Down
11 changes: 2 additions & 9 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import warnings
from collections.abc import Callable, Sequence
from copy import copy, deepcopy
from inspect import signature
from multiprocessing.managers import ListProxy
from multiprocessing.pool import ThreadPool
from pathlib import Path
Expand Down Expand Up @@ -372,10 +371,7 @@ def _cachecheck(self, item_transformed):

if hashfile is not None and hashfile.is_file(): # cache hit
try:
if "weights_only" in signature(torch.load).parameters:
return torch.load(hashfile, weights_only=False)
else:
return torch.load(hashfile)
return torch.load(hashfile, weights_only=False)
except PermissionError as e:
if sys.platform != "win32":
raise e
Expand Down Expand Up @@ -1674,7 +1670,4 @@ def _load_meta_cache(self, meta_hash_file_name):
if meta_hash_file_name in self._meta_cache:
return self._meta_cache[meta_hash_file_name]
else:
if "weights_only" in signature(torch.load).parameters:
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
else:
return torch.load(self.cache_dir / meta_hash_file_name)
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
2 changes: 1 addition & 1 deletion monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_z
if isinstance(_affine, torch.Tensor):
spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0))
else:
spacing = np.sqrt(np.sum(_affine * _affine, axis=0))
spacing = np.sqrt(np.sum(_affine * _affine, axis=0)) # type: ignore[operator]
if suppress_zeros:
spacing[spacing == 0] = 1.0
spacing_, *_ = convert_to_dst_type(spacing, dst=affine, dtype=dtype)
Expand Down
2 changes: 1 addition & 1 deletion monai/data/video_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def get_available_codecs() -> dict[str, str]:
for codec, ext in all_codecs.items():
writer = cv2.VideoWriter()
fname = os.path.join(tmp_dir, f"test{ext}")
fourcc = cv2.VideoWriter_fourcc(*codec)
fourcc = cv2.VideoWriter_fourcc(*codec) # type: ignore[attr-defined]
noviderr = writer.open(fname, fourcc, 1, (10, 10))
if noviderr:
codecs[codec] = ext
Expand Down
27 changes: 11 additions & 16 deletions monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from monai.utils import ForwardMode, IgniteInfo, ensure_tuple, min_version, optional_import
from monai.utils.enums import CommonKeys as Keys
from monai.utils.enums import EngineStatsKeys as ESKeys
from monai.utils.module import look_up_option, pytorch_after
from monai.utils.module import look_up_option

if TYPE_CHECKING:
from ignite.engine import Engine, EventEnum
Expand Down Expand Up @@ -82,8 +82,8 @@ class Evaluator(Workflow):
default to `True`.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.autocast.
"""

Expand Down Expand Up @@ -214,8 +214,8 @@ class SupervisedEvaluator(Evaluator):
default to `True`.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.autocast.
compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
`torch.Tensor` before forward pass, then converted back afterward with copied meta information.
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
Expand Down Expand Up @@ -269,13 +269,8 @@ def __init__(
amp_kwargs=amp_kwargs,
)
if compile:
if pytorch_after(2, 1):
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
else:
warnings.warn(
"Network compilation (compile=True) not supported for Pytorch versions before 2.1, no compilation done"
)
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
network = torch.compile(network, **compile_kwargs) # type: ignore[assignment]
self.network = network
self.compile = compile
self.inferer = SimpleInferer() if inferer is None else inferer
Expand Down Expand Up @@ -329,7 +324,7 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten
# execute forward computation
with engine.mode(engine.network):
if engine.amp:
with torch.cuda.amp.autocast(**engine.amp_kwargs):
with torch.autocast("cuda", **engine.amp_kwargs):
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
else:
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
Expand Down Expand Up @@ -399,8 +394,8 @@ class EnsembleEvaluator(Evaluator):
default to `True`.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.autocast.
"""

Expand Down Expand Up @@ -492,7 +487,7 @@ def _iteration(self, engine: EnsembleEvaluator, batchdata: dict[str, torch.Tenso
for idx, network in enumerate(engine.networks):
with engine.mode(network):
if engine.amp:
with torch.cuda.amp.autocast(**engine.amp_kwargs):
with torch.autocast("cuda", **engine.amp_kwargs):
if isinstance(engine.state.output, dict):
engine.state.output.update(
{engine.pred_keys[idx]: engine.inferer(inputs, network, *args, **kwargs)}
Expand Down
Loading

0 comments on commit da0a186

Please sign in to comment.