Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Pytorch 2.6 #8309

Merged
merged 24 commits into from
Mar 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
dfae9a0
Changing Numpy version limit to include 2.x but not a theoretical 3.x
ericspod Jan 20, 2025
3eabf20
Changing Numpy version limit to include 2.x but not a theoretical 3.x
ericspod Jan 20, 2025
1689c12
Merge branch 'dev' into numpy2_reintro
ericspod Feb 3, 2025
7b75792
Merge branch 'dev' into numpy2_reintro
ericspod Feb 12, 2025
c50b7aa
Merge branch 'dev' into numpy2_reintro
ericspod Feb 25, 2025
a7b615e
Cleaning up autocast usage
ericspod Feb 25, 2025
4685eca
Update torch.load usage to eliminate complaint mesages
ericspod Feb 25, 2025
199db95
Merge branch 'dev' into numpy2_reintro
ericspod Feb 25, 2025
77ff9b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 25, 2025
7ee3bf1
Formatting
ericspod Feb 25, 2025
7551ec5
Merge branch 'numpy2_reintro' of github.com:ericspod/MONAI into numpy…
ericspod Feb 25, 2025
f69ab50
Fix
ericspod Feb 25, 2025
3e0b517
Fix requirements
ericspod Feb 25, 2025
5db7a56
Change one torch.load
ericspod Feb 25, 2025
3dc4fe1
Change one torch.load and otrher fixes
ericspod Feb 25, 2025
c6a079a
Merge branch 'dev' into numpy2_reintro
ericspod Mar 4, 2025
db1d2af
Update requirements.txt
ericspod Mar 4, 2025
160df23
Fixes for mypy type checking
ericspod Mar 7, 2025
eaa9987
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2025
0cdfe12
Precision fix for test
ericspod Mar 7, 2025
8eca43a
Merge branch 'numpy2_reintro' of github.com:ericspod/MONAI into numpy…
ericspod Mar 7, 2025
6c98f21
Merger fix
ericspod Mar 7, 2025
12a6ac9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2025
310e44e
Merge branch 'dev' into numpy2_reintro
ericspod Mar 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/apps/deepedit/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __call__(self, engine: SupervisedTrainer | SupervisedEvaluator, batchdata: d

with torch.no_grad():
if engine.amp:
with torch.cuda.amp.autocast():
with torch.autocast("cuda"):
predictions = engine.inferer(inputs, engine.network)
else:
predictions = engine.inferer(inputs, engine.network)
Expand Down
2 changes: 1 addition & 1 deletion monai/apps/deepgrow/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __call__(self, engine: SupervisedTrainer | SupervisedEvaluator, batchdata: d
engine.network.eval()
with torch.no_grad():
if engine.amp:
with torch.cuda.amp.autocast():
with torch.autocast("cuda"):
predictions = engine.inferer(inputs, engine.network)
else:
predictions = engine.inferer(inputs, engine.network)
Expand Down
2 changes: 1 addition & 1 deletion monai/apps/detection/networks/retinanet_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def forward(self, images: torch.Tensor):
nesterov=True,
)
torch.save(detector.network.state_dict(), 'model.pt') # save model
detector.network.load_state_dict(torch.load('model.pt')) # load model
detector.network.load_state_dict(torch.load('model.pt', weights_only=True)) # load model
"""

def __init__(
Expand Down
10 changes: 5 additions & 5 deletions monai/apps/detection/networks/retinanet_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def __init__(

for layer in self.conv.children():
if isinstance(layer, conv_type): # type: ignore
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.constant_(layer.bias, 0)
torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type]

self.cls_logits = conv_type(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
Expand Down Expand Up @@ -167,8 +167,8 @@ def __init__(self, in_channels: int, num_anchors: int, spatial_dims: int):

for layer in self.conv.children():
if isinstance(layer, conv_type): # type: ignore
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.zeros_(layer.bias)
torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type]
torch.nn.init.zeros_(layer.bias) # type: ignore[arg-type]

def forward(self, x: list[Tensor]) -> list[Tensor]:
"""
Expand Down Expand Up @@ -297,7 +297,7 @@ def __init__(
)
self.feature_extractor = feature_extractor

self.feature_map_channels: int = self.feature_extractor.out_channels
self.feature_map_channels: int = self.feature_extractor.out_channels # type: ignore[assignment]
self.num_anchors = num_anchors
self.classification_head = RetinaNetClassificationHead(
self.feature_map_channels, self.num_anchors, self.num_classes, spatial_dims=self.spatial_dims
Expand Down
4 changes: 2 additions & 2 deletions monai/apps/detection/utils/box_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,15 @@ def decode_single(self, rel_codes: Tensor, reference_boxes: Tensor) -> Tensor:

pred_ctr_xyx_axis = dxyz_axis * whd_axis[:, None] + ctr_xyz_axis[:, None]
pred_whd_axis = torch.exp(dwhd_axis) * whd_axis[:, None]
pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype)
pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype) # type: ignore[union-attr]

# When convert float32 to float16, Inf or Nan may occur
if torch.isnan(pred_whd_axis).any() or torch.isinf(pred_whd_axis).any():
raise ValueError("pred_whd_axis is NaN or Inf.")

# Distance from center to box's corner.
c_to_c_whd_axis = (
torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis
torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis # type: ignore[arg-type]
)

pred_boxes.append(pred_ctr_xyx_axis - c_to_c_whd_axis)
Expand Down
2 changes: 1 addition & 1 deletion monai/apps/mmars/mmars.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def load_from_mmar(
return torch.jit.load(_model_file, map_location=map_location)

# loading with `torch.load`
model_dict = torch.load(_model_file, map_location=map_location)
model_dict = torch.load(_model_file, map_location=map_location, weights_only=True)
if weights_only:
return model_dict.get(model_key, model_dict) # model_dict[model_key] or model_dict directly

Expand Down
2 changes: 1 addition & 1 deletion monai/apps/reconstruction/networks/blocks/varnetblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def soft_dc(self, x: Tensor, ref_kspace: Tensor, mask: Tensor) -> Tensor:
Returns:
Output of DC block with the same shape as x
"""
return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight
return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight # type: ignore

def forward(self, current_kspace: Tensor, ref_kspace: Tensor, mask: Tensor, sens_maps: Tensor) -> Tensor:
"""
Expand Down
7 changes: 3 additions & 4 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def load(
if load_ts_module is True:
return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
# loading with `torch.load`
model_dict = torch.load(full_path, map_location=torch.device(device))
model_dict = torch.load(full_path, map_location=torch.device(device), weights_only=True)

if not isinstance(model_dict, Mapping):
warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.")
Expand Down Expand Up @@ -1279,9 +1279,8 @@ def verify_net_in_out(
if input_dtype == torch.float16:
# fp16 can only be executed in gpu mode
net.to("cuda")
from torch.cuda.amp import autocast

with autocast():
with torch.autocast("cuda"):
output = net(test_data.cuda(), **extra_forward_args_)
net.to(device_)
else:
Expand Down Expand Up @@ -1330,7 +1329,7 @@ def _export(
# here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
Checkpoint.load_objects(to_load={key_in_ckpt: net}, checkpoint=ckpt_file)
else:
ckpt = torch.load(ckpt_file)
ckpt = torch.load(ckpt_file, weights_only=True)
copy_model_state(dst=net, src=ckpt if key_in_ckpt == "" else ckpt[key_in_ckpt])

# Use the given converter to convert a model and save with metadata, config content
Expand Down
11 changes: 2 additions & 9 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import warnings
from collections.abc import Callable, Sequence
from copy import copy, deepcopy
from inspect import signature
from multiprocessing.managers import ListProxy
from multiprocessing.pool import ThreadPool
from pathlib import Path
Expand Down Expand Up @@ -372,10 +371,7 @@ def _cachecheck(self, item_transformed):

if hashfile is not None and hashfile.is_file(): # cache hit
try:
if "weights_only" in signature(torch.load).parameters:
return torch.load(hashfile, weights_only=False)
else:
return torch.load(hashfile)
return torch.load(hashfile, weights_only=False)
except PermissionError as e:
if sys.platform != "win32":
raise e
Expand Down Expand Up @@ -1674,7 +1670,4 @@ def _load_meta_cache(self, meta_hash_file_name):
if meta_hash_file_name in self._meta_cache:
return self._meta_cache[meta_hash_file_name]
else:
if "weights_only" in signature(torch.load).parameters:
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
else:
return torch.load(self.cache_dir / meta_hash_file_name)
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
2 changes: 1 addition & 1 deletion monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_z
if isinstance(_affine, torch.Tensor):
spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0))
else:
spacing = np.sqrt(np.sum(_affine * _affine, axis=0))
spacing = np.sqrt(np.sum(_affine * _affine, axis=0)) # type: ignore[operator]
if suppress_zeros:
spacing[spacing == 0] = 1.0
spacing_, *_ = convert_to_dst_type(spacing, dst=affine, dtype=dtype)
Expand Down
2 changes: 1 addition & 1 deletion monai/data/video_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def get_available_codecs() -> dict[str, str]:
for codec, ext in all_codecs.items():
writer = cv2.VideoWriter()
fname = os.path.join(tmp_dir, f"test{ext}")
fourcc = cv2.VideoWriter_fourcc(*codec)
fourcc = cv2.VideoWriter_fourcc(*codec) # type: ignore[attr-defined]
noviderr = writer.open(fname, fourcc, 1, (10, 10))
if noviderr:
codecs[codec] = ext
Expand Down
16 changes: 8 additions & 8 deletions monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ class Evaluator(Workflow):
default to `True`.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.autocast.

"""

Expand Down Expand Up @@ -214,8 +214,8 @@ class SupervisedEvaluator(Evaluator):
default to `True`.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.autocast.
compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
`torch.Tensor` before forward pass, then converted back afterward with copied meta information.
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
Expand Down Expand Up @@ -324,7 +324,7 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten
# execute forward computation
with engine.mode(engine.network):
if engine.amp:
with torch.cuda.amp.autocast(**engine.amp_kwargs):
with torch.autocast("cuda", **engine.amp_kwargs):
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
else:
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
Expand Down Expand Up @@ -394,8 +394,8 @@ class EnsembleEvaluator(Evaluator):
default to `True`.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.autocast.

"""

Expand Down Expand Up @@ -487,7 +487,7 @@ def _iteration(self, engine: EnsembleEvaluator, batchdata: dict[str, torch.Tenso
for idx, network in enumerate(engine.networks):
with engine.mode(network):
if engine.amp:
with torch.cuda.amp.autocast(**engine.amp_kwargs):
with torch.autocast("cuda", **engine.amp_kwargs):
if isinstance(engine.state.output, dict):
engine.state.output.update(
{engine.pred_keys[idx]: engine.inferer(inputs, network, *args, **kwargs)}
Expand Down
18 changes: 9 additions & 9 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ class SupervisedTrainer(Trainer):
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.autocast.
compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to
`torch.Tensor` before forward pass, then converted back afterward with copied meta information.
compile_kwargs: dict of the args for `torch.compile()` API, for more details:
Expand Down Expand Up @@ -249,7 +249,7 @@ def _compute_pred_loss():
engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)

if engine.amp and engine.scaler is not None:
with torch.cuda.amp.autocast(**engine.amp_kwargs):
with torch.autocast("cuda", **engine.amp_kwargs):
_compute_pred_loss()
engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
Expand Down Expand Up @@ -335,8 +335,8 @@ class GanTrainer(Trainer):
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.autocast.

"""

Expand Down Expand Up @@ -512,8 +512,8 @@ class AdversarialTrainer(Trainer):
more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.autocast.
"""

def __init__(
Expand Down Expand Up @@ -683,7 +683,7 @@ def _compute_generator_loss() -> None:
engine.state.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)

if engine.amp and engine.state.g_scaler is not None:
with torch.cuda.amp.autocast(**engine.amp_kwargs):
with torch.autocast("cuda", **engine.amp_kwargs):
_compute_generator_loss()

engine.state.output[Keys.LOSS] = (
Expand Down Expand Up @@ -731,7 +731,7 @@ def _compute_discriminator_loss() -> None:
engine.state.d_network.zero_grad(set_to_none=engine.optim_set_to_none)

if engine.amp and engine.state.d_scaler is not None:
with torch.cuda.amp.autocast(**engine.amp_kwargs):
with torch.autocast("cuda", **engine.amp_kwargs):
_compute_discriminator_loss()

engine.state.d_scaler.scale(engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS]).backward()
Expand Down
2 changes: 1 addition & 1 deletion monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_nam
self.scheduler = scheduler

def get_target(self, images, noise, timesteps):
return self.scheduler.get_velocity(images, noise, timesteps)
return self.scheduler.get_velocity(images, noise, timesteps) # type: ignore[operator]


def default_make_latent(
Expand Down
4 changes: 2 additions & 2 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ class Workflow(Engine):
default to `True`.
to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
`device`, `non_blocking`.
amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details:
https://pytorch.org/docs/stable/amp.html#torch.autocast.

Raises:
TypeError: When ``data_loader`` is not a ``torch.utils.data.DataLoader``.
Expand Down
2 changes: 1 addition & 1 deletion monai/fl/client/monai_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def get_weights(self, extra=None):
model_path = os.path.join(self.bundle_root, cast(str, self.model_filepaths[model_type]))
if not os.path.isfile(model_path):
raise ValueError(f"No best model checkpoint exists at {model_path}")
weights = torch.load(model_path, map_location="cpu")
weights = torch.load(model_path, map_location="cpu", weights_only=True)
# if weights contain several state dicts, use the one defined by `save_dict_key`
if isinstance(weights, dict) and self.save_dict_key in weights:
weights = weights.get(self.save_dict_key)
Expand Down
2 changes: 1 addition & 1 deletion monai/handlers/checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __call__(self, engine: Engine) -> None:
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
checkpoint = torch.load(self.load_path, map_location=self.map_location)
checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=False)

k, _ = list(self.load_dict.items())[0]
# single object and checkpoint is directly a state_dict
Expand Down
Loading
Loading