diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 7083373859..156677d992 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1334,13 +1334,15 @@ def __call__( # type: ignore[override] raise NotImplementedError(f"{mode} condition is not supported") noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) - down_block_res_samples, mid_block_res_sample = controlnet( - x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond - ) + if mode == "concat" and condition is not None: noisy_image = torch.cat([noisy_image, condition], dim=1) condition = None + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition + ) + diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): diffuse = partial(diffusion_model, seg=seg) @@ -1396,17 +1398,21 @@ def sample( # type: ignore[override] progress_bar = iter(scheduler.timesteps) intermediates = [] for t in progress_bar: - # 1. ControlNet forward - down_block_res_samples, mid_block_res_sample = controlnet( - x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond - ) - # 2. predict noise model_output diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): diffuse = partial(diffusion_model, seg=seg) if mode == "concat" and conditioning is not None: + # 1. Conditioning model_input = torch.cat([image, conditioning], dim=1) + # 2. ControlNet forward + down_block_res_samples, mid_block_res_sample = controlnet( + x=model_input, + timesteps=torch.Tensor((t,)).to(input_noise.device), + controlnet_cond=cn_cond, + context=None, + ) + # 3. predict noise model_output model_output = diffuse( model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), @@ -1415,6 +1421,12 @@ def sample( # type: ignore[override] mid_block_additional_residual=mid_block_res_sample, ) else: + down_block_res_samples, mid_block_res_sample = controlnet( + x=image, + timesteps=torch.Tensor((t,)).to(input_noise.device), + controlnet_cond=cn_cond, + context=conditioning, + ) model_output = diffuse( image, timesteps=torch.Tensor((t,)).to(input_noise.device), @@ -1485,9 +1497,6 @@ def get_likelihood( # type: ignore[override] for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) - down_block_res_samples, mid_block_res_sample = controlnet( - x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond - ) diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): @@ -1495,6 +1504,9 @@ def get_likelihood( # type: ignore[override] if mode == "concat" and conditioning is not None: noisy_image = torch.cat([noisy_image, conditioning], dim=1) + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond, context=None + ) model_output = diffuse( noisy_image, timesteps=timesteps, @@ -1503,6 +1515,12 @@ def get_likelihood( # type: ignore[override] mid_block_additional_residual=mid_block_res_sample, ) else: + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, + timesteps=torch.Tensor((t,)).to(inputs.device), + controlnet_cond=cn_cond, + context=conditioning, + ) model_output = diffuse( x=noisy_image, timesteps=timesteps, diff --git a/tests/inferers/test_controlnet_inferers.py b/tests/inferers/test_controlnet_inferers.py index 2ab5cec335..909f2cf398 100644 --- a/tests/inferers/test_controlnet_inferers.py +++ b/tests/inferers/test_controlnet_inferers.py @@ -550,6 +550,8 @@ def test_ddim_sampler(self, model_params, controlnet_params, input_shape): def test_sampler_conditioned(self, model_params, controlnet_params, input_shape): model_params["with_conditioning"] = True model_params["cross_attention_dim"] = 3 + controlnet_params["with_conditioning"] = True + controlnet_params["cross_attention_dim"] = 3 model = DiffusionModelUNet(**model_params) controlnet = ControlNet(**controlnet_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -619,8 +621,11 @@ def test_sampler_conditioned_concat(self, model_params, controlnet_params, input model_params = model_params.copy() n_concat_channel = 2 model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel model_params["cross_attention_dim"] = None + controlnet_params["cross_attention_dim"] = None model_params["with_conditioning"] = False + controlnet_params["with_conditioning"] = False model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) @@ -1023,8 +1028,10 @@ def test_prediction_shape_conditioned_concat( if ae_model_type == "SPADEAutoencoderKL": stage_1 = SPADEAutoencoderKL(**autoencoder_params) stage_2_params = stage_2_params.copy() + controlnet_params = controlnet_params.copy() n_concat_channel = 3 stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel if dm_model_type == "SPADEDiffusionModelUNet": stage_2 = SPADEDiffusionModelUNet(**stage_2_params) else: @@ -1106,8 +1113,10 @@ def test_sample_shape_conditioned_concat( if ae_model_type == "SPADEAutoencoderKL": stage_1 = SPADEAutoencoderKL(**autoencoder_params) stage_2_params = stage_2_params.copy() + controlnet_params = controlnet_params.copy() n_concat_channel = 3 stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel if dm_model_type == "SPADEDiffusionModelUNet": stage_2 = SPADEDiffusionModelUNet(**stage_2_params) else: