Skip to content

Commit

Permalink
Re-order skipUnless in test_restormer.py, Signed-off-by: Cano-Muniz, …
Browse files Browse the repository at this point in the history
…Santiago <[email protected]>
  • Loading branch information
phisanti committed Mar 8, 2025
1 parent da0a186 commit f17e06e
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tests/networks/nets/test_restormer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@

class TestMDTATransformerBlock(unittest.TestCase):

@skipUnless(has_einops, "Requires einops")
@parameterized.expand(TEST_CASES_TRANSFORMER)
@skipUnless(has_einops, "Requires einops")
def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape):
if flash and not torch.cuda.is_available():
self.skipTest("Flash attention requires CUDA")
Expand All @@ -111,6 +111,7 @@ def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_
class TestOverlapPatchEmbed(unittest.TestCase):

@parameterized.expand(TEST_CASES_PATCHEMBED)
@skipUnless(has_einops, "Requires einops")
def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected_shape):
net = OverlapPatchEmbed(spatial_dims=spatial_dims, in_channels=in_channels, embed_dim=embed_dim)
with eval_mode(net):
Expand All @@ -120,8 +121,8 @@ def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected

class TestRestormer(unittest.TestCase):

@skipUnless(has_einops, "Requires einops")
@parameterized.expand(TEST_CASES_RESTORMER)
@skipUnless(has_einops, "Requires einops")
def test_shape(self, input_param, input_shape, expected_shape):
if input_param.get("flash_attention", False) and not torch.cuda.is_available():
self.skipTest("Flash attention requires CUDA")
Expand Down

0 comments on commit f17e06e

Please sign in to comment.