Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misreport "Cannot have return statements inside while or for" if values returned by function are disregarded #5768

Open
xinyazhang opened this issue Jan 30, 2025 · 0 comments
Labels

Comments

@xinyazhang
Copy link
Contributor

Describe the bug

If caller disregard the returned values of a @jit function inside while or for loop, the compiler will report

Cannot have `return` statements inside `while` or `for` statements in triton (note that this also applies to `return` statements that are inside functions transitively called from within `while`/`for` statements)

Even if the code should compile without problems

Minimal Reproducer

#!/usr/bin/env python

import torch
import triton
import triton.language as tl

@triton.jit
def inner_store(ptrs, regs, mask):
    tl.store(ptrs, regs, mask=mask)
    return regs  # Return something

@triton.jit
def add_kernel(x_ptr,
               y_ptr,
               output_ptr,
               n_elements,
               BLOCK_SIZE: tl.constexpr,
               ):
    for block_start in range(0, n_elements, BLOCK_SIZE):
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(x_ptr + offsets, mask=mask)
        y = tl.load(y_ptr + offsets, mask=mask)
        output = x + y
        if block_start < 4096:
            # But ignore the return value
            inner_store(output_ptr + offsets, output, mask=mask)

DEVICE='cuda'
DTYPE=torch.float16

size = 256
x = torch.rand(size, dtype=DTYPE, device=DEVICE)
y = torch.rand(size, dtype=DTYPE, device=DEVICE)
output = torch.empty_like(x)
grid = lambda meta: (1,)
add_kernel[grid](x, y, output, output.numel(), BLOCK_SIZE=32, num_stages=1)

Environment details

Triton: 0fa1dbd
GPU: Advanced Micro Devices, Inc. [AMD/ATI] Aldebaran/MI200 [Instinct MI210] (rev 02)
Python: 3.10.15

@xinyazhang xinyazhang added the bug label Jan 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant