From 7c26e5af385eb5f7a813fa405c6f3fc87b7511fa Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Sat, 8 Mar 2025 00:14:24 +0000 Subject: [PATCH] Enable Pytorch 2.6 (#8309) Partially addresses #8303. ### Description This changes the maximum Numpy version to be below 3.0 for testing with 2.x compatibility. This appears to be resolved with newer versions of dependencies. This will also include fixes for Pytorch 2.6 mostly relating to `torch.load` and `autocast` usage. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Eric Kerfoot Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- monai/apps/deepedit/interaction.py | 2 +- monai/apps/deepgrow/interaction.py | 2 +- .../detection/networks/retinanet_detector.py | 2 +- .../detection/networks/retinanet_network.py | 10 +- monai/apps/detection/utils/box_coder.py | 4 +- monai/apps/mmars/mmars.py | 2 +- .../networks/blocks/varnetblock.py | 2 +- monai/bundle/scripts.py | 7 +- monai/data/dataset.py | 11 +-- monai/data/utils.py | 2 +- monai/data/video_dataset.py | 2 +- monai/engines/evaluator.py | 16 ++-- monai/engines/trainer.py | 18 ++-- monai/engines/utils.py | 2 +- monai/engines/workflow.py | 4 +- monai/fl/client/monai_algo.py | 2 +- monai/handlers/checkpoint_loader.py | 2 +- monai/inferers/inferer.py | 12 +-- monai/inferers/merger.py | 29 +++--- monai/losses/perceptual.py | 2 +- monai/losses/sure_loss.py | 2 +- .../blocks/feature_pyramid_network.py | 6 +- monai/networks/layers/vector_quantizer.py | 4 +- monai/networks/nets/hovernet.py | 9 +- monai/networks/nets/resnet.py | 4 +- monai/networks/nets/senet.py | 2 +- monai/networks/nets/swin_unetr.py | 95 +++++++++---------- monai/networks/nets/transchex.py | 5 +- monai/networks/nets/vista3d.py | 14 +-- monai/networks/utils.py | 9 +- monai/transforms/intensity/array.py | 2 +- monai/transforms/spatial/array.py | 12 +-- monai/utils/state_cacher.py | 2 +- requirements.txt | 4 +- runtests.sh | 2 +- tests/bundle/test_bundle_download.py | 22 +++-- tests/config/test_cv2_dist.py | 3 +- tests/data/meta_tensor/test_meta_tensor.py | 4 +- .../test_integration_classification_2d.py | 2 +- .../test_integration_fast_train.py | 4 +- .../test_integration_segmentation_3d.py | 2 +- .../test_compute_multiscalessim_metric.py | 6 +- tests/networks/nets/test_autoencoderkl.py | 2 +- tests/networks/nets/test_controlnet.py | 2 +- .../nets/test_diffusion_model_unet.py | 2 +- .../networks/nets/test_network_consistency.py | 2 +- tests/networks/nets/test_swin_unetr.py | 2 +- tests/networks/nets/test_transformer.py | 2 +- tests/networks/test_save_state.py | 2 +- 49 files changed, 184 insertions(+), 178 deletions(-) diff --git a/monai/apps/deepedit/interaction.py b/monai/apps/deepedit/interaction.py index 07302575c6..33e50700ca 100644 --- a/monai/apps/deepedit/interaction.py +++ b/monai/apps/deepedit/interaction.py @@ -72,7 +72,7 @@ def __call__(self, engine: SupervisedTrainer | SupervisedEvaluator, batchdata: d with torch.no_grad(): if engine.amp: - with torch.cuda.amp.autocast(): + with torch.autocast("cuda"): predictions = engine.inferer(inputs, engine.network) else: predictions = engine.inferer(inputs, engine.network) diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index fa3a28bfef..287f2d607c 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -67,7 +67,7 @@ def __call__(self, engine: SupervisedTrainer | SupervisedEvaluator, batchdata: d engine.network.eval() with torch.no_grad(): if engine.amp: - with torch.cuda.amp.autocast(): + with torch.autocast("cuda"): predictions = engine.inferer(inputs, engine.network) else: predictions = engine.inferer(inputs, engine.network) diff --git a/monai/apps/detection/networks/retinanet_detector.py b/monai/apps/detection/networks/retinanet_detector.py index a0573d6cd1..e996ae81bc 100644 --- a/monai/apps/detection/networks/retinanet_detector.py +++ b/monai/apps/detection/networks/retinanet_detector.py @@ -180,7 +180,7 @@ def forward(self, images: torch.Tensor): nesterov=True, ) torch.save(detector.network.state_dict(), 'model.pt') # save model - detector.network.load_state_dict(torch.load('model.pt')) # load model + detector.network.load_state_dict(torch.load('model.pt', weights_only=True)) # load model """ def __init__( diff --git a/monai/apps/detection/networks/retinanet_network.py b/monai/apps/detection/networks/retinanet_network.py index ca6a8f5c19..ead57d74c2 100644 --- a/monai/apps/detection/networks/retinanet_network.py +++ b/monai/apps/detection/networks/retinanet_network.py @@ -88,8 +88,8 @@ def __init__( for layer in self.conv.children(): if isinstance(layer, conv_type): # type: ignore - torch.nn.init.normal_(layer.weight, std=0.01) - torch.nn.init.constant_(layer.bias, 0) + torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type] + torch.nn.init.constant_(layer.bias, 0) # type: ignore[arg-type] self.cls_logits = conv_type(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1) torch.nn.init.normal_(self.cls_logits.weight, std=0.01) @@ -167,8 +167,8 @@ def __init__(self, in_channels: int, num_anchors: int, spatial_dims: int): for layer in self.conv.children(): if isinstance(layer, conv_type): # type: ignore - torch.nn.init.normal_(layer.weight, std=0.01) - torch.nn.init.zeros_(layer.bias) + torch.nn.init.normal_(layer.weight, std=0.01) # type: ignore[arg-type] + torch.nn.init.zeros_(layer.bias) # type: ignore[arg-type] def forward(self, x: list[Tensor]) -> list[Tensor]: """ @@ -297,7 +297,7 @@ def __init__( ) self.feature_extractor = feature_extractor - self.feature_map_channels: int = self.feature_extractor.out_channels + self.feature_map_channels: int = self.feature_extractor.out_channels # type: ignore[assignment] self.num_anchors = num_anchors self.classification_head = RetinaNetClassificationHead( self.feature_map_channels, self.num_anchors, self.num_classes, spatial_dims=self.spatial_dims diff --git a/monai/apps/detection/utils/box_coder.py b/monai/apps/detection/utils/box_coder.py index 504ae21d0f..d0f3adf71d 100644 --- a/monai/apps/detection/utils/box_coder.py +++ b/monai/apps/detection/utils/box_coder.py @@ -221,7 +221,7 @@ def decode_single(self, rel_codes: Tensor, reference_boxes: Tensor) -> Tensor: pred_ctr_xyx_axis = dxyz_axis * whd_axis[:, None] + ctr_xyz_axis[:, None] pred_whd_axis = torch.exp(dwhd_axis) * whd_axis[:, None] - pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype) + pred_whd_axis = pred_whd_axis.to(dxyz_axis.dtype) # type: ignore[union-attr] # When convert float32 to float16, Inf or Nan may occur if torch.isnan(pred_whd_axis).any() or torch.isinf(pred_whd_axis).any(): @@ -229,7 +229,7 @@ def decode_single(self, rel_codes: Tensor, reference_boxes: Tensor) -> Tensor: # Distance from center to box's corner. c_to_c_whd_axis = ( - torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis + torch.tensor(0.5, dtype=pred_ctr_xyx_axis.dtype, device=pred_whd_axis.device) * pred_whd_axis # type: ignore[arg-type] ) pred_boxes.append(pred_ctr_xyx_axis - c_to_c_whd_axis) diff --git a/monai/apps/mmars/mmars.py b/monai/apps/mmars/mmars.py index 31c88a17be..1fc0690cc9 100644 --- a/monai/apps/mmars/mmars.py +++ b/monai/apps/mmars/mmars.py @@ -241,7 +241,7 @@ def load_from_mmar( return torch.jit.load(_model_file, map_location=map_location) # loading with `torch.load` - model_dict = torch.load(_model_file, map_location=map_location) + model_dict = torch.load(_model_file, map_location=map_location, weights_only=True) if weights_only: return model_dict.get(model_key, model_dict) # model_dict[model_key] or model_dict directly diff --git a/monai/apps/reconstruction/networks/blocks/varnetblock.py b/monai/apps/reconstruction/networks/blocks/varnetblock.py index 75dc7e15ce..289505a057 100644 --- a/monai/apps/reconstruction/networks/blocks/varnetblock.py +++ b/monai/apps/reconstruction/networks/blocks/varnetblock.py @@ -55,7 +55,7 @@ def soft_dc(self, x: Tensor, ref_kspace: Tensor, mask: Tensor) -> Tensor: Returns: Output of DC block with the same shape as x """ - return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight + return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight # type: ignore def forward(self, current_kspace: Tensor, ref_kspace: Tensor, mask: Tensor, sens_maps: Tensor) -> Tensor: """ diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index b43f7e0fa0..6f35179e96 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -760,7 +760,7 @@ def load( if load_ts_module is True: return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files) # loading with `torch.load` - model_dict = torch.load(full_path, map_location=torch.device(device)) + model_dict = torch.load(full_path, map_location=torch.device(device), weights_only=True) if not isinstance(model_dict, Mapping): warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.") @@ -1279,9 +1279,8 @@ def verify_net_in_out( if input_dtype == torch.float16: # fp16 can only be executed in gpu mode net.to("cuda") - from torch.cuda.amp import autocast - with autocast(): + with torch.autocast("cuda"): output = net(test_data.cuda(), **extra_forward_args_) net.to(device_) else: @@ -1330,7 +1329,7 @@ def _export( # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver Checkpoint.load_objects(to_load={key_in_ckpt: net}, checkpoint=ckpt_file) else: - ckpt = torch.load(ckpt_file) + ckpt = torch.load(ckpt_file, weights_only=True) copy_model_state(dst=net, src=ckpt if key_in_ckpt == "" else ckpt[key_in_ckpt]) # Use the given converter to convert a model and save with metadata, config content diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 8c53338d66..691425994d 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -22,7 +22,6 @@ import warnings from collections.abc import Callable, Sequence from copy import copy, deepcopy -from inspect import signature from multiprocessing.managers import ListProxy from multiprocessing.pool import ThreadPool from pathlib import Path @@ -372,10 +371,7 @@ def _cachecheck(self, item_transformed): if hashfile is not None and hashfile.is_file(): # cache hit try: - if "weights_only" in signature(torch.load).parameters: - return torch.load(hashfile, weights_only=False) - else: - return torch.load(hashfile) + return torch.load(hashfile, weights_only=False) except PermissionError as e: if sys.platform != "win32": raise e @@ -1674,7 +1670,4 @@ def _load_meta_cache(self, meta_hash_file_name): if meta_hash_file_name in self._meta_cache: return self._meta_cache[meta_hash_file_name] else: - if "weights_only" in signature(torch.load).parameters: - return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False) - else: - return torch.load(self.cache_dir / meta_hash_file_name) + return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False) diff --git a/monai/data/utils.py b/monai/data/utils.py index d03dbd3234..988b813272 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -753,7 +753,7 @@ def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_z if isinstance(_affine, torch.Tensor): spacing = torch.sqrt(torch.sum(_affine * _affine, dim=0)) else: - spacing = np.sqrt(np.sum(_affine * _affine, axis=0)) + spacing = np.sqrt(np.sum(_affine * _affine, axis=0)) # type: ignore[operator] if suppress_zeros: spacing[spacing == 0] = 1.0 spacing_, *_ = convert_to_dst_type(spacing, dst=affine, dtype=dtype) diff --git a/monai/data/video_dataset.py b/monai/data/video_dataset.py index 031e85db26..9ff23ebeff 100644 --- a/monai/data/video_dataset.py +++ b/monai/data/video_dataset.py @@ -177,7 +177,7 @@ def get_available_codecs() -> dict[str, str]: for codec, ext in all_codecs.items(): writer = cv2.VideoWriter() fname = os.path.join(tmp_dir, f"test{ext}") - fourcc = cv2.VideoWriter_fourcc(*codec) + fourcc = cv2.VideoWriter_fourcc(*codec) # type: ignore[attr-defined] noviderr = writer.open(fname, fourcc, 1, (10, 10)) if noviderr: codecs[codec] = ext diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 35d4928465..836b407ac5 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -82,8 +82,8 @@ class Evaluator(Workflow): default to `True`. to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. - amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: - https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.autocast. """ @@ -214,8 +214,8 @@ class SupervisedEvaluator(Evaluator): default to `True`. to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. - amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: - https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.autocast. compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to `torch.Tensor` before forward pass, then converted back afterward with copied meta information. compile_kwargs: dict of the args for `torch.compile()` API, for more details: @@ -324,7 +324,7 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten # execute forward computation with engine.mode(engine.network): if engine.amp: - with torch.cuda.amp.autocast(**engine.amp_kwargs): + with torch.autocast("cuda", **engine.amp_kwargs): engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) else: engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) @@ -394,8 +394,8 @@ class EnsembleEvaluator(Evaluator): default to `True`. to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. - amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: - https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.autocast. """ @@ -487,7 +487,7 @@ def _iteration(self, engine: EnsembleEvaluator, batchdata: dict[str, torch.Tenso for idx, network in enumerate(engine.networks): with engine.mode(network): if engine.amp: - with torch.cuda.amp.autocast(**engine.amp_kwargs): + with torch.autocast("cuda", **engine.amp_kwargs): if isinstance(engine.state.output, dict): engine.state.output.update( {engine.pred_keys[idx]: engine.inferer(inputs, network, *args, **kwargs)} diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index fdb45fbab8..b69a5015bb 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -125,8 +125,8 @@ class SupervisedTrainer(Trainer): more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. - amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: - https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.autocast. compile: whether to use `torch.compile`, default is False. If True, MetaTensor inputs will be converted to `torch.Tensor` before forward pass, then converted back afterward with copied meta information. compile_kwargs: dict of the args for `torch.compile()` API, for more details: @@ -249,7 +249,7 @@ def _compute_pred_loss(): engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none) if engine.amp and engine.scaler is not None: - with torch.cuda.amp.autocast(**engine.amp_kwargs): + with torch.autocast("cuda", **engine.amp_kwargs): _compute_pred_loss() engine.scaler.scale(engine.state.output[Keys.LOSS]).backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) @@ -335,8 +335,8 @@ class GanTrainer(Trainer): more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. - amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: - https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.autocast. """ @@ -512,8 +512,8 @@ class AdversarialTrainer(Trainer): more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. - amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: - https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.autocast. """ def __init__( @@ -683,7 +683,7 @@ def _compute_generator_loss() -> None: engine.state.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none) if engine.amp and engine.state.g_scaler is not None: - with torch.cuda.amp.autocast(**engine.amp_kwargs): + with torch.autocast("cuda", **engine.amp_kwargs): _compute_generator_loss() engine.state.output[Keys.LOSS] = ( @@ -731,7 +731,7 @@ def _compute_discriminator_loss() -> None: engine.state.d_network.zero_grad(set_to_none=engine.optim_set_to_none) if engine.amp and engine.state.d_scaler is not None: - with torch.cuda.amp.autocast(**engine.amp_kwargs): + with torch.autocast("cuda", **engine.amp_kwargs): _compute_discriminator_loss() engine.state.d_scaler.scale(engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS]).backward() diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 8e19a18601..9095f8d943 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -309,7 +309,7 @@ def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_nam self.scheduler = scheduler def get_target(self, images, noise, timesteps): - return self.scheduler.get_velocity(images, noise, timesteps) + return self.scheduler.get_velocity(images, noise, timesteps) # type: ignore[operator] def default_make_latent( diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 0c36da6d3d..ecb0c4a070 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -90,8 +90,8 @@ class Workflow(Engine): default to `True`. to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for `device`, `non_blocking`. - amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: - https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + amp_kwargs: dict of the args for `torch.autocast("cuda")` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.autocast. Raises: TypeError: When ``data_loader`` is not a ``torch.utils.data.DataLoader``. diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index a3ac58c221..6e9a6fd1fe 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -574,7 +574,7 @@ def get_weights(self, extra=None): model_path = os.path.join(self.bundle_root, cast(str, self.model_filepaths[model_type])) if not os.path.isfile(model_path): raise ValueError(f"No best model checkpoint exists at {model_path}") - weights = torch.load(model_path, map_location="cpu") + weights = torch.load(model_path, map_location="cpu", weights_only=True) # if weights contain several state dicts, use the one defined by `save_dict_key` if isinstance(weights, dict) and self.save_dict_key in weights: weights = weights.get(self.save_dict_key) diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index f48968ecfd..16cb875d03 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -122,7 +122,7 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - checkpoint = torch.load(self.load_path, map_location=self.map_location) + checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=False) k, _ = list(self.load_dict.items())[0] # single object and checkpoint is directly a state_dict diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 156677d992..df23b9aea0 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -882,7 +882,7 @@ def sample( ) # 2. compute previous image: x_t -> x_t-1 - image, _ = scheduler.step(model_output, t, image) + image, _ = scheduler.step(model_output, t, image) # type: ignore[operator] if save_intermediates and t % intermediate_steps == 0: intermediates.append(image) if save_intermediates: @@ -986,8 +986,8 @@ def get_likelihood( predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image # get the posterior mean and variance - posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) - posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator] + posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator] log_posterior_variance = torch.log(posterior_variance) log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance @@ -1436,7 +1436,7 @@ def sample( # type: ignore[override] ) # 3. compute previous image: x_t -> x_t-1 - image, _ = scheduler.step(model_output, t, image) + image, _ = scheduler.step(model_output, t, image) # type: ignore[operator] if save_intermediates and t % intermediate_steps == 0: intermediates.append(image) if save_intermediates: @@ -1562,8 +1562,8 @@ def get_likelihood( # type: ignore[override] predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image # get the posterior mean and variance - posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) - posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) + posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator] + posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator] log_posterior_variance = torch.log(posterior_variance) log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index 1344207e18..a1ab8e8a56 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -53,8 +53,11 @@ def __init__( cropped_shape: Sequence[int] | None = None, device: torch.device | str | None = None, ) -> None: - self.merged_shape = merged_shape - self.cropped_shape = self.merged_shape if cropped_shape is None else cropped_shape + if merged_shape is None: + raise ValueError("Argument `merged_shape` must be provided") + + self.merged_shape: tuple[int, ...] = tuple(merged_shape) + self.cropped_shape: tuple[int, ...] = tuple(self.merged_shape if cropped_shape is None else cropped_shape) self.device = device self.is_finalized = False @@ -231,9 +234,9 @@ def __init__( dtype: np.dtype | str = "float32", value_dtype: np.dtype | str = "float32", count_dtype: np.dtype | str = "uint8", - store: zarr.storage.Store | str = "merged.zarr", - value_store: zarr.storage.Store | str | None = None, - count_store: zarr.storage.Store | str | None = None, + store: zarr.storage.Store | str = "merged.zarr", # type: ignore + value_store: zarr.storage.Store | str | None = None, # type: ignore + count_store: zarr.storage.Store | str | None = None, # type: ignore compressor: str | None = None, value_compressor: str | None = None, count_compressor: str | None = None, @@ -251,18 +254,18 @@ def __init__( if version_geq(get_package_version("zarr"), "3.0.0"): if value_store is None: self.tmpdir = TemporaryDirectory() - self.value_store = zarr.storage.LocalStore(self.tmpdir.name) + self.value_store = zarr.storage.LocalStore(self.tmpdir.name) # type: ignore else: - self.value_store = value_store + self.value_store = value_store # type: ignore if count_store is None: self.tmpdir = TemporaryDirectory() - self.count_store = zarr.storage.LocalStore(self.tmpdir.name) + self.count_store = zarr.storage.LocalStore(self.tmpdir.name) # type: ignore else: - self.count_store = count_store + self.count_store = count_store # type: ignore else: self.tmpdir = None - self.value_store = zarr.storage.TempStore() if value_store is None else value_store - self.count_store = zarr.storage.TempStore() if count_store is None else count_store + self.value_store = zarr.storage.TempStore() if value_store is None else value_store # type: ignore + self.count_store = zarr.storage.TempStore() if count_store is None else count_store # type: ignore self.chunks = chunks self.compressor = compressor self.value_compressor = value_compressor @@ -314,7 +317,7 @@ def aggregate(self, values: torch.Tensor, location: Sequence[int]) -> None: map_slice = ensure_tuple_size(map_slice, values.ndim, pad_val=slice(None), pad_from_start=True) with self.lock: self.values[map_slice] += values.numpy() - self.counts[map_slice] += 1 + self.counts[map_slice] += 1 # type: ignore[operator] def finalize(self) -> zarr.Array: """ @@ -332,7 +335,7 @@ def finalize(self) -> zarr.Array: if not self.is_finalized: # use chunks for division to fit into memory for chunk in iterate_over_chunks(self.values.chunks, self.values.cdata_shape): - self.output[chunk] = self.values[chunk] / self.counts[chunk] + self.output[chunk] = self.values[chunk] / self.counts[chunk] # type: ignore[operator] # finalize the shape self.output.resize(self.cropped_shape) # set finalize flag to protect performing in-place division again diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py index a8ae90993a..ee653fac9d 100644 --- a/monai/losses/perceptual.py +++ b/monai/losses/perceptual.py @@ -374,7 +374,7 @@ def __init__( else: network = torchvision.models.resnet50(weights=None) if pretrained is True: - state_dict = torch.load(pretrained_path) + state_dict = torch.load(pretrained_path, weights_only=True) if pretrained_state_dict_key is not None: state_dict = state_dict[pretrained_state_dict_key] network.load_state_dict(state_dict) diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py index ebf25613a6..fa8820885d 100644 --- a/monai/losses/sure_loss.py +++ b/monai/losses/sure_loss.py @@ -92,7 +92,7 @@ def sure_loss_function( y_ref = operator(x) # get perturbed output - x_perturbed = x + eps * perturb_noise + x_perturbed = x + eps * perturb_noise # type: ignore y_perturbed = operator(x_perturbed) # divergence divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # type: ignore diff --git a/monai/networks/blocks/feature_pyramid_network.py b/monai/networks/blocks/feature_pyramid_network.py index 7de899803c..759a4efe0d 100644 --- a/monai/networks/blocks/feature_pyramid_network.py +++ b/monai/networks/blocks/feature_pyramid_network.py @@ -54,7 +54,9 @@ from collections import OrderedDict from collections.abc import Callable +from typing import cast +import torch import torch.nn.functional as F from torch import Tensor, nn @@ -194,8 +196,8 @@ def __init__( conv_type_: type[nn.Module] = Conv[Conv.CONV, spatial_dims] for m in self.modules(): if isinstance(m, conv_type_): - nn.init.kaiming_uniform_(m.weight, a=1) - nn.init.constant_(m.bias, 0.0) + nn.init.kaiming_uniform_(cast(torch.Tensor, m.weight), a=1) + nn.init.constant_(cast(torch.Tensor, m.bias), 0.0) if extra_blocks is not None: if not isinstance(extra_blocks, ExtraFPNBlock): diff --git a/monai/networks/layers/vector_quantizer.py b/monai/networks/layers/vector_quantizer.py index 9c354e1009..0ff7143b69 100644 --- a/monai/networks/layers/vector_quantizer.py +++ b/monai/networks/layers/vector_quantizer.py @@ -100,7 +100,7 @@ def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, to torch.Tensor: Quantization indices of shape [B,H,W,D,1] """ - with torch.cuda.amp.autocast(enabled=False): + with torch.autocast("cuda", enabled=False): encoding_indices_view = list(inputs.shape) del encoding_indices_view[1] @@ -138,7 +138,7 @@ def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Quantize space representation of encoding_indices in channel first format. """ - with torch.cuda.amp.autocast(enabled=False): + with torch.autocast("cuda", enabled=False): embedding: torch.Tensor = ( self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous() ) diff --git a/monai/networks/nets/hovernet.py b/monai/networks/nets/hovernet.py index 3745b66bb5..b773af91d4 100644 --- a/monai/networks/nets/hovernet.py +++ b/monai/networks/nets/hovernet.py @@ -633,9 +633,9 @@ def _remap_preact_resnet_model(model_url: str): # download the pretrained weights into torch hub's default dir weights_dir = os.path.join(torch.hub.get_dir(), "preact-resnet50.pth") download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False) - state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu"))[ - "desc" - ] + map_location = None if torch.cuda.is_available() else torch.device("cpu") + state_dict = torch.load(weights_dir, map_location=map_location, weights_only=True)["desc"] + for key in list(state_dict.keys()): new_key = None if pattern_conv0.match(key): @@ -668,7 +668,8 @@ def _remap_standard_resnet_model(model_url: str, state_dict_key: str | None = No # download the pretrained weights into torch hub's default dir weights_dir = os.path.join(torch.hub.get_dir(), "resnet50.pth") download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False) - state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu")) + map_location = None if torch.cuda.is_available() else torch.device("cpu") + state_dict = torch.load(weights_dir, map_location=map_location, weights_only=True) if state_dict_key is not None: state_dict = state_dict[state_dict_key] diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index d62722478e..d24b86d27d 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -493,7 +493,7 @@ def _resnet( if isinstance(pretrained, str): if Path(pretrained).exists(): logger.info(f"Loading weights from {pretrained}...") - model_state_dict = torch.load(pretrained, map_location=device) + model_state_dict = torch.load(pretrained, map_location=device, weights_only=True) else: # Throw error raise FileNotFoundError("The pretrained checkpoint file is not found") @@ -665,7 +665,7 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", dat raise EntryNotFoundError( f"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}" ) from None - checkpoint = torch.load(pretrained_path, map_location=torch.device(device)) + checkpoint = torch.load(pretrained_path, map_location=torch.device(device), weights_only=True) else: raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]") logger.info(f"{filename} downloaded") diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py index 51435a9ea2..c14118ad20 100644 --- a/monai/networks/nets/senet.py +++ b/monai/networks/nets/senet.py @@ -302,7 +302,7 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool): if isinstance(model_url, dict): download_url(model_url["url"], filepath=model_url["filename"]) - state_dict = torch.load(model_url["filename"], map_location=None) + state_dict = torch.load(model_url["filename"], map_location=None, weights_only=True) else: state_dict = load_state_dict_from_url(model_url, progress=progress) for key in list(state_dict.keys()): diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index cfc5dda41f..22e1e6f659 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -272,53 +272,50 @@ def __init__( self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels) def load_from(self, weights): + layers1_0: BasicLayer = self.swinViT.layers1[0] # type: ignore[assignment] + layers2_0: BasicLayer = self.swinViT.layers2[0] # type: ignore[assignment] + layers3_0: BasicLayer = self.swinViT.layers3[0] # type: ignore[assignment] + layers4_0: BasicLayer = self.swinViT.layers4[0] # type: ignore[assignment] + wstate = weights["state_dict"] + with torch.no_grad(): - self.swinViT.patch_embed.proj.weight.copy_(weights["state_dict"]["module.patch_embed.proj.weight"]) - self.swinViT.patch_embed.proj.bias.copy_(weights["state_dict"]["module.patch_embed.proj.bias"]) - for bname, block in self.swinViT.layers1[0].blocks.named_children(): - block.load_from(weights, n_block=bname, layer="layers1") - self.swinViT.layers1[0].downsample.reduction.weight.copy_( - weights["state_dict"]["module.layers1.0.downsample.reduction.weight"] - ) - self.swinViT.layers1[0].downsample.norm.weight.copy_( - weights["state_dict"]["module.layers1.0.downsample.norm.weight"] - ) - self.swinViT.layers1[0].downsample.norm.bias.copy_( - weights["state_dict"]["module.layers1.0.downsample.norm.bias"] - ) - for bname, block in self.swinViT.layers2[0].blocks.named_children(): - block.load_from(weights, n_block=bname, layer="layers2") - self.swinViT.layers2[0].downsample.reduction.weight.copy_( - weights["state_dict"]["module.layers2.0.downsample.reduction.weight"] - ) - self.swinViT.layers2[0].downsample.norm.weight.copy_( - weights["state_dict"]["module.layers2.0.downsample.norm.weight"] - ) - self.swinViT.layers2[0].downsample.norm.bias.copy_( - weights["state_dict"]["module.layers2.0.downsample.norm.bias"] - ) - for bname, block in self.swinViT.layers3[0].blocks.named_children(): - block.load_from(weights, n_block=bname, layer="layers3") - self.swinViT.layers3[0].downsample.reduction.weight.copy_( - weights["state_dict"]["module.layers3.0.downsample.reduction.weight"] - ) - self.swinViT.layers3[0].downsample.norm.weight.copy_( - weights["state_dict"]["module.layers3.0.downsample.norm.weight"] - ) - self.swinViT.layers3[0].downsample.norm.bias.copy_( - weights["state_dict"]["module.layers3.0.downsample.norm.bias"] - ) - for bname, block in self.swinViT.layers4[0].blocks.named_children(): - block.load_from(weights, n_block=bname, layer="layers4") - self.swinViT.layers4[0].downsample.reduction.weight.copy_( - weights["state_dict"]["module.layers4.0.downsample.reduction.weight"] - ) - self.swinViT.layers4[0].downsample.norm.weight.copy_( - weights["state_dict"]["module.layers4.0.downsample.norm.weight"] - ) - self.swinViT.layers4[0].downsample.norm.bias.copy_( - weights["state_dict"]["module.layers4.0.downsample.norm.bias"] - ) + self.swinViT.patch_embed.proj.weight.copy_(wstate["module.patch_embed.proj.weight"]) + self.swinViT.patch_embed.proj.bias.copy_(wstate["module.patch_embed.proj.bias"]) + for bname, block in layers1_0.blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers1") # type: ignore[operator] + + if layers1_0.downsample is not None: + d = layers1_0.downsample + d.reduction.weight.copy_(wstate["module.layers1.0.downsample.reduction.weight"]) # type: ignore + d.norm.weight.copy_(wstate["module.layers1.0.downsample.norm.weight"]) # type: ignore + d.norm.bias.copy_(wstate["module.layers1.0.downsample.norm.bias"]) # type: ignore + + for bname, block in layers2_0.blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers2") # type: ignore[operator] + + if layers2_0.downsample is not None: + d = layers2_0.downsample + d.reduction.weight.copy_(wstate["module.layers2.0.downsample.reduction.weight"]) # type: ignore + d.norm.weight.copy_(wstate["module.layers2.0.downsample.norm.weight"]) # type: ignore + d.norm.bias.copy_(wstate["module.layers2.0.downsample.norm.bias"]) # type: ignore + + for bname, block in layers3_0.blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers3") # type: ignore[operator] + + if layers3_0.downsample is not None: + d = layers3_0.downsample + d.reduction.weight.copy_(wstate["module.layers3.0.downsample.reduction.weight"]) # type: ignore + d.norm.weight.copy_(wstate["module.layers3.0.downsample.norm.weight"]) # type: ignore + d.norm.bias.copy_(wstate["module.layers3.0.downsample.norm.bias"]) # type: ignore + + for bname, block in layers4_0.blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers4") # type: ignore[operator] + + if layers4_0.downsample is not None: + d = layers4_0.downsample + d.reduction.weight.copy_(wstate["module.layers4.0.downsample.reduction.weight"]) # type: ignore + d.norm.weight.copy_(wstate["module.layers4.0.downsample.norm.weight"]) # type: ignore + d.norm.bias.copy_(wstate["module.layers4.0.downsample.norm.bias"]) # type: ignore @torch.jit.unused def _check_input_size(self, spatial_shape): @@ -532,7 +529,7 @@ def forward(self, x, mask): q = q * self.scale attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index.clone()[:n, :n].reshape(-1) + self.relative_position_index.clone()[:n, :n].reshape(-1) # type: ignore[operator] ].reshape(n, n, -1) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() attn = attn + relative_position_bias.unsqueeze(0) @@ -691,7 +688,7 @@ def load_from(self, weights, n_block, layer): self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]]) self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]]) self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]]) - self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]]) + self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]]) # type: ignore[operator] self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]]) self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]]) self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]]) @@ -1118,7 +1115,7 @@ def filter_swinunetr(key, value): ) ssl_weights_path = "./ssl_pretrained_weights.pth" download_url(resource, ssl_weights_path) - ssl_weights = torch.load(ssl_weights_path)["model"] + ssl_weights = torch.load(ssl_weights_path, weights_only=True)["model"] dst_dict, loaded, not_loaded = copy_model_state(model, ssl_weights, filter_func=filter_swinunetr) diff --git a/monai/networks/nets/transchex.py b/monai/networks/nets/transchex.py index 6bfff3c956..bd756ec214 100644 --- a/monai/networks/nets/transchex.py +++ b/monai/networks/nets/transchex.py @@ -43,7 +43,7 @@ def __init__(self, *inputs, **kwargs) -> None: def init_bert_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) # type: ignore[union-attr,arg-type] elif isinstance(module, torch.nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) @@ -68,7 +68,8 @@ def from_pretrained( weights_path = cached_file(path_or_repo_id, filename, cache_dir=cache_dir) model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs) if state_dict is None and not from_tf: - state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None) + map_location = "cpu" if not torch.cuda.is_available() else None + state_dict = torch.load(weights_path, map_location=map_location, weights_only=True) if from_tf: return load_tf_weights_in_bert(model, weights_path) old_keys = [] diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index 6ecb664b85..a5c2cc13ef 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -315,7 +315,7 @@ def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False): """ if auto_freeze != self.auto_freeze: if hasattr(self.image_encoder, "set_auto_grad"): - self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) + self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) # type: ignore[operator] else: for param in self.image_encoder.parameters(): param.requires_grad = (not auto_freeze) and (not point_freeze) @@ -325,7 +325,7 @@ def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False): if point_freeze != self.point_freeze: if hasattr(self.image_encoder, "set_auto_grad"): - self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) + self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) # type: ignore[operator] else: for param in self.image_encoder.parameters(): param.requires_grad = (not auto_freeze) and (not point_freeze) @@ -543,10 +543,10 @@ def forward( point_embedding = self.pe_layer.forward_with_coords(points, out_shape) # type: ignore point_embedding[point_labels == -1] = 0.0 point_embedding[point_labels == -1] += self.not_a_point_embed.weight - point_embedding[point_labels == 0] += self.point_embeddings[0].weight - point_embedding[point_labels == 1] += self.point_embeddings[1].weight - point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight - point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight + point_embedding[point_labels == 0] += self.point_embeddings[0].weight # type: ignore[arg-type] + point_embedding[point_labels == 1] += self.point_embeddings[1].weight # type: ignore[arg-type] + point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight # type: ignore[operator] + point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight # type: ignore[operator] output_tokens = self.mask_tokens.weight output_tokens = output_tokens.unsqueeze(0).expand(point_embedding.size(0), -1, -1) @@ -884,7 +884,7 @@ def _pe_encoding(self, coords: torch.torch.Tensor) -> torch.torch.Tensor: coords = 2 * coords - 1 # [bs=1,N=2,2] @ [2,128] # [bs=1, N=2, 128] - coords = coords @ self.positional_encoding_gaussian_matrix + coords = coords @ self.positional_encoding_gaussian_matrix # type: ignore[operator] coords = 2 * np.pi * coords # outputs d_1 x ... x d_n x C shape # [bs=1, N=2, 128+128=256] diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 2279bed0b4..a41d4b1e33 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -22,7 +22,7 @@ from collections.abc import Callable, Mapping, Sequence from contextlib import contextmanager from copy import deepcopy -from typing import Any +from typing import Any, Iterable import numpy as np import torch @@ -1238,7 +1238,7 @@ def __init__(self, mod): def forward(self, x): dtype = x.dtype - with torch.amp.autocast("cuda", enabled=False): + with torch.autocast("cuda", enabled=False): ret = self.mod.forward(x.to(torch.float32)).to(dtype) return ret @@ -1255,7 +1255,7 @@ def __init__(self, mod): def forward(self, *args): from_dtype = args[0].dtype - with torch.amp.autocast("cuda", enabled=False): + with torch.autocast("cuda", enabled=False): ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)) return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) @@ -1291,7 +1291,8 @@ def simple_replace(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable def expansion_fn(mod: nn.Module) -> nn.Module | None: if not isinstance(mod, base_t): return None - args = [getattr(mod, name, None) for name in mod.__constants__] + constants: Iterable = mod.__constants__ # type: ignore[assignment] + args = [getattr(mod, name, None) for name in constants] out = dest_t(*args) return out diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 8fe658ad3e..ed0a1ad9ac 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1856,7 +1856,7 @@ def interp(self, x: NdarrayOrTensor, xp: NdarrayOrTensor, fp: NdarrayOrTensor) - indices = ns.searchsorted(xp.reshape(-1), x.reshape(-1)) - 1 indices = ns.clip(indices, 0, len(m) - 1) - f = (m[indices] * x.reshape(-1) + b[indices]).reshape(x.shape) + f: NdarrayOrTensor = (m[indices] * x.reshape(-1) + b[indices]).reshape(x.shape) f[x < xp[0]] = fp[0] f[x > xp[-1]] = fp[-1] return f diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index e4ed196eff..a75bb390cd 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1758,13 +1758,13 @@ def __call__( if self.affine is None: affine = torch.eye(spatial_dims + 1, device=_device) if self.rotate_params: - affine @= create_rotate(spatial_dims, self.rotate_params, device=_device, backend=_b) + affine @= create_rotate(spatial_dims, self.rotate_params, device=_device, backend=_b) # type: ignore[assignment] if self.shear_params: - affine @= create_shear(spatial_dims, self.shear_params, device=_device, backend=_b) + affine @= create_shear(spatial_dims, self.shear_params, device=_device, backend=_b) # type: ignore[assignment] if self.translate_params: - affine @= create_translate(spatial_dims, self.translate_params, device=_device, backend=_b) + affine @= create_translate(spatial_dims, self.translate_params, device=_device, backend=_b) # type: ignore[assignment] if self.scale_params: - affine @= create_scale(spatial_dims, self.scale_params, device=_device, backend=_b) + affine @= create_scale(spatial_dims, self.scale_params, device=_device, backend=_b) # type: ignore[assignment] else: affine = self.affine # type: ignore affine = to_affine_nd(spatial_dims, affine) @@ -1780,7 +1780,7 @@ def __call__( grid_ = ((affine @ sc) @ grid_.view((grid_.shape[0], -1))).view([-1] + list(grid_.shape[1:])) else: grid_ = (affine @ grid_.view((grid_.shape[0], -1))).view([-1] + list(grid_.shape[1:])) - return grid_, affine + return grid_, affine # type: ignore[return-value] class RandAffineGrid(Randomizable, LazyTransform): @@ -3257,7 +3257,7 @@ def filter_threshold(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tuple[NdarrayOrTensor, numpy.ndarray]: tuple of filtered patches and locations. """ n_dims = len(image_np.shape) - idx = argwhere(image_np.sum(tuple(range(1, n_dims))) < self.threshold).reshape(-1) + idx = argwhere(image_np.sum(tuple(range(1, n_dims))) < self.threshold).reshape(-1) # type: ignore[operator] idx_np = convert_data_type(idx, np.ndarray)[0] return image_np[idx], locations[idx_np] diff --git a/monai/utils/state_cacher.py b/monai/utils/state_cacher.py index 60a074544b..c59436525c 100644 --- a/monai/utils/state_cacher.py +++ b/monai/utils/state_cacher.py @@ -124,7 +124,7 @@ def retrieve(self, key: Hashable) -> Any: fn = self.cached[key]["obj"] # pytype: disable=attribute-error if not os.path.exists(fn): # pytype: disable=wrong-arg-types raise RuntimeError(f"Failed to load state in {fn}. File doesn't exist anymore.") - data_obj = torch.load(fn, map_location=lambda storage, location: storage) + data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=False) # copy back to device if necessary if "device" in self.cached[key]: data_obj = data_obj.to(self.cached[key]["device"]) diff --git a/requirements.txt b/requirements.txt index 452a62adda..ad394ce807 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -torch>=2.3.0,<2.6; sys_platform != 'win32' -torch>=2.4.1,<2.6; sys_platform == 'win32' +torch>=2.3.0; sys_platform != 'win32' +torch>=2.4.1; sys_platform == 'win32' numpy>=1.24,<3.0 diff --git a/runtests.sh b/runtests.sh index 2a399d5c3a..fd7df79722 100755 --- a/runtests.sh +++ b/runtests.sh @@ -120,7 +120,7 @@ function print_usage { # FIXME: https://github.com/Project-MONAI/MONAI/issues/4354 protobuf_major_version=$("${PY_EXE}" -m pip list | grep '^protobuf ' | tr -s ' ' | cut -d' ' -f2 | cut -d'.' -f1) -if [ "$protobuf_major_version" -ge "4" ] +if [ ! -z "$protobuf_major_version" ] && [ "$protobuf_major_version" -ge "4" ] then export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python fi diff --git a/tests/bundle/test_bundle_download.py b/tests/bundle/test_bundle_download.py index 38620d98ff..da58a6313e 100644 --- a/tests/bundle/test_bundle_download.py +++ b/tests/bundle/test_bundle_download.py @@ -266,6 +266,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) with skip_if_downloading_fails(): # download bundle, and load weights from the downloaded path with tempfile.TemporaryDirectory() as tempdir: + bundle_root = os.path.join(tempdir, bundle_name) # load weights weights = load( name=bundle_name, @@ -278,7 +279,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) return_state_dict=True, ) # prepare network - with open(os.path.join(tempdir, bundle_name, bundle_files[2])) as f: + with open(os.path.join(bundle_root, bundle_files[2])) as f: net_args = json.load(f)["network_def"] model_name = net_args["_target_"] del net_args["_target_"] @@ -288,9 +289,13 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file) model.eval() # prepare data and test - input_tensor = torch.load(os.path.join(tempdir, bundle_name, bundle_files[4]), map_location=device) + input_tensor = torch.load( + os.path.join(bundle_root, bundle_files[4]), map_location=device, weights_only=True + ) output = model.forward(input_tensor) - expected_output = torch.load(os.path.join(tempdir, bundle_name, bundle_files[3]), map_location=device) + expected_output = torch.load( + os.path.join(bundle_root, bundle_files[3]), map_location=device, weights_only=True + ) assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False) # load instantiated model directly and test, since the bundle has been downloaded, @@ -350,7 +355,7 @@ def test_load_weights_with_net_override(self, bundle_name, device, net_override) config_file=f"{tempdir}/spleen_ct_segmentation/configs/train.json", workflow_type="train" ) expected_model = workflow.network_def.to(device) - expected_model.load_state_dict(torch.load(model_path)) + expected_model.load_state_dict(torch.load(model_path, weights_only=True)) expected_output = expected_model(input_tensor) assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False) @@ -378,6 +383,7 @@ def test_load_ts_module(self, bundle_files, bundle_name, version, repo, device, with skip_if_downloading_fails(): # load ts module with tempfile.TemporaryDirectory() as tempdir: + bundle_root = os.path.join(tempdir, bundle_name) # load ts module model_ts, metadata, extra_file_dict = load( name=bundle_name, @@ -393,9 +399,13 @@ def test_load_ts_module(self, bundle_files, bundle_name, version, repo, device, ) # prepare and test ts - input_tensor = torch.load(os.path.join(tempdir, bundle_name, bundle_files[1]), map_location=device) + input_tensor = torch.load( + os.path.join(bundle_root, bundle_files[1]), map_location=device, weights_only=True + ) output = model_ts.forward(input_tensor) - expected_output = torch.load(os.path.join(tempdir, bundle_name, bundle_files[0]), map_location=device) + expected_output = torch.load( + os.path.join(bundle_root, bundle_files[0]), map_location=device, weights_only=True + ) assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False) # test metadata self.assertTrue(metadata["pytorch_version"] == "1.7.1") diff --git a/tests/config/test_cv2_dist.py b/tests/config/test_cv2_dist.py index 2ef8e5b10f..3bcb68e553 100644 --- a/tests/config/test_cv2_dist.py +++ b/tests/config/test_cv2_dist.py @@ -16,7 +16,6 @@ import numpy as np import torch import torch.distributed as dist -from torch.cuda.amp import autocast # FIXME: test for the workaround of https://github.com/Project-MONAI/MONAI/issues/5291 from monai.config.deviceconfig import print_config @@ -33,7 +32,7 @@ def main_worker(rank, ngpus_per_node, port): model, device_ids=[rank], output_device=rank, find_unused_parameters=False ) x = torch.ones(1, 1, 12, 12, 12).to(rank) - with autocast(enabled=True): + with torch.autocast("cuda"): model(x) if dist.is_initialized(): diff --git a/tests/data/meta_tensor/test_meta_tensor.py b/tests/data/meta_tensor/test_meta_tensor.py index cd3def4de1..f52d70e7b6 100644 --- a/tests/data/meta_tensor/test_meta_tensor.py +++ b/tests/data/meta_tensor/test_meta_tensor.py @@ -245,7 +245,7 @@ def test_pickling(self): with tempfile.TemporaryDirectory() as tmp_dir: fname = os.path.join(tmp_dir, "im.pt") torch.save(m, fname) - m2 = torch.load(fname) + m2 = torch.load(fname, weights_only=False) self.check(m2, m, ids=False) @skip_if_no_cuda @@ -256,7 +256,7 @@ def test_amp(self): conv = torch.nn.Conv2d(im.shape[1], 5, 3) conv.to(device) im_conv = conv(im) - with torch.cuda.amp.autocast(): + with torch.autocast("cuda"): im_conv2 = conv(im) self.check(im_conv2, im_conv, ids=False, rtol=1e-2, atol=1e-2) diff --git a/tests/integration/test_integration_classification_2d.py b/tests/integration/test_integration_classification_2d.py index fd9e58aaf8..aecfa2efab 100644 --- a/tests/integration/test_integration_classification_2d.py +++ b/tests/integration/test_integration_classification_2d.py @@ -166,7 +166,7 @@ def run_inference_test(root_dir, test_x, test_y, device="cuda:0", num_workers=10 model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=len(np.unique(test_y))).to(device) model_filename = os.path.join(root_dir, "best_metric_model.pth") - model.load_state_dict(torch.load(model_filename)) + model.load_state_dict(torch.load(model_filename, weights_only=True)) y_true = [] y_pred = [] with eval_mode(model): diff --git a/tests/integration/test_integration_fast_train.py b/tests/integration/test_integration_fast_train.py index f9beb5613d..814c4b182c 100644 --- a/tests/integration/test_integration_fast_train.py +++ b/tests/integration/test_integration_fast_train.py @@ -186,7 +186,7 @@ def test_train_timing(self): step += 1 optimizer.zero_grad() # set AMP for training - with torch.cuda.amp.autocast(): + with torch.autocast("cuda"): outputs = model(batch_data["image"]) loss = loss_function(outputs, batch_data["label"]) scaler.scale(loss).backward() @@ -207,7 +207,7 @@ def test_train_timing(self): roi_size = (96, 96, 96) sw_batch_size = 4 # set AMP for validation - with torch.cuda.amp.autocast(): + with torch.autocast("cuda"): val_outputs = sliding_window_inference(val_data["image"], roi_size, sw_batch_size, model) val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)] diff --git a/tests/integration/test_integration_segmentation_3d.py b/tests/integration/test_integration_segmentation_3d.py index fb2937739f..7c30150505 100644 --- a/tests/integration/test_integration_segmentation_3d.py +++ b/tests/integration/test_integration_segmentation_3d.py @@ -216,7 +216,7 @@ def run_inference_test(root_dir, device="cuda:0"): ).to(device) model_filename = os.path.join(root_dir, "best_metric_model.pth") - model.load_state_dict(torch.load(model_filename)) + model.load_state_dict(torch.load(model_filename, weights_only=True)) with eval_mode(model): # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 diff --git a/tests/metrics/test_compute_multiscalessim_metric.py b/tests/metrics/test_compute_multiscalessim_metric.py index 3df8026c2b..d85e6f7bf6 100644 --- a/tests/metrics/test_compute_multiscalessim_metric.py +++ b/tests/metrics/test_compute_multiscalessim_metric.py @@ -32,7 +32,7 @@ def test2d_gaussian(self): metric(preds, target) result = metric.aggregate() expected_value = 0.023176 - self.assertTrue(expected_value - result.item() < 0.000001) + self.assertAlmostEqual(expected_value, result.item(), 4) def test2d_uniform(self): set_determinism(0) @@ -45,7 +45,7 @@ def test2d_uniform(self): metric(preds, target) result = metric.aggregate() expected_value = 0.022655 - self.assertTrue(expected_value - result.item() < 0.000001) + self.assertAlmostEqual(expected_value, result.item(), 4) def test3d_gaussian(self): set_determinism(0) @@ -58,7 +58,7 @@ def test3d_gaussian(self): metric(preds, target) result = metric.aggregate() expected_value = 0.061796 - self.assertTrue(expected_value - result.item() < 0.000001) + self.assertAlmostEqual(expected_value, result.item(), 4) def input_ill_input_shape2d(self): metric = MultiScaleSSIMMetric(spatial_dims=3, weights=[0.5, 0.5]) diff --git a/tests/networks/nets/test_autoencoderkl.py b/tests/networks/nets/test_autoencoderkl.py index 0a3db60830..2d4c5b66ca 100644 --- a/tests/networks/nets/test_autoencoderkl.py +++ b/tests/networks/nets/test_autoencoderkl.py @@ -330,7 +330,7 @@ def test_compatibility_with_monai_generative(self): weight_path = os.path.join(tmpdir, filename) download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) - net.load_old_state_dict(torch.load(weight_path), verbose=False) + net.load_old_state_dict(torch.load(weight_path, weights_only=True), verbose=False) if __name__ == "__main__": diff --git a/tests/networks/nets/test_controlnet.py b/tests/networks/nets/test_controlnet.py index 9503518762..6158dc2eef 100644 --- a/tests/networks/nets/test_controlnet.py +++ b/tests/networks/nets/test_controlnet.py @@ -208,7 +208,7 @@ def test_compatibility_with_monai_generative(self): weight_path = os.path.join(tmpdir, filename) download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) - net.load_old_state_dict(torch.load(weight_path), verbose=False) + net.load_old_state_dict(torch.load(weight_path, weights_only=True), verbose=False) if __name__ == "__main__": diff --git a/tests/networks/nets/test_diffusion_model_unet.py b/tests/networks/nets/test_diffusion_model_unet.py index a7c823709d..3bca26882c 100644 --- a/tests/networks/nets/test_diffusion_model_unet.py +++ b/tests/networks/nets/test_diffusion_model_unet.py @@ -578,7 +578,7 @@ def test_compatibility_with_monai_generative(self): weight_path = os.path.join(tmpdir, filename) download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) - net.load_old_state_dict(torch.load(weight_path), verbose=False) + net.load_old_state_dict(torch.load(weight_path, weights_only=True), verbose=False) if __name__ == "__main__": diff --git a/tests/networks/nets/test_network_consistency.py b/tests/networks/nets/test_network_consistency.py index e09826de75..4ce198b92f 100644 --- a/tests/networks/nets/test_network_consistency.py +++ b/tests/networks/nets/test_network_consistency.py @@ -55,7 +55,7 @@ def test_network_consistency(self, net_name, data_path, json_path): print("JSON path: " + json_path) # Load data - loaded_data = torch.load(data_path) + loaded_data = torch.load(data_path, weights_only=True) # Load json from file json_file = open(json_path) diff --git a/tests/networks/nets/test_swin_unetr.py b/tests/networks/nets/test_swin_unetr.py index 4908907bfe..2c4532ecc4 100644 --- a/tests/networks/nets/test_swin_unetr.py +++ b/tests/networks/nets/test_swin_unetr.py @@ -128,7 +128,7 @@ def test_filter_swinunetr(self, input_param, key, value): data_spec["url"], weight_path, hash_val=data_spec["hash_val"], hash_type=data_spec["hash_type"] ) - ssl_weight = torch.load(weight_path)["model"] + ssl_weight = torch.load(weight_path, weights_only=True)["model"] net = SwinUNETR(**input_param) dst_dict, loaded, not_loaded = copy_model_state(net, ssl_weight, filter_func=filter_swinunetr) assert_allclose(dst_dict[key][:8], value, atol=1e-4, rtol=1e-4, type_test=False) diff --git a/tests/networks/nets/test_transformer.py b/tests/networks/nets/test_transformer.py index f9264ba153..daf424c174 100644 --- a/tests/networks/nets/test_transformer.py +++ b/tests/networks/nets/test_transformer.py @@ -101,7 +101,7 @@ def test_compatibility_with_monai_generative(self): weight_path = os.path.join(tmpdir, filename) download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) - net.load_old_state_dict(torch.load(weight_path), verbose=False) + net.load_old_state_dict(torch.load(weight_path, weights_only=True), verbose=False) if __name__ == "__main__": diff --git a/tests/networks/test_save_state.py b/tests/networks/test_save_state.py index 0581a3ce1f..329065da2b 100644 --- a/tests/networks/test_save_state.py +++ b/tests/networks/test_save_state.py @@ -64,7 +64,7 @@ def test_file(self, src, expected_keys, create_dir=True, atomic=True, func=None, if kwargs is None: kwargs = {} save_state(src=src, path=path, create_dir=create_dir, atomic=atomic, func=func, **kwargs) - ckpt = dict(torch.load(path)) + ckpt = dict(torch.load(path, weights_only=True)) for k in ckpt.keys(): self.assertIn(k, expected_keys)