Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triton interpreter cannot handle parameters that alias #5791

Open
saagarjha opened this issue Feb 2, 2025 · 13 comments
Open

Triton interpreter cannot handle parameters that alias #5791

saagarjha opened this issue Feb 2, 2025 · 13 comments
Labels

Comments

@saagarjha
Copy link
Contributor

Describe the bug

When invoking the interpreter, Triton will make a copy of the tensors passed in so that they can be operated on directly (e.g. by copying to the host). Unfortunately the straightforward way to do this where each input tensor gets a new host copy, running the kernel, then copying out is subtly incorrect. In particular, a kernel like this:

@triton.jit
def aliasing_test(buffer, buffer2):
    triton.language.store(buffer, 1)
    
if __name__ == "__main__":
    buffer = torch.zeros(1, device="cuda")
    aliasing_test[(1,)](buffer, buffer)
    print(buffer)

should print "1" but it prints "0". This is because buffer is copied in from the kernel and then buffer2 (which isn't written to) overwrites it with the original value.

Environment details

Triton: built from main
GPU: H100

@saagarjha saagarjha added the bug label Feb 2, 2025
@Jokeren
Copy link
Contributor

Jokeren commented Feb 2, 2025

Yeah, I acknowledge the problem and thank you for reporting it.

I think we should first check the storage of all tensors to determine if any tensor is a slice of or identical to another tensor in the input arguments. These child tensors should then be excluded from the copy process and instead take a slice from the copied "parent" tensors. When storing data back to the GPU, a similar process will be performed.

It's not a priority since most kernels do not have such a case.

@Jokeren
Copy link
Contributor

Jokeren commented Feb 2, 2025

Let me know if you want to propose a fix and I will assign the issue to you

@saagarjha
Copy link
Contributor Author

So the reason I filed this as a bug instead of just opening a PR is that this actually seems nontrivial to solve ;) Figuring out if two parameters alias seems like a hard problem, because you have to perform an intersection of two arbitrarily-strided tensors, and I don't know if there is any API or non-annoying way to compute this. I can keep thinking about it but if you have ideas I'd be happy to hear them.

@Jokeren
Copy link
Contributor

Jokeren commented Feb 2, 2025

It won't be that difficult to solve. Any tensor that is a view of another will share the same storage and maintain the same storage data_ptr. Likely I won't have time to handle the problem this week. Let me keep this issue open and feel free to take it later, or I can ask someone else to take it.

@saagarjha
Copy link
Contributor Author

Well, you can have aliasing occur from any pointer really, not just views into the same tensor. Like, you can wrap a tensor around an existing allocation, and you'll run into issues if there is any overlap. The following two tensors partially overlap for example:

  • base 0x1000, stride 2, size 5
  • base 0x1002, stride 3, size 4

They overlap at 0x1002 and 0x1008. I think the general version of this would require computing some sort of LCM stride for all the parameters and then figuring out what the layout of the overlap is.

@lezcano
Copy link
Contributor

lezcano commented Feb 3, 2025

In PyTorch, to see if two tensors are views of the same tensor, you can do x._base is y._base (if x is not a view, _base will return None, so in that case you have to take x as its base, but yeah).

@saagarjha
Copy link
Contributor Author

Right, but my point is that you can form a tensor around any memory address, including one owned by someone else, without it being needing to be a view. I'm fine with going "yeah but that is dumb and we are not going to do aliasing checks for that" but ideally that is something we decide on rather than just missing that case by accident :)

@Jokeren
Copy link
Contributor

Jokeren commented Feb 3, 2025

Right, but my point is that you can form a tensor around any memory address, including one owned by someone else, without it being needing to be a view

Can you show me how you create overlapped tensors with different t.untyped_storage().data_ptr() in Python?

@peterbell10
Copy link
Contributor

Here is a minimal example:

import numpy as np
import torch

a = np.random.randn(100)
b = torch.from_numpy(a)
c = torch.from_numpy(a[1:])
assert b.untyped_storage().data_ptr() != c.untyped_storage().data_ptr()

I'm inclined to say this is too much of an edge case to focus on though.

@Jokeren
Copy link
Contributor

Jokeren commented Feb 3, 2025

Here is a minimal example:

import numpy as np
import torch

a = np.random.randn(100)
b = torch.from_numpy(a)
c = torch.from_numpy(a[1:])
assert b.untyped_storage().data_ptr() != c.untyped_storage().data_ptr()
I'm inclined to say this is too much of an edge case to focus on though.

Fair enough example, then detecting the maximum range of all input arguments using strides and sizes would be required.

@lezcano
Copy link
Contributor

lezcano commented Feb 3, 2025

I think that we can do something that works on a best effort basis, and that would already cover 99% of the use cases. Then we can leave a note somewhere mentioning that if you twist the interpreter's arm a bit too much with views + funny constructors (not regular PyTorch ops) it might struggle.

@Jokeren
Copy link
Contributor

Jokeren commented Feb 3, 2025

I think that we can do something that works on a best effort basis, and that would already cover 99% of the use cases. Then we can leave a note somewhere mentioning that if you twist the interpreter's arm a bit too much with views + funny constructors (not regular PyTorch ops) it might struggle.

Agreed.

@saagarjha
Copy link
Contributor Author

Alright, I’m happy to take it in that case. At the very least this will let me clear out some of the workarounds we have in our code :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants