diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 86e1b1d3ae..ec5ff25946 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -106,7 +106,11 @@ def __init__( self.num_heads = num_heads self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size - self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size) + self.out_proj: Union[nn.Linear, nn.Identity] + if include_fc: + self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size) + else: + self.out_proj = nn.Identity() self.qkv: Union[nn.Linear, nn.Identity] self.to_q: Union[nn.Linear, nn.Identity]