Skip to content

Commit

Permalink
Use device fixture for more tests in test_core.py (#5885)
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev authored Feb 11, 2025
1 parent 2bc85dc commit 420c290
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2441,7 +2441,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const
negative_config = [('cumsum', 'float32', (32, 32), -1, False, 4)]


def test_sum_dtype():
def test_sum_dtype(device):

@triton.jit
def kernel_dtype(out_ptr, init, in_dtype: tl.constexpr, out_dtype: tl.constexpr):
Expand All @@ -2461,7 +2461,7 @@ def kernel_default_float(out_ptr):
x = tl.sum(x)
tl.store(out_ptr, x)

out = torch.empty(1, dtype=torch.int32, device='cuda')
out = torch.empty(1, dtype=torch.int32, device=device)
kernel_dtype[(1, )](out, init=1, in_dtype=tl.int1, out_dtype=None)
assert out[0] == 32 * 32

Expand All @@ -2477,9 +2477,9 @@ def kernel_default_float(out_ptr):
kernel_default_int[(1, )](out)
assert out[0] == 32 * 32

out = torch.empty(1, dtype=torch.bfloat16, device='cuda')
out = torch.empty(1, dtype=torch.bfloat16, device=device)
kernel_default_float[(1, )](out)
torch.testing.assert_close(out[0], torch.tensor(32 * 32, dtype=torch.bfloat16, device='cuda'))
torch.testing.assert_close(out[0], torch.tensor(32 * 32, dtype=torch.bfloat16, device=device))


@triton.jit
Expand Down Expand Up @@ -2675,16 +2675,16 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr):


@pytest.mark.parametrize("M, N", [(1, 64), (2, 32), (4, 16), (8, 8), (16, 4), (32, 2), (64, 1)])
def test_scan_1d(M, N):
def test_scan_1d(M, N, device):

@triton.jit
def scan_kernel(out_ptr, in_ptr, M: tl.constexpr, N: tl.constexpr):
input = tl.load(in_ptr + tl.arange(0, M))
output = tl.cumsum(input).reshape([1, M]).broadcast_to([N, M])
tl.store(out_ptr + tl.arange(0, M * N), output.reshape([M * N]))

x = torch.randint(-100, 100, (M, ), dtype=torch.int32, device='cuda')
output = torch.empty(M * N, dtype=torch.int32, device='cuda')
x = torch.randint(-100, 100, (M, ), dtype=torch.int32, device=device)
output = torch.empty(M * N, dtype=torch.int32, device=device)

scan_kernel[(1, )](output, x, M, N)

Expand Down Expand Up @@ -4813,14 +4813,14 @@ def kernel():


@pytest.mark.interpreter
def test_tma_load_block_shape_err():
def test_tma_load_block_shape_err(device):

@triton.jit
def kernel(ptr):
desc = tl._experimental_make_tensor_descriptor(ptr, [128, 128], [128, 1], [1, 32])
desc.load([0, 0])

input = torch.empty((128, 128), dtype=torch.int32, device='cuda')
input = torch.empty((128, 128), dtype=torch.int32, device=device)
errc = triton.CompilationError if not is_interpreter() else InterpreterError
with pytest.raises(errc) as e:
kernel[(1, )](input)
Expand All @@ -4829,14 +4829,14 @@ def kernel(ptr):


@pytest.mark.interpreter
def test_tma_store_block_shape_err():
def test_tma_store_block_shape_err(device):

@triton.jit
def kernel(ptr):
desc = tl._experimental_make_tensor_descriptor(ptr, [128, 128], [128, 1], [8, 8])
desc.store([0, 0], tl.zeros((1, 32), dtype=tl.int16))

input = torch.empty((128, 128), dtype=torch.int16, device='cuda')
input = torch.empty((128, 128), dtype=torch.int16, device=device)
errc = triton.CompilationError if not is_interpreter() else InterpreterError
with pytest.raises(errc) as e:
kernel[(1, )](input)
Expand Down

0 comments on commit 420c290

Please sign in to comment.