From 547ac943c2e606c783600dd884049a571df6a475 Mon Sep 17 00:00:00 2001 From: John Zielke Date: Tue, 4 Feb 2025 14:10:19 +0000 Subject: [PATCH 1/2] selfattention block: Remove the fc linear layer if it is not used Signed-off-by: John Zielke --- monai/networks/blocks/selfattention.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 86e1b1d3ae..ec5ff25946 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -106,7 +106,11 @@ def __init__( self.num_heads = num_heads self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size - self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size) + self.out_proj: Union[nn.Linear, nn.Identity] + if include_fc: + self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size) + else: + self.out_proj = nn.Identity() self.qkv: Union[nn.Linear, nn.Identity] self.to_q: Union[nn.Linear, nn.Identity] From 14115648189d6644933cb91b3092ebb085ccbaa0 Mon Sep 17 00:00:00 2001 From: John Zielke Date: Wed, 12 Feb 2025 18:32:20 +0000 Subject: [PATCH 2/2] Fix old state dict loading and add tests Signed-off-by: John Zielke --- monai/networks/nets/diffusion_model_unet.py | 6 +++--- tests/networks/blocks/test_selfattention.py | 21 +++++++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 65d6053acc..11196bb343 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -1847,9 +1847,9 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias") # projection - new_state_dict[f"{block}.attn.out_proj.weight"] = old_state_dict.pop(f"{block}.proj_attn.weight") - new_state_dict[f"{block}.attn.out_proj.bias"] = old_state_dict.pop(f"{block}.proj_attn.bias") - + if f"{block}.attn.out_proj.weight" in new_state_dict and f"{block}.attn.out_proj.bias" in new_state_dict: + new_state_dict[f"{block}.attn.out_proj.weight"] = old_state_dict.pop(f"{block}.proj_attn.weight") + new_state_dict[f"{block}.attn.out_proj.bias"] = old_state_dict.pop(f"{block}.proj_attn.bias") # fix the cross attention blocks cross_attention_blocks = [ k.replace(".out_proj.weight", "") diff --git a/tests/networks/blocks/test_selfattention.py b/tests/networks/blocks/test_selfattention.py index 494f64cad8..7d209659b9 100644 --- a/tests/networks/blocks/test_selfattention.py +++ b/tests/networks/blocks/test_selfattention.py @@ -227,6 +227,27 @@ def test_flash_attention(self): out_2 = block_wo_flash_attention(test_data) assert_allclose(out_1, out_2, atol=1e-4) + @parameterized.expand([[True], [False]]) + def test_no_extra_weights_if_no_fc(self, include_fc): + input_param = { + "hidden_size": 360, + "num_heads": 4, + "dropout_rate": 0.0, + "rel_pos_embedding": None, + "input_size": (16, 32), + "include_fc": include_fc, + "use_combined_linear": use_combined_linear, + } + net = SABlock(**input_param) + if not include_fc: + self.assertNotIn("out_proj.weight", net.state_dict()) + self.assertNotIn("out_proj.bias", net.state_dict()) + self.assertIsInstance(net.out_proj, torch.nn.Identity) + else: + self.assertIn("out_proj.weight", net.state_dict()) + self.assertIn("out_proj.bias", net.state_dict()) + self.assertIsInstance(net.out_proj, torch.nn.Linear) + if __name__ == "__main__": unittest.main()