Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Softmax kernel in Triton. Use softmax kernel and argmax in Llama generation.py. + Small changes #11

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ __pycache__/
.pytest_cache
**/.cache
**/meta-llama/**/*

# Virtual Environment
venv/
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Triton Kernels

## Supported Kernels
* flash attention
* matmul
* cross entropy

## Contributing
```
python3 main.py llama_chat_completion --benchmark --ckpt_dir <model_checkpoint_path> --tokenizer_path <model_tokenizer_path>
```


## Getting started



* Install dependencies



```bash

python3 -m pip install -r requirements.txt

```



* Download llama model



```bash

export HF_TOKEN=xxx

huggingface-cli download meta-llama/Meta-Llama-3-8B-Instruct --local-dir $HOME/models/llama-3-8b-instruct

```
5 changes: 4 additions & 1 deletion kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# from .conv import _conv, conv
from . import blocksparse
from .cross_entropy import _cross_entropy, cross_entropy
from .fused_softmax import _softmax, softmax
from .flash_attention import attention
from .matmul import _matmul, get_higher_dtype, matmul

Expand All @@ -12,4 +13,6 @@
"matmul",
"attention",
"get_higher_dtype",
]
"_softmax",
"softmax"
]
124 changes: 124 additions & 0 deletions kernels/fused_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import torch
import triton
import triton.language as tl
from triton.runtime import driver


def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"

def is_cdna():
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
'gfx90a', 'gfx908')

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
num_stages: tl.constexpr):
# starting row of the program
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)

def triton_softmax(x):
device = torch.cuda.current_device()
properties = driver.active.utils.get_device_properties(device)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}

n_rows, n_cols = x.shape

# The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)

# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 8

# Number of software piepling stages.
# num_stages = 4 if SIZE_SMEM > 200000 else 2
num_stages = 1

# Allocate output
y = torch.empty_like(x)

# pre-compile kernel to get register usage and compute thread occupancy.
kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0))
if kernel is None:
kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
num_stages=num_stages, num_warps=num_warps, grid=(1, ))
kernel._init_handles()
n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
if is_hip():
# NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
# However, this is not always the case. In most cases all registers can be used as regular purpose registers.
# ISA SECTION (3.6.4 for CDNA3)
# VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
# with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
# VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
# not required to be equal numbers of both types.
if is_cdna():
NUM_GPRS = NUM_REGS * 2

# MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
# When we divide this number with WARP_SIZE we get maximum number of waves that can
# execute on a CU (multi-processor) in parallel.
MAX_NUM_THREADS = properties["max_threads_per_sm"]
max_num_waves = MAX_NUM_THREADS // WARP_SIZE
occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
else:
occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
occupancy = min(occupancy, SIZE_SMEM // size_smem)
num_programs = NUM_SM * occupancy
kernels[BLOCK_SIZE] = (kernel, num_programs)

num_programs = min(num_programs, n_rows)

# Create a number of persistent programs.
kernel[(num_programs, 1, 1)](
y,
x,
x.stride(0),
y.stride(0),
n_rows,
n_cols,
)
return y



class _softmax(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return triton_softmax(x)

# @staticmethod
# def backward(ctx, grad_output):
# return grad_output, grad_output

softmax = _softmax.apply
4 changes: 4 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from models.llama import llama_example_chat_completion, llama_example_text_completion
from benchmarking import Profiler, compare_benchmarks
import pprint
import torch.distributed as dist


def main(operation: str, profile=False, benchmark=False, **kwargs):
Expand Down Expand Up @@ -65,6 +66,9 @@ def main(operation: str, profile=False, benchmark=False, **kwargs):
print(output)
print("\n==================================\n")

if dist.is_initialized():
dist.destroy_process_group()


def runner(operation: str, kwargs):
if operation == "llama_chat_completion":
Expand Down
3 changes: 1 addition & 2 deletions models/llama/llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from .tokenizer import ChatFormat, Dialog, Message, Tokenizer
from benchmarking import Profiler


class CompletionPrediction(TypedDict, total=False):
generation: str
tokens: List[str] # not required
Expand Down Expand Up @@ -196,7 +195,7 @@ def generate(
probs = self.Math.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = self.Math.argmax(logits[:, -1], dim=-1)
next_token = self.Math.argmax(logits[:,-1], dim=-1)

next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
Expand Down
11 changes: 6 additions & 5 deletions models/llama/llama/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from kernels.cross_entropy import cross_entropy
from kernels.matmul import matmul
from kernels.flash_attention import attention
from kernels.fused_softmax import softmax
from benchmarking import Profiler
import time

Expand Down Expand Up @@ -70,17 +71,17 @@ def attention(self, xq, keys, values, head_dim, mask):

@Profiler.profiling_decorator("softmax")
def softmax(self, x, dim):
if self.use_triton:
return F.softmax(x, dim=-1)
if self.use_triton and x.ndim == 2:
return softmax(x)
else:
return F.softmax(x, dim=-1)
return F.softmax(x, dim)

@Profiler.profiling_decorator("argmax")
def argmax(self, x, dim):
if self.use_triton:
return torch.argmax(x, dim=-1)
return triton.language.argmax(x, axis=dim)
else:
return torch.argmax(x, dim=-1)
return torch.argmax(x, dim=dim)

@Profiler.profiling_decorator("cross_entropy")
def cross_entropy(self, input_val, target, reduction, ignore_index):
Expand Down
4 changes: 4 additions & 0 deletions test/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Unit Tests for Kernels

Invoke all tests by calling ```pytest test/```.
To run a specific test, call ```pytest test/test_softmax.py```.
19 changes: 19 additions & 0 deletions test/test_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
import torch.nn.functional as F
import pytest
from kernels.fused_softmax import softmax

@pytest.mark.parametrize("input_size", [(1024, 1024), (512, 512), (2048, 512)])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding tests!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, this might not be ideal because we are not calling softmax from triton.ops like the other tests. I ran into issues with doing it that way.

def test_softmax_equivalence(input_size):
# Create random input tensor of specified size
x = torch.randn(*input_size).cuda()

# Compute softmax using PyTorch
pytorch_softmax = F.softmax(x, dim=-1)

# Compute softmax using Triton
triton_output = softmax(x)

# Assert that both outputs are approximately equal
assert torch.allclose(pytorch_softmax, triton_output, atol=1e-5), \
f"Triton softmax output doesn't match PyTorch softmax for input size {input_size}"