Skip to content

Commit

Permalink
Break down test_group_fused_layernorm_sigmoid_mul to avoid timeout (#…
Browse files Browse the repository at this point in the history
…1031)

Summary:
Pull Request resolved: #1031

test_group_fused_layernorm_sigmoid_mul can timeout. Break it down into smaller pieces.

Reviewed By: ColinPeppler

Differential Revision: D64557485

fbshipit-source-id: 216d03ac8d4af2bbe2608c8d655b2bc17230c750
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Oct 18, 2024
1 parent 437b48a commit aef8f6e
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions tests/unittest/ops/test_layernorm_sigmoid_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,14 +912,6 @@ def test_group_fused_layernorm_sigmoid_mul(self, dtype: str):
dtype=dtype,
)

# Make sure we test the boundary between being able to fit the arguments in constant memory vs not.
for num_groups in range(38, 41):
self._test_group_fused_layernorm_sigmoid_mul(
[[1024, 256]] * num_groups,
use_size_op=True,
dtype=dtype,
)

# < 1024 kernel
self._test_group_fused_layernorm_sigmoid_mul(
[[4, 16]],
Expand Down Expand Up @@ -986,13 +978,38 @@ def test_group_fused_layernorm_sigmoid_mul(self, dtype: str):
[[128, 1025], [128, 0], [128, 1023]],
dtype=dtype,
)

@parameterized.expand(
[
param("float16"),
param("float32"),
param("bfloat16"),
]
)
def test_group_fused_layernorm_sigmoid_mul_long(self, dtype: str):
# Make sure we test the boundary between being able to fit the arguments in constant memory vs not.
for num_groups in range(38, 41):
self._test_group_fused_layernorm_sigmoid_mul(
[[1024, 256]] * num_groups,
use_size_op=True,
dtype=dtype,
)

# Ditto boundary test
for num_groups_divided_by_3 in range(12, 15):
self._test_group_fused_layernorm_sigmoid_mul(
[[1024, 1025], [1024, 1276], [1024, 1023]] * num_groups_divided_by_3,
dtype=dtype,
)

@parameterized.expand(
[
param("float16"),
param("float32"),
param("bfloat16"),
]
)
def test_group_fused_layernorm_sigmoid_mul_nd(self, dtype: str):
# ND
self._test_group_fused_layernorm_sigmoid_mul(
[[2, 512, 256, 16], [2, 512, 128, 4]],
Expand Down

0 comments on commit aef8f6e

Please sign in to comment.