From ab0752343b65198dae12d7389e441cfafeb9890a Mon Sep 17 00:00:00 2001 From: Virginia Fernandez <61539159+virginiafdez@users.noreply.github.com> Date: Mon, 24 Feb 2025 10:11:58 +0000 Subject: [PATCH] =?UTF-8?q?Modify=20ControlNet=20inferer=20so=20that=20it?= =?UTF-8?q?=20takes=20in=20context=20when=20the=20diffus=E2=80=A6=20(#8360?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #8344 ### Description The ControlNet inferers (latent and not latent) work in such a way that, when conditioning is used, the ControlNet does not take in the conditioning. It should, in theory, exhibit the same behaviour as the diffusion model. I've changed this behaviour, which has included modifying ControlNetDiffusionInferer and ControlNetLatentDiffusionInferer; the methods call, sample and get_likelihood. I've also modified the tests to take this into account. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes (modified, rather than new) - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. Signed-off-by: Virginia Fernandez Co-authored-by: Virginia Fernandez Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/inferers/inferer.py | 40 ++++++++++++++++------ tests/inferers/test_controlnet_inferers.py | 9 +++++ 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 7083373859..156677d992 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1334,13 +1334,15 @@ def __call__( # type: ignore[override] raise NotImplementedError(f"{mode} condition is not supported") noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) - down_block_res_samples, mid_block_res_sample = controlnet( - x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond - ) + if mode == "concat" and condition is not None: noisy_image = torch.cat([noisy_image, condition], dim=1) condition = None + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition + ) + diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): diffuse = partial(diffusion_model, seg=seg) @@ -1396,17 +1398,21 @@ def sample( # type: ignore[override] progress_bar = iter(scheduler.timesteps) intermediates = [] for t in progress_bar: - # 1. ControlNet forward - down_block_res_samples, mid_block_res_sample = controlnet( - x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond - ) - # 2. predict noise model_output diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): diffuse = partial(diffusion_model, seg=seg) if mode == "concat" and conditioning is not None: + # 1. Conditioning model_input = torch.cat([image, conditioning], dim=1) + # 2. ControlNet forward + down_block_res_samples, mid_block_res_sample = controlnet( + x=model_input, + timesteps=torch.Tensor((t,)).to(input_noise.device), + controlnet_cond=cn_cond, + context=None, + ) + # 3. predict noise model_output model_output = diffuse( model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), @@ -1415,6 +1421,12 @@ def sample( # type: ignore[override] mid_block_additional_residual=mid_block_res_sample, ) else: + down_block_res_samples, mid_block_res_sample = controlnet( + x=image, + timesteps=torch.Tensor((t,)).to(input_noise.device), + controlnet_cond=cn_cond, + context=conditioning, + ) model_output = diffuse( image, timesteps=torch.Tensor((t,)).to(input_noise.device), @@ -1485,9 +1497,6 @@ def get_likelihood( # type: ignore[override] for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) - down_block_res_samples, mid_block_res_sample = controlnet( - x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond - ) diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): @@ -1495,6 +1504,9 @@ def get_likelihood( # type: ignore[override] if mode == "concat" and conditioning is not None: noisy_image = torch.cat([noisy_image, conditioning], dim=1) + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond, context=None + ) model_output = diffuse( noisy_image, timesteps=timesteps, @@ -1503,6 +1515,12 @@ def get_likelihood( # type: ignore[override] mid_block_additional_residual=mid_block_res_sample, ) else: + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, + timesteps=torch.Tensor((t,)).to(inputs.device), + controlnet_cond=cn_cond, + context=conditioning, + ) model_output = diffuse( x=noisy_image, timesteps=timesteps, diff --git a/tests/inferers/test_controlnet_inferers.py b/tests/inferers/test_controlnet_inferers.py index 2ab5cec335..909f2cf398 100644 --- a/tests/inferers/test_controlnet_inferers.py +++ b/tests/inferers/test_controlnet_inferers.py @@ -550,6 +550,8 @@ def test_ddim_sampler(self, model_params, controlnet_params, input_shape): def test_sampler_conditioned(self, model_params, controlnet_params, input_shape): model_params["with_conditioning"] = True model_params["cross_attention_dim"] = 3 + controlnet_params["with_conditioning"] = True + controlnet_params["cross_attention_dim"] = 3 model = DiffusionModelUNet(**model_params) controlnet = ControlNet(**controlnet_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -619,8 +621,11 @@ def test_sampler_conditioned_concat(self, model_params, controlnet_params, input model_params = model_params.copy() n_concat_channel = 2 model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel model_params["cross_attention_dim"] = None + controlnet_params["cross_attention_dim"] = None model_params["with_conditioning"] = False + controlnet_params["with_conditioning"] = False model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) @@ -1023,8 +1028,10 @@ def test_prediction_shape_conditioned_concat( if ae_model_type == "SPADEAutoencoderKL": stage_1 = SPADEAutoencoderKL(**autoencoder_params) stage_2_params = stage_2_params.copy() + controlnet_params = controlnet_params.copy() n_concat_channel = 3 stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel if dm_model_type == "SPADEDiffusionModelUNet": stage_2 = SPADEDiffusionModelUNet(**stage_2_params) else: @@ -1106,8 +1113,10 @@ def test_sample_shape_conditioned_concat( if ae_model_type == "SPADEAutoencoderKL": stage_1 = SPADEAutoencoderKL(**autoencoder_params) stage_2_params = stage_2_params.copy() + controlnet_params = controlnet_params.copy() n_concat_channel = 3 stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel if dm_model_type == "SPADEDiffusionModelUNet": stage_2 = SPADEDiffusionModelUNet(**stage_2_params) else: