Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Protect autotuner with synchronization #5893

Closed
wants to merge 1 commit into from

Conversation

saagarjha
Copy link
Contributor

The autotuner calls kernels in the middle of configuring shared state, which drops the GIL. If multiple threads call into the autotuner at the same time, they will trample each other's attempts. We only need one thread to run the autotuning code, so we can guard the parameter selection with a lock to avoid this problem.

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because I'm not entirely sure how to test the absence of a race.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

The autotuner calls kernels in the middle of configuring shared state,
which drops the GIL. If multiple threads call into the autotuner at the
same time, they will trample each other's attempts. We only need one
thread to run the autotuning code, so we can guard the parameter
selection with a lock to avoid this problem.
@saagarjha saagarjha requested a review from ptillet as a code owner February 12, 2025 05:05
@Jokeren
Copy link
Contributor

Jokeren commented Feb 12, 2025

If multiple threads call into the autotuner at the same time

Can you give an example of your use case? It sounds like you should protect it outside of the autotuner.

@Jokeren
Copy link
Contributor

Jokeren commented Feb 12, 2025

To actually support multi-threading I believe more work should be done and there's a pending PR that we never merged.

@Jokeren
Copy link
Contributor

Jokeren commented Feb 12, 2025

#5436

@saagarjha
Copy link
Contributor Author

Our usecase isn't anything particularly special, we're just calling into @triton.jit kernels from multiple threads. This works by itself but if you stick the @triton.autotune decorator on it (which I assume is supposed to apply transparently) things stop working.

I didn't realize that the other issue I filed actually had a PR put up for it. I'd be happy to see that fixed too but it's orthogonal to the problem being solved here. My concern for #4806 was that autotuning took a while so it should internally run in parallel, while the issue here is that I can't actually call an autotuned kernel from multiple threads. We do this not to save on autotuning time, but to run multiple kernels at once on different streams.

@Jokeren
Copy link
Contributor

Jokeren commented Feb 12, 2025

Our use case isn't anything particularly special, we're just calling into @triton.jit kernels from multiple threads

You meant multiple threads calling the same kernel? Then I think you can create a lock before the kernel, right?

@Jokeren
Copy link
Contributor

Jokeren commented Feb 12, 2025

The autotuner calls kernels in the middle of configuring shared state, which drops the GIL

I also do not understand which part drops GIL you're referring to

@saagarjha
Copy link
Contributor Author

You meant multiple threads calling the same kernel?

Yes.

Then I think you can create a lock before the kernel, right?

We could do that, sure. It breaks the abstraction of the autotuner though, because we need to add a lock conditional on whether the kernel is using it or not.

I also do not understand which part drops GIL you're referring to

This is done when you wait for a kernel to complete; for example, if you synchronize on a stream in PyTorch it will drop the GIL for you. This means that the autotuning code is unprotected from reentrant calls as the benchmarked kernels finish. You will note that I do not actually protect the call to launch a kernel, because this is already thread-safe. It's just the shared state inside the autotuner that needs to be locked over.

@Jokeren
Copy link
Contributor

Jokeren commented Feb 12, 2025

Are you able to post your code? I'm still confused about the usage. More specifically, when should you use multiple streams to tune a single kernel? Is it because you split configurations among streams?

If not, and all threads sharing the same configuration space, I think the protection should be done outside of the autotuner. It's not designed to be thread-safe at the first place

@saagarjha
Copy link
Contributor Author

Sure, here's a simple test script:

#!/usr/bin/env python3

import torch
import threading
import triton


@triton.autotune(
    configs=[triton.Config({"N": n}) for n in range(100)],
    key=[],
)
@triton.jit
def test(N: triton.language.constexpr = 0):
    pass


if __name__ == "__main__":

    def run():
        with torch.cuda.stream(torch.cuda.Stream()):
            test[(1,)]()

    threads = [threading.Thread(target=run) for _ in range(10)]
    for thread in threads:
        thread.start()
    for thread in threads:
        thread.join()

If you comment out the @triton.autotune this works just fine. I don't actually want multiple threads to autotune the kernel, I just want the first call to figure out the parameters and then the rest of the threads can call it directly. You say that the autotuner is not designed to be thread-safe but considering that @triton.jit itself should be, and the autotuner is designed to be transparent around the actual kernel call, it seems reasonable to expect this to extend to autotuning as well. If this argument wasn't compelling, I will point out that doing this externally is actually somewhat complicated because the state you'd want to observe (namely, whether autotuning is finished) is not exposed. So doing this outside of the autotuner is not straightforward.

@saagarjha
Copy link
Contributor Author

@Jokeren just wanted to check in; did that help justify why this is useful or are you still not interested in this PR? I can close it if that's the case (and we'll carry it locally).

@Jokeren
Copy link
Contributor

Jokeren commented Feb 28, 2025

Did that help justify why this is useful or are you still not interested in this PR?

Sorry, I'm a bit conservative about making Triton's autotuner more complex than necessary. It indeed lacks several features currently, such as caching configuration files, tuning with multiple threads, or handling the case you mentioned. There are pros and cons associated with these additions. Perhaps, for now, we could keep these changes as experimental within your local branches.

@saagarjha
Copy link
Contributor Author

Ok, I'll close this. Thanks for clarifying!

@saagarjha saagarjha closed this Feb 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants