Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gather does not work if index is much longer than value #5836

Open
SchrodingerZhu opened this issue Feb 6, 2025 · 1 comment
Open

Gather does not work if index is much longer than value #5836

SchrodingerZhu opened this issue Feb 6, 2025 · 1 comment
Assignees
Labels

Comments

@SchrodingerZhu
Copy link
Contributor

Describe the bug

import triton
import triton.language as tl
import torch
@triton.jit
def test(values, index, output):
    val = tl.load(values + tl.arange(0, 4))
    idx = tl.load(index + tl.arange(0, 4096))
    result = val.gather(idx, axis=0)
    tl.store(output + tl.arange(0, 4096), result)
a = torch.tensor([1, 2, 3, 4], device='cuda')
b = torch.zeros((4096,), device='cuda', dtype=torch.int32)
c = torch.empty((4096,), device='cuda')
test[lambda _ : (4,)](a, b, c)

The program aborts the interpreter

python: /home/ubuntu/triton/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp:232: void mlir::triton::gpu::setOptimizedGatherLayout(mlir::triton::GatherOp, mlir::RewriterBase&): Assertion `GatherLoweringHelper(op).isWarpLocal()' failed.
Aborted

Environment details

Triton: triton==3.2.0+git94643b23
Python: 3.10

@Mogball Mogball self-assigned this Feb 6, 2025
@zinccat
Copy link

zinccat commented Feb 20, 2025

hi, any update on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants