-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Softmax kernel in Triton. Use softmax kernel and argmax in Llama generation.py. + Small changes #11
Open
catherinelee274
wants to merge
12
commits into
triton-lang:main
Choose a base branch
from
catherinelee274:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add Softmax kernel in Triton. Use softmax kernel and argmax in Llama generation.py. + Small changes #11
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
a9cb93a
update readme + silu_mul
catherinelee274 6475766
Update gitignore
catherinelee274 426b934
update readme
catherinelee274 7c84e9d
updates
catherinelee274 16c4040
udpates, removing warning
catherinelee274 7e47ddd
Remove files from commit but keep them locally
catherinelee274 a8759f1
update README.md
catherinelee274 213c407
Add back benchmarking_utils
catherinelee274 bec9312
updates--need to debug argmax issue
catherinelee274 88e1bed
Put softmax in MathOps
catherinelee274 243e22b
Add back accidentally removed next_token var
catherinelee274 036e4bc
self.triton.langauge.argmax -> triton.langauge.argmax
catherinelee274 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,6 @@ __pycache__/ | |
.pytest_cache | ||
**/.cache | ||
**/meta-llama/**/* | ||
|
||
# Virtual Environment | ||
venv/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Triton Kernels | ||
|
||
## Supported Kernels | ||
* flash attention | ||
* matmul | ||
* cross entropy | ||
|
||
## Contributing | ||
``` | ||
python3 main.py llama_chat_completion --benchmark --ckpt_dir <model_checkpoint_path> --tokenizer_path <model_tokenizer_path> | ||
``` | ||
|
||
|
||
## Getting started | ||
|
||
|
||
|
||
* Install dependencies | ||
|
||
|
||
|
||
```bash | ||
|
||
python3 -m pip install -r requirements.txt | ||
|
||
``` | ||
|
||
|
||
|
||
* Download llama model | ||
|
||
|
||
|
||
```bash | ||
|
||
export HF_TOKEN=xxx | ||
|
||
huggingface-cli download meta-llama/Meta-Llama-3-8B-Instruct --local-dir $HOME/models/llama-3-8b-instruct | ||
|
||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import torch | ||
import triton | ||
import triton.language as tl | ||
from triton.runtime import driver | ||
|
||
|
||
def is_hip(): | ||
return triton.runtime.driver.active.get_current_target().backend == "hip" | ||
|
||
def is_cdna(): | ||
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', | ||
'gfx90a', 'gfx908') | ||
|
||
@triton.jit | ||
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, | ||
num_stages: tl.constexpr): | ||
# starting row of the program | ||
row_start = tl.program_id(0) | ||
row_step = tl.num_programs(0) | ||
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): | ||
# The stride represents how much we need to increase the pointer to advance 1 row | ||
row_start_ptr = input_ptr + row_idx * input_row_stride | ||
# The block size is the next power of two greater than n_cols, so we can fit each | ||
# row in a single block | ||
col_offsets = tl.arange(0, BLOCK_SIZE) | ||
input_ptrs = row_start_ptr + col_offsets | ||
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols | ||
mask = col_offsets < n_cols | ||
row = tl.load(input_ptrs, mask=mask, other=-float('inf')) | ||
# Subtract maximum for numerical stability | ||
row_minus_max = row - tl.max(row, axis=0) | ||
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) | ||
numerator = tl.exp(row_minus_max) | ||
denominator = tl.sum(numerator, axis=0) | ||
softmax_output = numerator / denominator | ||
# Write back output to DRAM | ||
output_row_start_ptr = output_ptr + row_idx * output_row_stride | ||
output_ptrs = output_row_start_ptr + col_offsets | ||
tl.store(output_ptrs, softmax_output, mask=mask) | ||
|
||
def triton_softmax(x): | ||
device = torch.cuda.current_device() | ||
properties = driver.active.utils.get_device_properties(device) | ||
NUM_SM = properties["multiprocessor_count"] | ||
NUM_REGS = properties["max_num_regs"] | ||
SIZE_SMEM = properties["max_shared_mem"] | ||
WARP_SIZE = properties["warpSize"] | ||
target = triton.runtime.driver.active.get_current_target() | ||
kernels = {} | ||
|
||
n_rows, n_cols = x.shape | ||
|
||
# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` | ||
BLOCK_SIZE = triton.next_power_of_2(n_cols) | ||
|
||
# Another trick we can use is to ask the compiler to use more threads per row by | ||
# increasing the number of warps (`num_warps`) over which each row is distributed. | ||
# You will see in the next tutorial how to auto-tune this value in a more natural | ||
# way so you don't have to come up with manual heuristics yourself. | ||
num_warps = 8 | ||
|
||
# Number of software piepling stages. | ||
# num_stages = 4 if SIZE_SMEM > 200000 else 2 | ||
num_stages = 1 | ||
|
||
# Allocate output | ||
y = torch.empty_like(x) | ||
|
||
# pre-compile kernel to get register usage and compute thread occupancy. | ||
kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0)) | ||
if kernel is None: | ||
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, | ||
num_stages=num_stages, num_warps=num_warps, grid=(1, )) | ||
kernel._init_handles() | ||
n_regs = kernel.n_regs | ||
size_smem = kernel.metadata.shared | ||
if is_hip(): | ||
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available. | ||
# However, this is not always the case. In most cases all registers can be used as regular purpose registers. | ||
# ISA SECTION (3.6.4 for CDNA3) | ||
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used | ||
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total | ||
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is | ||
# not required to be equal numbers of both types. | ||
if is_cdna(): | ||
NUM_GPRS = NUM_REGS * 2 | ||
|
||
# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor. | ||
# When we divide this number with WARP_SIZE we get maximum number of waves that can | ||
# execute on a CU (multi-processor) in parallel. | ||
MAX_NUM_THREADS = properties["max_threads_per_sm"] | ||
max_num_waves = MAX_NUM_THREADS // WARP_SIZE | ||
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps | ||
else: | ||
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) | ||
occupancy = min(occupancy, SIZE_SMEM // size_smem) | ||
num_programs = NUM_SM * occupancy | ||
kernels[BLOCK_SIZE] = (kernel, num_programs) | ||
|
||
num_programs = min(num_programs, n_rows) | ||
|
||
# Create a number of persistent programs. | ||
kernel[(num_programs, 1, 1)]( | ||
y, | ||
x, | ||
x.stride(0), | ||
y.stride(0), | ||
n_rows, | ||
n_cols, | ||
) | ||
return y | ||
|
||
|
||
|
||
class _softmax(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, x): | ||
return triton_softmax(x) | ||
|
||
# @staticmethod | ||
# def backward(ctx, grad_output): | ||
# return grad_output, grad_output | ||
|
||
softmax = _softmax.apply |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Unit Tests for Kernels | ||
|
||
Invoke all tests by calling ```pytest test/```. | ||
To run a specific test, call ```pytest test/test_softmax.py```. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
import pytest | ||
from kernels.fused_softmax import softmax | ||
|
||
@pytest.mark.parametrize("input_size", [(1024, 1024), (512, 512), (2048, 512)]) | ||
def test_softmax_equivalence(input_size): | ||
# Create random input tensor of specified size | ||
x = torch.randn(*input_size).cuda() | ||
|
||
# Compute softmax using PyTorch | ||
pytorch_softmax = F.softmax(x, dim=-1) | ||
|
||
# Compute softmax using Triton | ||
triton_output = softmax(x) | ||
|
||
# Assert that both outputs are approximately equal | ||
assert torch.allclose(pytorch_softmax, triton_output, atol=1e-5), \ | ||
f"Triton softmax output doesn't match PyTorch softmax for input size {input_size}" |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding tests!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, this might not be ideal because we are not calling softmax from triton.ops like the other tests. I ran into issues with doing it that way.