diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst index 4e3a32b6fe..49297e1a3d 100644 --- a/docs/source/bundle.rst +++ b/docs/source/bundle.rst @@ -34,6 +34,13 @@ Model Bundle :members: :special-members: + +`nnUNet Bundle` +--------------- +.. autoclass:: ModelnnUNetWrapper + :members: + :special-members: + `Scripts` --------- .. autofunction:: ckpt_export @@ -50,3 +57,6 @@ Model Bundle .. autofunction:: init_bundle .. autofunction:: push_to_hf_hub .. autofunction:: update_kwargs +.. autofunction:: get_nnunet_trainer +.. autofunction:: get_nnunet_monai_predictor +.. autofunction:: convert_nnunet_to_monai_bundle diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index 3f3c8d545e..305bf9eb6a 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -13,6 +13,7 @@ from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable from .config_parser import ConfigParser +from .nnunet import ModelnnUNetWrapper, convert_nnunet_to_monai_bundle, get_nnunet_monai_predictor, get_nnunet_trainer from .properties import InferProperties, MetaProperties, TrainProperties from .reference_resolver import ReferenceResolver from .scripts import ( diff --git a/monai/bundle/nnunet.py b/monai/bundle/nnunet.py new file mode 100644 index 0000000000..4eca036c17 --- /dev/null +++ b/monai/bundle/nnunet.py @@ -0,0 +1,427 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +import shutil +from pathlib import Path + +import numpy as np +import torch +from torch.backends import cudnn + +from monai.data.meta_tensor import MetaTensor +from monai.utils import optional_import + +join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join") +load_json, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_json") + +__all__ = ["get_nnunet_trainer", "get_nnunet_monai_predictor", "convert_nnunet_to_monai_bundle", "ModelnnUNetWrapper"] + + +def get_nnunet_trainer( + dataset_name_or_id, + configuration, + fold, + trainer_class_name="nnUNetTrainer", + plans_identifier="nnUNetPlans", + pretrained_weights=None, + num_gpus=1, + use_compressed_data=False, + export_validation_probabilities=False, + continue_training=False, + only_run_validation=False, + disable_checkpointing=False, + val_with_best=False, + device="cuda", + pretrained_model=None, +): + """ + Get the nnUNet trainer instance based on the provided configuration. + The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network, + optimizer, loss function, DataLoader, etc. + + Example:: + + from monai.apps import SupervisedTrainer + from monai.bundle.nnunet import get_nnunet_trainer + + dataset_name_or_id = 'Task101_PROSTATE' + fold = 0 + configuration = '3d_fullres' + nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold) + + trainer = SupervisedTrainer( + device=nnunet_trainer.device, + max_epochs=nnunet_trainer.num_epochs, + train_data_loader=nnunet_trainer.dataloader_train, + network=nnunet_trainer.network, + optimizer=nnunet_trainer.optimizer, + loss_function=nnunet_trainer.loss_function, + epoch_length=nnunet_trainer.num_iterations_per_epoch, + ) + + Parameters + ---------- + dataset_name_or_id : Union[str, int] + The name or ID of the dataset to be used. + configuration : str + The configuration name for the training. + fold : Union[int, str] + The fold number or 'all' for cross-validation. + trainer_class_name : str, optional + The class name of the trainer to be used. Default is 'nnUNetTrainer'. + plans_identifier : str, optional + Identifier for the plans to be used. Default is 'nnUNetPlans'. + pretrained_weights : str, optional + Path to the pretrained weights file. + num_gpus : int, optional + Number of GPUs to be used. Default is 1. + use_compressed_data : bool, optional + Whether to use compressed data. Default is False. + export_validation_probabilities : bool, optional + Whether to export validation probabilities. Default is False. + continue_training : bool, optional + Whether to continue training from a checkpoint. Default is False. + only_run_validation : bool, optional + Whether to only run validation. Default is False. + disable_checkpointing : bool, optional + Whether to disable checkpointing. Default is False. + val_with_best : bool, optional + Whether to validate with the best model. Default is False. + device : str, optional + The device to be used for training. Default is 'cuda'. + pretrained_model : str, optional + Path to the pretrained model file. + Returns + ------- + nnunet_trainer + The nnUNet trainer instance. + """ + # From nnUNet/nnunetv2/run/run_training.py#run_training + if isinstance(fold, str): + if fold != "all": + try: + fold = int(fold) + except ValueError as e: + print( + f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!' + ) + raise e + + if int(num_gpus) > 1: + ... # Disable for now + else: + from nnunetv2.run.run_training import get_trainer_from_args, maybe_load_checkpoint + + nnunet_trainer = get_trainer_from_args( + str(dataset_name_or_id), + configuration, + fold, + trainer_class_name, + plans_identifier, + use_compressed_data, + device=torch.device(device), + ) + if disable_checkpointing: + nnunet_trainer.disable_checkpointing = disable_checkpointing + + assert not (continue_training and only_run_validation), "Cannot set --c and --val flag at the same time. Dummy." + + maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights) + nnunet_trainer.on_train_start() # Added to Initialize Trainer + if torch.cuda.is_available(): + cudnn.deterministic = False + cudnn.benchmark = True + + if pretrained_model is not None: + state_dict = torch.load(pretrained_model) + if "network_weights" in state_dict: + nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"]) + return nnunet_trainer + + +class ModelnnUNetWrapper(torch.nn.Module): + """ + A wrapper class for nnUNet model integration with MONAI framework. + The wrapper can be used to integrate the nnUNet Bundle within MONAI framework for inference. + + Parameters + ---------- + predictor : object + The nnUNet predictor object used for inference. + model_folder : str + The folder path where the model and related files are stored. + model_name : str, optional + The name of the model file, by default "model.pt". + + Attributes + ---------- + predictor : nnUNetPredictor + The nnUNet predictor object used for inference. + network_weights : torch.nn.Module + The network weights of the model. + + Notes + ----- + This class integrates nnUNet model with MONAI framework by loading necessary configurations, + restoring network architecture, and setting up the predictor for inference. + """ + + def __init__(self, predictor, model_folder, model_name="model.pt"): + super().__init__() + self.predictor = predictor + + model_training_output_dir = model_folder + use_folds = ["0"] + + from nnunetv2.utilities.plans_handling.plans_handler import PlansManager + + # Block Added from nnUNet/nnunetv2/inference/predict_from_raw_data.py#nnUNetPredictor + dataset_json = load_json(join(model_training_output_dir, "dataset.json")) + plans = load_json(join(model_training_output_dir, "plans.json")) + plans_manager = PlansManager(plans) + + if isinstance(use_folds, str): + use_folds = [use_folds] + + parameters = [] + for i, f in enumerate(use_folds): + f = str(f) if f != "all" else f + checkpoint = torch.load( + join(model_training_output_dir, "nnunet_checkpoint.pth"), map_location=torch.device("cpu") + ) + monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu")) + if i == 0: + trainer_name = checkpoint["trainer_name"] + configuration_name = checkpoint["init_args"]["configuration"] + inference_allowed_mirroring_axes = ( + checkpoint["inference_allowed_mirroring_axes"] + if "inference_allowed_mirroring_axes" in checkpoint.keys() + else None + ) + + parameters.append(monai_checkpoint["network_weights"]) + + configuration_manager = plans_manager.get_configuration(configuration_name) + # restore network + import nnunetv2 + from nnunetv2.utilities.find_class_by_name import recursive_find_python_class + from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels + + num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) + trainer_class = recursive_find_python_class( + join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), trainer_name, "nnunetv2.training.nnUNetTrainer" + ) + if trainer_class is None: + raise RuntimeError( + f"Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. " + f"Please place it there (in any .py file)!" + ) + network = trainer_class.build_network_architecture( + configuration_manager.network_arch_class_name, + configuration_manager.network_arch_init_kwargs, + configuration_manager.network_arch_init_kwargs_req_import, + num_input_channels, + plans_manager.get_label_manager(dataset_json).num_segmentation_heads, + enable_deep_supervision=False, + ) + + predictor.plans_manager = plans_manager + predictor.configuration_manager = configuration_manager + predictor.list_of_parameters = parameters + predictor.network = network + predictor.dataset_json = dataset_json + predictor.trainer_name = trainer_name + predictor.allowed_mirroring_axes = inference_allowed_mirroring_axes + predictor.label_manager = plans_manager.get_label_manager(dataset_json) + if ( + ("nnUNet_compile" in os.environ.keys()) + and (os.environ["nnUNet_compile"].lower() in ("true", "1", "t")) + # and not isinstance(predictor.network, OptimizedModule) + ): + print("Using torch.compile") + predictor.network = torch.compile(self.network) + # End Block + self.network_weights = self.predictor.network + + def forward(self, x): + """ + Forward pass for the nnUNet model. + + :no-index: + + Args: + x (Union[torch.Tensor, Tuple[MetaTensor]]): Input tensor or a tuple of MetaTensors. If the input is a tuple, + it is assumed to be a decollated batch (list of tensors). Otherwise, it is assumed to be a collated batch. + + Returns: + MetaTensor: The output tensor with the same metadata as the input. + + Raises: + TypeError: If the input is not a torch.Tensor or a tuple of MetaTensors. + + Notes: + - If the input is a tuple, the filenames are extracted from the metadata of each tensor in the tuple. + - If the input is a collated batch, the filenames are extracted from the metadata of the input tensor. + - The filenames are used to generate predictions using the nnUNet predictor. + - The predictions are converted to torch tensors, with added batch and channel dimensions. + - The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata. + """ + if type(x) is tuple: # if batch is decollated (list of tensors) + input_files = [img.meta["filename_or_obj"][0] for img in x] + else: # if batch is collated + input_files = x.meta["filename_or_obj"] + if isinstance(input_files, str): + input_files = [input_files] + + # input_files should be a list of file paths, one per modality + prediction_output = self.predictor.predict_from_files( + [input_files], + None, + save_probabilities=False, + overwrite=True, + num_processes_preprocessing=2, + num_processes_segmentation_export=2, + folder_with_segs_from_prev_stage=None, + num_parts=1, + part_id=0, + ) + # prediction_output is a list of numpy arrays, with dimensions (H, W, D), output from ArgMax + + out_tensors = [] + for out in prediction_output: # Add batch and channel dimensions + out_tensors.append(torch.from_numpy(np.expand_dims(np.expand_dims(out, 0), 0))) + out_tensor = torch.cat(out_tensors, 0) # Concatenate along batch dimension + + if type(x) is tuple: + return MetaTensor(out_tensor, meta=x[0].meta) + else: + return MetaTensor(out_tensor, meta=x.meta) + + +def get_nnunet_monai_predictor(model_folder, model_name="model.pt"): + """ + Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`. + The model folder should contain the following files, created during training: + + - dataset.json: from the nnUNet results folder + - plans.json: from the nnUNet results folder + - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration + - model.pt: The checkpoint file containing the model weights. + + The returned wrapper object can be used for inference with MONAI framework: + + Example:: + + from monai.bundle.nnunet import get_nnunet_monai_predictor + + model_folder = 'path/to/monai_bundle/model' + model_name = 'model.pt' + wrapper = get_nnunet_monai_predictor(model_folder, model_name) + + # Perform inference + input_data = ... + output = wrapper(input_data) + + + Parameters + ---------- + model_folder : str + The folder where the model is stored. + model_name : str, optional + The name of the model file, by default "model.pt". + + Returns + ------- + nnUNetMONAIModelWrapper + A wrapper object that contains the nnUNetPredictor and the loaded model. + """ + + from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor + + predictor = nnUNetPredictor( + tile_step_size=0.5, + use_gaussian=True, + use_mirroring=False, + device=torch.device("cuda", 0), + verbose=False, + verbose_preprocessing=False, + allow_tqdm=True, + ) + # initializes the network architecture, loads the checkpoint + wrapper = ModelnnUNetWrapper(predictor, model_folder, model_name) + return wrapper + + +def convert_nnunet_to_monai_bundle(nnunet_config, bundle_root_folder, fold=0): + """ + Convert nnUNet model checkpoints and configuration to MONAI bundle format. + + Parameters + ---------- + nnunet_config : dict + Configuration dictionary for nnUNet, containing keys such as 'dataset_name_or_id', 'nnunet_configuration', + 'nnunet_trainer', and 'nnunet_plans'. + bundle_root_folder : str + Root folder where the MONAI bundle will be saved. + fold : int, optional + Fold number of the nnUNet model to be converted, by default 0. + + Returns + ------- + None + """ + + nnunet_trainer = "nnUNetTrainer" + nnunet_plans = "nnUNetPlans" + nnunet_configuration = "3d_fullres" + + if "nnunet_trainer" in nnunet_config: + nnunet_trainer = nnunet_config["nnunet_trainer"] + + if "nnunet_plans" in nnunet_config: + nnunet_plans = nnunet_config["nnunet_plans"] + + if "nnunet_configuration" in nnunet_config: + nnunet_configuration = nnunet_config["nnunet_configuration"] + + from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name + + dataset_name = maybe_convert_to_dataset_name(nnunet_config["dataset_name_or_id"]) + nnunet_model_folder = Path(os.environ["nnUNet_results"]).joinpath( + dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}" + ) + + nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth")) + nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth")) + + nnunet_checkpoint = {} + nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"] + nnunet_checkpoint["init_args"] = nnunet_checkpoint_final["init_args"] + nnunet_checkpoint["trainer_name"] = nnunet_checkpoint_final["trainer_name"] + + torch.save(nnunet_checkpoint, Path(bundle_root_folder).joinpath("models", "nnunet_checkpoint.pth")) + + monai_last_checkpoint = {} + monai_last_checkpoint["network_weights"] = nnunet_checkpoint_final["network_weights"] + torch.save(monai_last_checkpoint, Path(bundle_root_folder).joinpath("models", "model.pt")) + + monai_best_checkpoint = {} + monai_best_checkpoint["network_weights"] = nnunet_checkpoint_best["network_weights"] + torch.save(monai_best_checkpoint, Path(bundle_root_folder).joinpath("models", "best_model.pt")) + + shutil.copy( + Path(nnunet_model_folder).joinpath("plans.json"), Path(bundle_root_folder).joinpath("models", "plans.json") + ) + shutil.copy( + Path(nnunet_model_folder).joinpath("dataset.json"), Path(bundle_root_folder).joinpath("models", "dataset.json") + ) diff --git a/tests/integration/test_integration_nnunet_bundle.py b/tests/integration/test_integration_nnunet_bundle.py new file mode 100644 index 0000000000..4e04f3f5cf --- /dev/null +++ b/tests/integration/test_integration_nnunet_bundle.py @@ -0,0 +1,145 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import tempfile +import unittest +from pathlib import Path + +import numpy as np + +from monai.apps.nnunet import nnUNetV2Runner +from monai.bundle.config_parser import ConfigParser +from monai.bundle.nnunet import convert_nnunet_to_monai_bundle, get_nnunet_monai_predictor, get_nnunet_trainer +from monai.data import DataLoader, Dataset, create_test_image_3d +from monai.transforms import Compose, Decollated, EnsureChannelFirstd, LoadImaged, SaveImaged, Transposed +from monai.utils import optional_import +from tests.test_utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick + +_, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter") +_, has_nnunet = optional_import("nnunetv2") + +sim_datalist: dict[str, list[dict]] = { + "testing": [{"image": "val_001.fake.nii.gz"}, {"image": "val_002.fake.nii.gz"}], + "training": [ + {"fold": 0, "image": "tr_image_001.fake.nii.gz", "label": "tr_label_001.fake.nii.gz"}, + {"fold": 0, "image": "tr_image_002.fake.nii.gz", "label": "tr_label_002.fake.nii.gz"}, + {"fold": 1, "image": "tr_image_003.fake.nii.gz", "label": "tr_label_003.fake.nii.gz"}, + {"fold": 1, "image": "tr_image_004.fake.nii.gz", "label": "tr_label_004.fake.nii.gz"}, + {"fold": 2, "image": "tr_image_005.fake.nii.gz", "label": "tr_label_005.fake.nii.gz"}, + {"fold": 2, "image": "tr_image_006.fake.nii.gz", "label": "tr_label_006.fake.nii.gz"}, + {"fold": 3, "image": "tr_image_007.fake.nii.gz", "label": "tr_label_007.fake.nii.gz"}, + {"fold": 3, "image": "tr_image_008.fake.nii.gz", "label": "tr_label_008.fake.nii.gz"}, + {"fold": 4, "image": "tr_image_009.fake.nii.gz", "label": "tr_label_009.fake.nii.gz"}, + {"fold": 4, "image": "tr_image_010.fake.nii.gz", "label": "tr_label_010.fake.nii.gz"}, + ], +} + + +@skip_if_quick +@SkipIfBeforePyTorchVersion((1, 13, 0)) +@unittest.skipIf(not has_tb, "no tensorboard summary writer") +@unittest.skipIf(not has_nnunet, "no nnunetv2") +class TestnnUNetBundle(unittest.TestCase): + + def setUp(self) -> None: + + import nibabel as nib + + self.test_dir = tempfile.TemporaryDirectory() + test_path = self.test_dir.name + + sim_dataroot = os.path.join(test_path, "dataroot") + if not os.path.isdir(sim_dataroot): + os.makedirs(sim_dataroot) + + self.sim_dataroot = sim_dataroot + # Generate a fake dataset + for d in sim_datalist["testing"] + sim_datalist["training"]: + im, seg = create_test_image_3d(24, 24, 24, rad_max=10, num_seg_classes=2) + nib_image = nib.Nifti1Image(im, affine=np.eye(4)) + image_fpath = os.path.join(sim_dataroot, d["image"]) + nib.save(nib_image, image_fpath) + + if "label" in d: + nib_image = nib.Nifti1Image(seg, affine=np.eye(4)) + label_fpath = os.path.join(sim_dataroot, d["label"]) + nib.save(nib_image, label_fpath) + + sim_json_datalist = os.path.join(sim_dataroot, "sim_input.json") + ConfigParser.export_config_file(sim_datalist, sim_json_datalist) + + data_src_cfg = os.path.join(test_path, "data_src_cfg.yaml") + data_src = {"modality": "CT", "datalist": sim_json_datalist, "dataroot": sim_dataroot} + + ConfigParser.export_config_file(data_src, data_src_cfg) + self.data_src_cfg = data_src_cfg + self.test_path = test_path + + @skip_if_no_cuda + def test_nnunet_bundle(self) -> None: + runner = nnUNetV2Runner( + input_config=self.data_src_cfg, trainer_class_name="nnUNetTrainer_1epoch", work_dir=self.test_path + ) + with skip_if_downloading_fails(): + runner.run(run_train=False, run_find_best_configuration=False, run_predict_ensemble_postprocessing=False) + + nnunet_trainer = get_nnunet_trainer( + dataset_name_or_id=runner.dataset_name, fold=0, configuration="3d_fullres" + ) + + print("Max Epochs: ", nnunet_trainer.num_epochs) + print("Num Iterations: ", nnunet_trainer.num_iterations_per_epoch) + print("Train Batch dims: ", next(nnunet_trainer.dataloader_train.generator)["data"].shape) + print("Val Batch dims: ", next(nnunet_trainer.dataloader_val.generator)["data"].shape) + print("Network: ", nnunet_trainer.network) + print("Optimizer: ", nnunet_trainer.optimizer) + print("Loss Function: ", nnunet_trainer.loss) + print("LR Scheduler: ", nnunet_trainer.lr_scheduler) + print("Device: ", nnunet_trainer.device) + runner.train_single_model("3d_fullres", fold=0) + + nnunet_config = {"dataset_name_or_id": "001", "nnunet_trainer": "nnUNetTrainer_1epoch"} + self.bundle_root = os.path.join("bundle_root") + + Path(self.bundle_root).joinpath("models").mkdir(parents=True, exist_ok=True) + convert_nnunet_to_monai_bundle(nnunet_config, self.bundle_root, 0) + + data_transforms = Compose([LoadImaged(keys="image"), EnsureChannelFirstd(keys="image")]) + dataset = Dataset( + data=[{"image": os.path.join(self.test_path, "dataroot", "val_001.fake.nii.gz")}], transform=data_transforms + ) + data_loader = DataLoader(dataset, batch_size=1) + input = next(iter(data_loader)) + + predictor = get_nnunet_monai_predictor(Path(self.bundle_root).joinpath("models")) + pred_batch = predictor(input["image"]) + Path(self.sim_dataroot).joinpath("predictions").mkdir(parents=True, exist_ok=True) + + post_processing_transforms = Compose( + [ + Decollated(keys=None, detach=True), + Transposed(keys="pred", indices=[0, 3, 2, 1]), + SaveImaged( + keys="pred", output_dir=Path(self.sim_dataroot).joinpath("predictions"), output_postfix="pred" + ), + ] + ) + post_processing_transforms({"pred": pred_batch}) + + def tearDown(self) -> None: + self.test_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/min_tests.py b/tests/min_tests.py index 12f494be9c..2d68f099a7 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -127,6 +127,7 @@ def run_testsuit(): "test_integration_bundle_run", "test_integration_autorunner", "test_integration_nnunetv2_runner", + "test_integration_nnunet_bundle", "test_invert", "test_invertd", "test_iterable_dataset",