diff --git a/tests/networks/nets/test_restormer.py b/tests/networks/nets/test_restormer.py index ab08d84390..9b54b7a765 100644 --- a/tests/networks/nets/test_restormer.py +++ b/tests/networks/nets/test_restormer.py @@ -88,8 +88,8 @@ class TestMDTATransformerBlock(unittest.TestCase): - @skipUnless(has_einops, "Requires einops") @parameterized.expand(TEST_CASES_TRANSFORMER) + @skipUnless(has_einops, "Requires einops") def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape): if flash and not torch.cuda.is_available(): self.skipTest("Flash attention requires CUDA") @@ -111,6 +111,7 @@ def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_ class TestOverlapPatchEmbed(unittest.TestCase): @parameterized.expand(TEST_CASES_PATCHEMBED) + @skipUnless(has_einops, "Requires einops") def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected_shape): net = OverlapPatchEmbed(spatial_dims=spatial_dims, in_channels=in_channels, embed_dim=embed_dim) with eval_mode(net): @@ -120,8 +121,8 @@ def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected class TestRestormer(unittest.TestCase): - @skipUnless(has_einops, "Requires einops") @parameterized.expand(TEST_CASES_RESTORMER) + @skipUnless(has_einops, "Requires einops") def test_shape(self, input_param, input_shape, expected_shape): if input_param.get("flash_attention", False) and not torch.cuda.is_available(): self.skipTest("Flash attention requires CUDA")