diff --git a/tests/unittest/ops/test_layernorm_sigmoid_mul.py b/tests/unittest/ops/test_layernorm_sigmoid_mul.py index 0d41ff55d..756734e31 100644 --- a/tests/unittest/ops/test_layernorm_sigmoid_mul.py +++ b/tests/unittest/ops/test_layernorm_sigmoid_mul.py @@ -912,14 +912,6 @@ def test_group_fused_layernorm_sigmoid_mul(self, dtype: str): dtype=dtype, ) - # Make sure we test the boundary between being able to fit the arguments in constant memory vs not. - for num_groups in range(38, 41): - self._test_group_fused_layernorm_sigmoid_mul( - [[1024, 256]] * num_groups, - use_size_op=True, - dtype=dtype, - ) - # < 1024 kernel self._test_group_fused_layernorm_sigmoid_mul( [[4, 16]], @@ -986,6 +978,23 @@ def test_group_fused_layernorm_sigmoid_mul(self, dtype: str): [[128, 1025], [128, 0], [128, 1023]], dtype=dtype, ) + + @parameterized.expand( + [ + param("float16"), + param("float32"), + param("bfloat16"), + ] + ) + def test_group_fused_layernorm_sigmoid_mul_long(self, dtype: str): + # Make sure we test the boundary between being able to fit the arguments in constant memory vs not. + for num_groups in range(38, 41): + self._test_group_fused_layernorm_sigmoid_mul( + [[1024, 256]] * num_groups, + use_size_op=True, + dtype=dtype, + ) + # Ditto boundary test for num_groups_divided_by_3 in range(12, 15): self._test_group_fused_layernorm_sigmoid_mul( @@ -993,6 +1002,14 @@ def test_group_fused_layernorm_sigmoid_mul(self, dtype: str): dtype=dtype, ) + @parameterized.expand( + [ + param("float16"), + param("float32"), + param("bfloat16"), + ] + ) + def test_group_fused_layernorm_sigmoid_mul_nd(self, dtype: str): # ND self._test_group_fused_layernorm_sigmoid_mul( [[2, 512, 256, 16], [2, 512, 128, 4]],