From f1b49fd60d3e69f600d9ff7f2bad8e2353447fe4 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 5 Mar 2025 11:24:18 +0000 Subject: [PATCH 01/14] add rectified flow noise scheduler to monai Signed-off-by: Can-Zhao --- monai/inferers/inferer.py | 19 +- monai/networks/schedulers/__init__.py | 1 + monai/networks/schedulers/rectified_flow.py | 283 ++++++++++++++++++++ monai/utils/jupyter_utils.py | 2 +- tests/test_diffusion_inferer.py | 18 +- 5 files changed, 317 insertions(+), 6 deletions(-) create mode 100644 monai/networks/schedulers/rectified_flow.py diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 769b6cc0e7..61fbacd1a7 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -39,7 +39,7 @@ SPADEAutoencoderKL, SPADEDiffusionModelUNet, ) -from monai.networks.schedulers import Scheduler +from monai.networks.schedulers import RFlowScheduler, Scheduler from monai.transforms import CenterSpatialCrop, SpatialPad from monai.utils import BlendMode, Ordering, PatchKeys, PytorchPadMode, ensure_tuple, optional_import from monai.visualize import CAM, GradCAM, GradCAMpp @@ -859,12 +859,19 @@ def sample( if not scheduler: scheduler = self.scheduler image = input_noise + + all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype))) if verbose and has_tqdm: - progress_bar = tqdm(scheduler.timesteps) + progress_bar = tqdm( + zip(scheduler.timesteps, all_next_timesteps), + total=min(len(scheduler.timesteps), len(all_next_timesteps)), + ) else: progress_bar = iter(scheduler.timesteps) + progress_bar = iter(zip(scheduler.timesteps, all_next_timesteps)) intermediates = [] - for t in progress_bar: + + for t, next_t in progress_bar: # 1. predict noise model_output diffusion_model = ( partial(diffusion_model, seg=seg) @@ -882,9 +889,13 @@ def sample( ) # 2. compute previous image: x_t -> x_t-1 - image, _ = scheduler.step(model_output, t, image) + if not isinstance(scheduler, RFlowScheduler): + image, _ = scheduler.step(model_output, t, image) + else: + image, _ = scheduler.step(model_output, t, image, next_t) if save_intermediates and t % intermediate_steps == 0: intermediates.append(image) + if save_intermediates: return image, intermediates else: diff --git a/monai/networks/schedulers/__init__.py b/monai/networks/schedulers/__init__.py index 29e9020d65..b7b34f9a77 100644 --- a/monai/networks/schedulers/__init__.py +++ b/monai/networks/schedulers/__init__.py @@ -14,4 +14,5 @@ from .ddim import DDIMScheduler from .ddpm import DDPMScheduler from .pndm import PNDMScheduler +from .rectified_flow import RFlowScheduler from .scheduler import NoiseSchedules, Scheduler diff --git a/monai/networks/schedulers/rectified_flow.py b/monai/networks/schedulers/rectified_flow.py new file mode 100644 index 0000000000..6a848f0762 --- /dev/null +++ b/monai/networks/schedulers/rectified_flow.py @@ -0,0 +1,283 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/hpcaitech/Open-Sora/blob/main/opensora/schedulers/rf/rectified_flow.py +# which has the following license: +# https://github.com/hpcaitech/Open-Sora/blob/main/LICENSE +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +from typing import Any + +import numpy as np +import torch +from torch.distributions import LogisticNormal + +from .scheduler import Scheduler + + +def timestep_transform( + t, input_img_size_numel, base_img_size_numel=32 * 32 * 32, scale=1.0, num_train_timesteps=1000, spatial_dim=3 +): + """ + Applies a transformation to the timestep based on image resolution scaling. + + Args: + t (torch.Tensor): The original timestep(s). + input_img_size_numel (torch.Tensor): The input image's size (H * W * D). + base_img_size_numel (int): reference H*W*D size, usually smaller than input_img_size_numel. + scale (float): Scaling factor for the transformation. + num_train_timesteps (int): Total number of training timesteps. + spatial_dim (int): Number of spatial dimensions in the image. + + Returns: + torch.Tensor: Transformed timestep(s). + """ + t = t / num_train_timesteps + ratio_space = (input_img_size_numel / base_img_size_numel).pow(1.0 / spatial_dim) + + ratio = ratio_space * scale + new_t = ratio * t / (1 + (ratio - 1) * t) + + new_t = new_t * num_train_timesteps + return new_t + + +class RFlowScheduler(Scheduler): + """ + A rectified flow scheduler for guiding the diffusion process in a generative model. + + Supports uniform and logit-normal sampling methods, timestep transformation for + different resolutions, and noise addition during diffusion. + + Attributes: + num_train_timesteps (int): Total number of training timesteps. + use_discrete_timesteps (bool): Whether to use discrete timesteps. + sample_method (str): Training time step sampling method ('uniform' or 'logit-normal'). + loc (float): Location parameter for logit-normal distribution, used only if sample_method='logit-normal'. + scale (float): Scale parameter for logit-normal distribution, used only if sample_method='logit-normal'. + use_timestep_transform (bool): Whether to apply timestep transformation. + If true, there will be more inference timesteps at early(noisy) stages for larger image volumes. + transform_scale (float): Scaling factor for timestep transformation, used only if use_timestep_transform=True. + steps_offset (int): Offset added to computed timesteps, used only if use_timestep_transform=True. + base_img_size_numel (int): Reference image volume size for scaling, used only if use_timestep_transform=True. + + Example: + + .. code-block:: python + + # define a scheduler + noise_scheduler = RFlowScheduler( + num_train_timesteps = 1000, + use_discrete_timesteps = True, + sample_method = 'logit-normal', + use_timestep_transform = True, + base_img_size_numel = 32 * 32 * 32 + ) + + # during training + inputs = torch.ones(2,4,64,64,64) + noise = torch.randn_like(inputs) + timesteps = noise_scheduler.sample_timesteps(inputs) + noisy_inputs = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + predicted_velocity = diffusion_unet( + x=noisy_inputs, + timesteps=timesteps + ) + loss = loss_l1(predicted_velocity, (inputs - noise)) + + # during inference + noisy_inputs = torch.randn(2,4,64,64,64) + input_img_size_numel = torch.prod(torch.tensor(noisy_inputs.shape[-3:]) + noise_scheduler.set_timesteps( + num_inference_steps=30, input_img_size_numel=input_img_size_numel) + ) + all_next_timesteps = torch.cat( + (noise_scheduler.timesteps[1:], torch.tensor([0], dtype=noise_scheduler.timesteps.dtype)) + ) + for t, next_t in tqdm( + zip(noise_scheduler.timesteps, all_next_timesteps), + total=min(len(noise_scheduler.timesteps), len(all_next_timesteps)), + ): + predicted_velocity = diffusion_unet( + x=noisy_inputs, + timesteps=timesteps + ) + noisy_inputs, _ = noise_scheduler.step(predicted_velocity, t, noisy_inputs, next_t) + final_output = noisy_inputs + """ + + def __init__( + self, + num_train_timesteps: int = 1000, + use_discrete_timesteps: bool = True, + sample_method: str = "uniform", + loc: float = 0.0, + scale: float = 1.0, + use_timestep_transform: bool = False, + transform_scale: float = 1.0, + steps_offset: int = 0, + base_img_size_numel: int = 32 * 32 * 32, + ): + self.num_train_timesteps = num_train_timesteps + self.use_discrete_timesteps = use_discrete_timesteps + self.base_img_size_numel = base_img_size_numel + + # sample method + if sample_method not in ["uniform", "logit-normal"]: + raise ValueError( + f"sample_method = {sample_method}, which has to be chosen from ['uniform', 'logit-normal']." + ) + self.sample_method = sample_method + if sample_method == "logit-normal": + self.distribution = LogisticNormal(torch.tensor([loc]), torch.tensor([scale])) + self.sample_t = lambda x: self.distribution.sample((x.shape[0],))[:, 0].to(x.device) + + # timestep transform + self.use_timestep_transform = use_timestep_transform + self.transform_scale = transform_scale + self.steps_offset = steps_offset + + def add_noise( + self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor + ) -> torch.FloatTensor: + """ + Adds noise to the original samples based on the given timesteps. + + Args: + original_samples (torch.FloatTensor): The original sample tensor. + noise (torch.FloatTensor): Noise tensor to be added. + timesteps (torch.IntTensor): Timesteps corresponding to each sample. + + Returns: + torch.FloatTensor: The noisy sample tensor. + """ + timepoints = timesteps.float() / self.num_train_timesteps + timepoints = 1 - timepoints # [1,1/1000] + + # timepoint (bsz) noise: (bsz, 4, frame, w ,h) + # expand timepoint to noise shape + timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) + timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4]) + + return timepoints * original_samples + (1 - timepoints) * noise + + def set_timesteps( + self, + num_inference_steps: int, + device: str | torch.device | None = None, + input_img_size_numel: int | None = None, + ) -> None: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model. + device: target device to put the data. + input_img_size_numel: int, H*W*D of the image, used with self.use_timestep_transform is True. + """ + if num_inference_steps > self.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + # prepare timesteps + timesteps = [ + (1.0 - i / self.num_inference_steps) * self.num_train_timesteps for i in range(self.num_inference_steps) + ] + if self.use_discrete_timesteps: + timesteps = [int(round(t)) for t in timesteps] + if self.use_timestep_transform: + timesteps = [ + timestep_transform( + t, + input_img_size_numel=input_img_size_numel, + base_img_size_numel=self.base_img_size_numel, + num_train_timesteps=self.num_train_timesteps, + ) + for t in timesteps + ] + timesteps = np.array(timesteps).astype(np.float16) + if self.use_discrete_timesteps: + timesteps = timesteps.astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps += self.steps_offset + + def sample_timesteps(self, x_start): + """ + Randomly samples training timesteps using the chosen sampling method. + + Args: + x_start (torch.Tensor): The input tensor for sampling. + + Returns: + torch.Tensor: Sampled timesteps. + """ + if self.sample_method == "uniform": + t = torch.rand((x_start.shape[0],), device=x_start.device) * self.num_train_timesteps + elif self.sample_method == "logit-normal": + t = self.sample_t(x_start) * self.num_train_timesteps + + if self.use_discrete_timesteps: + t = t.long() + + if self.use_timestep_transform: + input_img_size_numel = torch.prod(torch.tensor(x_start.shape[-3:])) + t = timestep_transform( + t, + input_img_size_numel=input_img_size_numel, + base_img_size_numel=self.base_img_size_numel, + num_train_timesteps=self.num_train_timesteps, + ) + + return t + + def step( + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep=None + ) -> tuple[torch.Tensor, Any]: + """ + Predict the sample at the previous timestep. Core function to propagate the diffusion + process from the learned model outputs. + + Args: + model_output: direct output from learned diffusion model. + timestep: current discrete timestep in the diffusion chain. + sample: current instance of sample being created by diffusion process. + next_timestep: next discrete timestep in the diffusion chain. + Returns: + pred_prev_sample: Predicted previous sample + None + """ + v_pred = model_output + if next_timestep is None: + dt = 1.0 / self.num_inference_steps + else: + dt = timestep - next_timestep + dt = dt / self.num_train_timesteps + z = sample + v_pred * dt + + return z, None diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index c93e93dcb9..b1b43a6767 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -234,7 +234,7 @@ def plot_engine_status( def _get_loss_from_output( - output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor, + output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor ) -> torch.Tensor: """Returns a single value from the network output, which is a dict or tensor.""" diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py index 7f37025d3c..6b74452288 100644 --- a/tests/test_diffusion_inferer.py +++ b/tests/test_diffusion_inferer.py @@ -19,7 +19,7 @@ from monai.inferers import DiffusionInferer from monai.networks.nets import DiffusionModelUNet -from monai.networks.schedulers import DDIMScheduler, DDPMScheduler +from monai.networks.schedulers import DDIMScheduler, DDPMScheduler, RFlowScheduler from monai.utils import optional_import _, has_scipy = optional_import("scipy") @@ -120,6 +120,22 @@ def test_ddim_sampler(self, model_params, input_shape): ) self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_rflow_sampler(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = RFlowScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_sampler_conditioned(self, model_params, input_shape): From d7ea3ed8465acfe31a6382facb458e0b81b6b00a Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 5 Mar 2025 16:51:25 +0000 Subject: [PATCH 02/14] add rectified flow for accelerated diffusion model Signed-off-by: Can-Zhao --- monai/networks/schedulers/rectified_flow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/schedulers/rectified_flow.py b/monai/networks/schedulers/rectified_flow.py index 6a848f0762..d0657c54c7 100644 --- a/monai/networks/schedulers/rectified_flow.py +++ b/monai/networks/schedulers/rectified_flow.py @@ -97,7 +97,7 @@ class RFlowScheduler(Scheduler): ) # during training - inputs = torch.ones(2,4,64,64,64) + inputs = torch.ones(2,4,64,64,32) noise = torch.randn_like(inputs) timesteps = noise_scheduler.sample_timesteps(inputs) noisy_inputs = noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) @@ -108,7 +108,7 @@ class RFlowScheduler(Scheduler): loss = loss_l1(predicted_velocity, (inputs - noise)) # during inference - noisy_inputs = torch.randn(2,4,64,64,64) + noisy_inputs = torch.randn(2,4,64,64,32) input_img_size_numel = torch.prod(torch.tensor(noisy_inputs.shape[-3:]) noise_scheduler.set_timesteps( num_inference_steps=30, input_img_size_numel=input_img_size_numel) From c14ce67c062d96609e6017b46d157e5ab5a46757 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 5 Mar 2025 18:29:05 +0000 Subject: [PATCH 03/14] reformat Signed-off-by: Can-Zhao --- monai/utils/jupyter_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index b1b43a6767..c93e93dcb9 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -234,7 +234,7 @@ def plot_engine_status( def _get_loss_from_output( - output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor + output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor, ) -> torch.Tensor: """Returns a single value from the network output, which is a dict or tensor.""" From 1294ceb31adee26f0f159c364ee0fea418a42664 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 5 Mar 2025 18:54:00 +0000 Subject: [PATCH 04/14] reformat Signed-off-by: Can-Zhao --- monai/networks/schedulers/rectified_flow.py | 22 ++++++++++----------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/monai/networks/schedulers/rectified_flow.py b/monai/networks/schedulers/rectified_flow.py index d0657c54c7..995739c2aa 100644 --- a/monai/networks/schedulers/rectified_flow.py +++ b/monai/networks/schedulers/rectified_flow.py @@ -159,19 +159,17 @@ def __init__( self.transform_scale = transform_scale self.steps_offset = steps_offset - def add_noise( - self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor - ) -> torch.FloatTensor: + def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: """ - Adds noise to the original samples based on the given timesteps. + Add noise to the original samples. Args: - original_samples (torch.FloatTensor): The original sample tensor. - noise (torch.FloatTensor): Noise tensor to be added. - timesteps (torch.IntTensor): Timesteps corresponding to each sample. + original_samples: original samples + noise: noise to add to samples + timesteps: timesteps tensor indicating the timestep to be computed for each sample. Returns: - torch.FloatTensor: The noisy sample tensor. + noisy_samples: sample with added noise """ timepoints = timesteps.float() / self.num_train_timesteps timepoints = 1 - timepoints # [1,1/1000] @@ -221,10 +219,10 @@ def set_timesteps( ) for t in timesteps ] - timesteps = np.array(timesteps).astype(np.float16) + timesteps_np = np.array(timesteps).astype(np.float16) if self.use_discrete_timesteps: - timesteps = timesteps.astype(np.int64) - self.timesteps = torch.from_numpy(timesteps).to(device) + timesteps_np = timesteps_np.astype(np.int64) + self.timesteps = torch.from_numpy(timesteps_np).to(device) self.timesteps += self.steps_offset def sample_timesteps(self, x_start): @@ -257,7 +255,7 @@ def sample_timesteps(self, x_start): return t def step( - self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep=None + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep: int | None = None ) -> tuple[torch.Tensor, Any]: """ Predict the sample at the previous timestep. Core function to propagate the diffusion From 4bf6c02ff849ef782d275651daab0983b112a570 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 5 Mar 2025 19:58:43 +0000 Subject: [PATCH 05/14] reformat Signed-off-by: Can-Zhao --- monai/networks/schedulers/rectified_flow.py | 48 +++++++++++++-------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/monai/networks/schedulers/rectified_flow.py b/monai/networks/schedulers/rectified_flow.py index 995739c2aa..1e961ac847 100644 --- a/monai/networks/schedulers/rectified_flow.py +++ b/monai/networks/schedulers/rectified_flow.py @@ -28,7 +28,7 @@ from __future__ import annotations -from typing import Any +from typing import Any, Union import numpy as np import torch @@ -171,15 +171,16 @@ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timeste Returns: noisy_samples: sample with added noise """ - timepoints = timesteps.float() / self.num_train_timesteps + timepoints: torch.Tensor = timesteps.float() / self.num_train_timesteps timepoints = 1 - timepoints # [1,1/1000] # timepoint (bsz) noise: (bsz, 4, frame, w ,h) # expand timepoint to noise shape timepoints = timepoints.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1) timepoints = timepoints.repeat(1, noise.shape[1], noise.shape[2], noise.shape[3], noise.shape[4]) + noisy_samples: torch.Tensor = timepoints * original_samples + (1 - timepoints) * noise - return timepoints * original_samples + (1 - timepoints) * noise + return noisy_samples def set_timesteps( self, @@ -255,27 +256,38 @@ def sample_timesteps(self, x_start): return t def step( - self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep: int | None = None - ) -> tuple[torch.Tensor, Any]: + self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep: Union[int, None] = None + ) -> tuple[torch.Tensor, None]: """ - Predict the sample at the previous timestep. Core function to propagate the diffusion - process from the learned model outputs. + Predicts the next sample in the diffusion process. Args: - model_output: direct output from learned diffusion model. - timestep: current discrete timestep in the diffusion chain. - sample: current instance of sample being created by diffusion process. - next_timestep: next discrete timestep in the diffusion chain. + model_output (torch.Tensor): Output from the trained diffusion model. + timestep (int): Current timestep in the diffusion chain. + sample (torch.Tensor): Current sample in the process. + next_timestep (Union[int, None]): Optional next timestep. + Returns: - pred_prev_sample: Predicted previous sample - None + tuple[torch.Tensor, None]: Predicted sample at the next step and additional info. """ + # Ensure num_inference_steps exists and is a valid integer + if not hasattr(self, "num_inference_steps") or not isinstance(self.num_inference_steps, int): + raise AttributeError( + "num_inference_steps is missing or not an integer in the class." + "Please run self.set_timesteps(num_inference_steps,device,input_img_size_numel) to set it." + ) + v_pred = model_output - if next_timestep is None: - dt = 1.0 / self.num_inference_steps + + if next_timestep is not None: + next_timestep = int(next_timestep) + dt: float = ( + float(timestep - next_timestep) / self.num_train_timesteps + ) # Now next_timestep is guaranteed to be int else: - dt = timestep - next_timestep - dt = dt / self.num_train_timesteps - z = sample + v_pred * dt + dt = ( + 1.0 / float(self.num_inference_steps) if self.num_inference_steps > 0 else 0.0 + ) # Avoid division by zero + z = sample + v_pred * dt return z, None From a22d9bcedb29385167636b810ac765fd4d3b0f76 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Mar 2025 19:59:18 +0000 Subject: [PATCH 06/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/schedulers/rectified_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/schedulers/rectified_flow.py b/monai/networks/schedulers/rectified_flow.py index 1e961ac847..3627a395a5 100644 --- a/monai/networks/schedulers/rectified_flow.py +++ b/monai/networks/schedulers/rectified_flow.py @@ -28,7 +28,7 @@ from __future__ import annotations -from typing import Any, Union +from typing import Union import numpy as np import torch From df69ed8e4f8e5848fa81a704159a5b3ce7f6509c Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 5 Mar 2025 20:01:10 +0000 Subject: [PATCH 07/14] reformat Signed-off-by: Can-Zhao --- monai/networks/schedulers/rectified_flow.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/networks/schedulers/rectified_flow.py b/monai/networks/schedulers/rectified_flow.py index 1e961ac847..795ca3d85b 100644 --- a/monai/networks/schedulers/rectified_flow.py +++ b/monai/networks/schedulers/rectified_flow.py @@ -196,9 +196,10 @@ def set_timesteps( device: target device to put the data. input_img_size_numel: int, H*W*D of the image, used with self.use_timestep_transform is True. """ - if num_inference_steps > self.num_train_timesteps: + if num_inference_steps > self.num_train_timesteps or num_inference_steps < 1: raise ValueError( - f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:" + f"`num_inference_steps`: {num_inference_steps} should be at least 1, " + "and cannot be larger than `self.num_train_timesteps`:" f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle" f" maximal {self.num_train_timesteps} timesteps." ) From d6ff59129bd0ad91d7225453cf33e2d1f57dc3d4 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 5 Mar 2025 21:10:29 +0000 Subject: [PATCH 08/14] add prev_original Signed-off-by: Can-Zhao --- monai/networks/schedulers/rectified_flow.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/monai/networks/schedulers/rectified_flow.py b/monai/networks/schedulers/rectified_flow.py index 26b908ce4c..27a4473d9f 100644 --- a/monai/networks/schedulers/rectified_flow.py +++ b/monai/networks/schedulers/rectified_flow.py @@ -258,7 +258,7 @@ def sample_timesteps(self, x_start): def step( self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, next_timestep: Union[int, None] = None - ) -> tuple[torch.Tensor, None]: + ) -> tuple[torch.Tensor, torch.Tensor]: """ Predicts the next sample in the diffusion process. @@ -290,5 +290,7 @@ def step( 1.0 / float(self.num_inference_steps) if self.num_inference_steps > 0 else 0.0 ) # Avoid division by zero - z = sample + v_pred * dt - return z, None + pred_post_sample = sample + v_pred * dt + pred_original_sample = sample + v_pred * timestep/self.num_train_timesteps + + return pred_post_sample, pred_original_sample From 28f2021cd9dfcd47ac5ba41e78c9e49c9d145aee Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Mar 2025 21:11:03 +0000 Subject: [PATCH 09/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/schedulers/rectified_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/schedulers/rectified_flow.py b/monai/networks/schedulers/rectified_flow.py index 27a4473d9f..3f79500c63 100644 --- a/monai/networks/schedulers/rectified_flow.py +++ b/monai/networks/schedulers/rectified_flow.py @@ -292,5 +292,5 @@ def step( pred_post_sample = sample + v_pred * dt pred_original_sample = sample + v_pred * timestep/self.num_train_timesteps - + return pred_post_sample, pred_original_sample From 39186fcd202a6497bc0f1fccc8f3209190d0a0a9 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 5 Mar 2025 21:31:36 +0000 Subject: [PATCH 10/14] black Signed-off-by: Can-Zhao --- monai/networks/schedulers/rectified_flow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/schedulers/rectified_flow.py b/monai/networks/schedulers/rectified_flow.py index 27a4473d9f..858f1c183a 100644 --- a/monai/networks/schedulers/rectified_flow.py +++ b/monai/networks/schedulers/rectified_flow.py @@ -291,6 +291,6 @@ def step( ) # Avoid division by zero pred_post_sample = sample + v_pred * dt - pred_original_sample = sample + v_pred * timestep/self.num_train_timesteps - + pred_original_sample = sample + v_pred * timestep / self.num_train_timesteps + return pred_post_sample, pred_original_sample From 929d7a3dbaf5deb94c3ca5da8bde514afe92204c Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 5 Mar 2025 22:29:09 +0000 Subject: [PATCH 11/14] add doc Signed-off-by: Can-Zhao --- docs/source/networks.rst | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index e2e509a99b..6ba7577955 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -750,3 +750,27 @@ Utilities .. automodule:: monai.apps.reconstruction.networks.nets.utils :members: + +Noise Schedulers +---------------- +.. currentmodule:: monai.networks.schedulers + +`AHNet` +~~~~~~~ +.. autoclass:: Scheduler + :members: + +.. autoclass:: NoiseSchedules + :members: + +.. autoclass:: DDPMScheduler + :members: + +.. autoclass:: DDIMScheduler + :members: + +.. autoclass:: PNDMScheduler + :members: + +.. autoclass:: RFlowScheduler + :members: From bdf6d0393168e23d8b5320a1e3b930c5e045ae69 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 5 Mar 2025 22:36:29 +0000 Subject: [PATCH 12/14] add doc Signed-off-by: Can-Zhao --- docs/source/networks.rst | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 6ba7577955..11e8ed1fb5 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -755,22 +755,32 @@ Noise Schedulers ---------------- .. currentmodule:: monai.networks.schedulers -`AHNet` -~~~~~~~ +`Scheduler` +~~~~~~~~~~~ .. autoclass:: Scheduler :members: +`NoiseSchedules` +~~~~~~~~~~~~~~~~ .. autoclass:: NoiseSchedules :members: +`DDPMScheduler` +~~~~~~~~~~~~~~~ .. autoclass:: DDPMScheduler :members: +`DDIMScheduler` +~~~~~~~~~~~~~~~ .. autoclass:: DDIMScheduler :members: +`PNDMScheduler` +~~~~~~~~~~~~~~~ .. autoclass:: PNDMScheduler :members: +`RFlowScheduler` +~~~~~~~~~~~~~~~~ .. autoclass:: RFlowScheduler :members: From 7c94a8dae8344211658a20e2d08d7bf12e19e230 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 5 Mar 2025 22:53:17 +0000 Subject: [PATCH 13/14] add doc Signed-off-by: Can-Zhao --- docs/source/networks.rst | 1 + monai/networks/schedulers/rectified_flow.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 11e8ed1fb5..0119c6db4d 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -753,6 +753,7 @@ Utilities Noise Schedulers ---------------- +.. automodule:: monai.networks.schedulers .. currentmodule:: monai.networks.schedulers `Scheduler` diff --git a/monai/networks/schedulers/rectified_flow.py b/monai/networks/schedulers/rectified_flow.py index 858f1c183a..a3002f59b8 100644 --- a/monai/networks/schedulers/rectified_flow.py +++ b/monai/networks/schedulers/rectified_flow.py @@ -71,7 +71,7 @@ class RFlowScheduler(Scheduler): Supports uniform and logit-normal sampling methods, timestep transformation for different resolutions, and noise addition during diffusion. - Attributes: + Args: num_train_timesteps (int): Total number of training timesteps. use_discrete_timesteps (bool): Whether to use discrete timesteps. sample_method (str): Training time step sampling method ('uniform' or 'logit-normal'). From 9f4ae11926fa14d2659f5ec3b3f98bb3aa18c447 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Thu, 6 Mar 2025 17:18:04 +0000 Subject: [PATCH 14/14] update doc Signed-off-by: Can-Zhao --- monai/networks/schedulers/rectified_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/schedulers/rectified_flow.py b/monai/networks/schedulers/rectified_flow.py index a3002f59b8..5bdeae0931 100644 --- a/monai/networks/schedulers/rectified_flow.py +++ b/monai/networks/schedulers/rectified_flow.py @@ -269,7 +269,7 @@ def step( next_timestep (Union[int, None]): Optional next timestep. Returns: - tuple[torch.Tensor, None]: Predicted sample at the next step and additional info. + tuple[torch.Tensor, torch.Tensor]: Predicted sample at the next step and additional info. """ # Ensure num_inference_steps exists and is a valid integer if not hasattr(self, "num_inference_steps") or not isinstance(self.num_inference_steps, int):