Skip to content

Commit

Permalink
Modify ControlNet inferer so that it takes in context when the diffus…
Browse files Browse the repository at this point in the history
…ion model has context. This should be the standard behavior. Modified tests accordingly.

Signed-off-by: Virginia Fernandez <[email protected]>
  • Loading branch information
Virginia Fernandez committed Feb 21, 2025
1 parent b0ed253 commit e8ca78a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 11 deletions.
40 changes: 29 additions & 11 deletions monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,13 +1334,15 @@ def __call__( # type: ignore[override]
raise NotImplementedError(f"{mode} condition is not supported")

noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond
)

if mode == "concat" and condition is not None:
noisy_image = torch.cat([noisy_image, condition], dim=1)
condition = None

down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition
)

diffuse = diffusion_model
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
diffuse = partial(diffusion_model, seg=seg)
Expand Down Expand Up @@ -1396,17 +1398,21 @@ def sample( # type: ignore[override]
progress_bar = iter(scheduler.timesteps)
intermediates = []
for t in progress_bar:
# 1. ControlNet forward
down_block_res_samples, mid_block_res_sample = controlnet(
x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond
)
# 2. predict noise model_output
diffuse = diffusion_model
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
diffuse = partial(diffusion_model, seg=seg)

if mode == "concat" and conditioning is not None:
# 1. Conditioning
model_input = torch.cat([image, conditioning], dim=1)
# 2. ControlNet forward
down_block_res_samples, mid_block_res_sample = controlnet(
x=model_input,
timesteps=torch.Tensor((t,)).to(input_noise.device),
controlnet_cond=cn_cond,
context=None,
)
# 3. predict noise model_output
model_output = diffuse(
model_input,
timesteps=torch.Tensor((t,)).to(input_noise.device),
Expand All @@ -1415,6 +1421,12 @@ def sample( # type: ignore[override]
mid_block_additional_residual=mid_block_res_sample,
)
else:
down_block_res_samples, mid_block_res_sample = controlnet(
x=image,
timesteps=torch.Tensor((t,)).to(input_noise.device),
controlnet_cond=cn_cond,
context=conditioning,
)
model_output = diffuse(
image,
timesteps=torch.Tensor((t,)).to(input_noise.device),
Expand Down Expand Up @@ -1485,16 +1497,16 @@ def get_likelihood( # type: ignore[override]
for t in progress_bar:
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond
)

diffuse = diffusion_model
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
diffuse = partial(diffusion_model, seg=seg)

if mode == "concat" and conditioning is not None:
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond, context=None
)
model_output = diffuse(
noisy_image,
timesteps=timesteps,
Expand All @@ -1503,6 +1515,12 @@ def get_likelihood( # type: ignore[override]
mid_block_additional_residual=mid_block_res_sample,
)
else:
down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_image,
timesteps=torch.Tensor((t,)).to(inputs.device),
controlnet_cond=cn_cond,
context=conditioning,
)
model_output = diffuse(
x=noisy_image,
timesteps=timesteps,
Expand Down
9 changes: 9 additions & 0 deletions tests/inferers/test_controlnet_inferers.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,8 @@ def test_ddim_sampler(self, model_params, controlnet_params, input_shape):
def test_sampler_conditioned(self, model_params, controlnet_params, input_shape):
model_params["with_conditioning"] = True
model_params["cross_attention_dim"] = 3
controlnet_params["with_conditioning"] = True
controlnet_params["cross_attention_dim"] = 3
model = DiffusionModelUNet(**model_params)
controlnet = ControlNet(**controlnet_params)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -619,8 +621,11 @@ def test_sampler_conditioned_concat(self, model_params, controlnet_params, input
model_params = model_params.copy()
n_concat_channel = 2
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
model_params["cross_attention_dim"] = None
controlnet_params["cross_attention_dim"] = None
model_params["with_conditioning"] = False
controlnet_params["with_conditioning"] = False
model = DiffusionModelUNet(**model_params)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
Expand Down Expand Up @@ -1023,8 +1028,10 @@ def test_prediction_shape_conditioned_concat(
if ae_model_type == "SPADEAutoencoderKL":
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
stage_2_params = stage_2_params.copy()
controlnet_params = controlnet_params.copy()
n_concat_channel = 3
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
if dm_model_type == "SPADEDiffusionModelUNet":
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
else:
Expand Down Expand Up @@ -1106,8 +1113,10 @@ def test_sample_shape_conditioned_concat(
if ae_model_type == "SPADEAutoencoderKL":
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
stage_2_params = stage_2_params.copy()
controlnet_params = controlnet_params.copy()
n_concat_channel = 3
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
if dm_model_type == "SPADEDiffusionModelUNet":
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
else:
Expand Down

0 comments on commit e8ca78a

Please sign in to comment.