diff --git a/monai/networks/blocks/rel_pos_embedding.py b/monai/networks/blocks/rel_pos_embedding.py index a96fcc9e216..4aee63d5f17 100644 --- a/monai/networks/blocks/rel_pos_embedding.py +++ b/monai/networks/blocks/rel_pos_embedding.py @@ -17,17 +17,7 @@ from monai.networks.blocks.attention_utils import add_decomposed_rel_pos -class RelativePosEmbedding(nn.Module): - def __init__( - self, - ) -> None: - super().__init__() - - def forward(self, x: torch.Tensor, att_mat: torch.Tensor) -> torch.Tensor: - ... - - -class DecomposedRelativePosEmbedding(RelativePosEmbedding): +class DecomposedRelativePosEmbedding(nn.Module): def __init__(self, s_input_dims: Tuple, c_dim: int, num_heads: int) -> None: """ Args: diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 6062b5352f3..44b9e035150 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -18,22 +18,32 @@ import torch from parameterized import parameterized +from monai.networks.layers.factories import RelPosEmbedding from monai.networks import eval_mode from monai.networks.blocks.selfattention import SABlock from monai.utils import optional_import + einops, has_einops = optional_import("einops") TEST_CASE_SABLOCK = [] for dropout_rate in np.linspace(0, 1, 4): for hidden_size in [360, 480, 600, 768]: for num_heads in [4, 6, 8, 12]: - test_case = [ - {"hidden_size": hidden_size, "num_heads": num_heads, "dropout_rate": dropout_rate}, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_SABLOCK.append(test_case) + for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: + for input_size in [(16, 32), (8, 8, 8)]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_SABLOCK.append(test_case) class TestResBlock(unittest.TestCase):