Skip to content

Commit

Permalink
Modify ControlNet inferer so that it takes in context when the diffus… (
Browse files Browse the repository at this point in the history
#8360)

Fixes #8344 

### Description
The ControlNet inferers (latent and not latent) work in such a way that,
when conditioning is used, the ControlNet does not take in the
conditioning. It should, in theory, exhibit the same behaviour as the
diffusion model.
I've changed this behaviour, which has included modifying
ControlNetDiffusionInferer and ControlNetLatentDiffusionInferer; the
methods call, sample and get_likelihood.
I've also modified the tests to take this into account. 

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] New tests added to cover the changes (modified, rather than new)
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.

Signed-off-by: Virginia Fernandez <[email protected]>
Co-authored-by: Virginia Fernandez <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
4 people authored Feb 24, 2025
1 parent a790590 commit ab07523
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 11 deletions.
40 changes: 29 additions & 11 deletions monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,13 +1334,15 @@ def __call__( # type: ignore[override]
raise NotImplementedError(f"{mode} condition is not supported")

noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond
)

if mode == "concat" and condition is not None:
noisy_image = torch.cat([noisy_image, condition], dim=1)
condition = None

down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition
)

diffuse = diffusion_model
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
diffuse = partial(diffusion_model, seg=seg)
Expand Down Expand Up @@ -1396,17 +1398,21 @@ def sample( # type: ignore[override]
progress_bar = iter(scheduler.timesteps)
intermediates = []
for t in progress_bar:
# 1. ControlNet forward
down_block_res_samples, mid_block_res_sample = controlnet(
x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond
)
# 2. predict noise model_output
diffuse = diffusion_model
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
diffuse = partial(diffusion_model, seg=seg)

if mode == "concat" and conditioning is not None:
# 1. Conditioning
model_input = torch.cat([image, conditioning], dim=1)
# 2. ControlNet forward
down_block_res_samples, mid_block_res_sample = controlnet(
x=model_input,
timesteps=torch.Tensor((t,)).to(input_noise.device),
controlnet_cond=cn_cond,
context=None,
)
# 3. predict noise model_output
model_output = diffuse(
model_input,
timesteps=torch.Tensor((t,)).to(input_noise.device),
Expand All @@ -1415,6 +1421,12 @@ def sample( # type: ignore[override]
mid_block_additional_residual=mid_block_res_sample,
)
else:
down_block_res_samples, mid_block_res_sample = controlnet(
x=image,
timesteps=torch.Tensor((t,)).to(input_noise.device),
controlnet_cond=cn_cond,
context=conditioning,
)
model_output = diffuse(
image,
timesteps=torch.Tensor((t,)).to(input_noise.device),
Expand Down Expand Up @@ -1485,16 +1497,16 @@ def get_likelihood( # type: ignore[override]
for t in progress_bar:
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond
)

diffuse = diffusion_model
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
diffuse = partial(diffusion_model, seg=seg)

if mode == "concat" and conditioning is not None:
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond, context=None
)
model_output = diffuse(
noisy_image,
timesteps=timesteps,
Expand All @@ -1503,6 +1515,12 @@ def get_likelihood( # type: ignore[override]
mid_block_additional_residual=mid_block_res_sample,
)
else:
down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_image,
timesteps=torch.Tensor((t,)).to(inputs.device),
controlnet_cond=cn_cond,
context=conditioning,
)
model_output = diffuse(
x=noisy_image,
timesteps=timesteps,
Expand Down
9 changes: 9 additions & 0 deletions tests/inferers/test_controlnet_inferers.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,8 @@ def test_ddim_sampler(self, model_params, controlnet_params, input_shape):
def test_sampler_conditioned(self, model_params, controlnet_params, input_shape):
model_params["with_conditioning"] = True
model_params["cross_attention_dim"] = 3
controlnet_params["with_conditioning"] = True
controlnet_params["cross_attention_dim"] = 3
model = DiffusionModelUNet(**model_params)
controlnet = ControlNet(**controlnet_params)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -619,8 +621,11 @@ def test_sampler_conditioned_concat(self, model_params, controlnet_params, input
model_params = model_params.copy()
n_concat_channel = 2
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
model_params["cross_attention_dim"] = None
controlnet_params["cross_attention_dim"] = None
model_params["with_conditioning"] = False
controlnet_params["with_conditioning"] = False
model = DiffusionModelUNet(**model_params)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
Expand Down Expand Up @@ -1023,8 +1028,10 @@ def test_prediction_shape_conditioned_concat(
if ae_model_type == "SPADEAutoencoderKL":
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
stage_2_params = stage_2_params.copy()
controlnet_params = controlnet_params.copy()
n_concat_channel = 3
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
if dm_model_type == "SPADEDiffusionModelUNet":
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
else:
Expand Down Expand Up @@ -1106,8 +1113,10 @@ def test_sample_shape_conditioned_concat(
if ae_model_type == "SPADEAutoencoderKL":
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
stage_2_params = stage_2_params.copy()
controlnet_params = controlnet_params.copy()
n_concat_channel = 3
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
if dm_model_type == "SPADEDiffusionModelUNet":
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
else:
Expand Down

0 comments on commit ab07523

Please sign in to comment.