Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
Signed-off-by: vgrau98 <[email protected]>
  • Loading branch information
vgrau98 committed Jan 7, 2024
1 parent 0b2b96d commit 4b2c852
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
12 changes: 1 addition & 11 deletions monai/networks/blocks/rel_pos_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,7 @@
from monai.networks.blocks.attention_utils import add_decomposed_rel_pos


class RelativePosEmbedding(nn.Module):
def __init__(
self,
) -> None:
super().__init__()

def forward(self, x: torch.Tensor, att_mat: torch.Tensor) -> torch.Tensor:
...


class DecomposedRelativePosEmbedding(RelativePosEmbedding):
class DecomposedRelativePosEmbedding(nn.Module):
def __init__(self, s_input_dims: Tuple, c_dim: int, num_heads: int) -> None:
"""
Args:
Expand Down
22 changes: 16 additions & 6 deletions tests/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,32 @@
import torch
from parameterized import parameterized

from monai.networks.layers.factories import RelPosEmbedding
from monai.networks import eval_mode
from monai.networks.blocks.selfattention import SABlock
from monai.utils import optional_import


einops, has_einops = optional_import("einops")

TEST_CASE_SABLOCK = []
for dropout_rate in np.linspace(0, 1, 4):
for hidden_size in [360, 480, 600, 768]:
for num_heads in [4, 6, 8, 12]:
test_case = [
{"hidden_size": hidden_size, "num_heads": num_heads, "dropout_rate": dropout_rate},
(2, 512, hidden_size),
(2, 512, hidden_size),
]
TEST_CASE_SABLOCK.append(test_case)
for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]:
for input_size in [(16, 32), (8, 8, 8)]:
test_case = [
{
"hidden_size": hidden_size,
"num_heads": num_heads,
"dropout_rate": dropout_rate,
"rel_pos_embedding": rel_pos_embedding,
"input_size": input_size,
},
(2, 512, hidden_size),
(2, 512, hidden_size),
]
TEST_CASE_SABLOCK.append(test_case)


class TestResBlock(unittest.TestCase):
Expand Down

0 comments on commit 4b2c852

Please sign in to comment.