diff --git a/.gitignore b/.gitignore index d300e33..9e9fb12 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,6 @@ __pycache__/ .pytest_cache **/.cache **/meta-llama/**/* + +# Virtual Environment +venv/ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..c44abb5 --- /dev/null +++ b/README.md @@ -0,0 +1,40 @@ +# Triton Kernels + +## Supported Kernels +* flash attention +* matmul +* cross entropy + +## Contributing +``` +python3 main.py llama_chat_completion --benchmark --ckpt_dir --tokenizer_path +``` + + +## Getting started + + + +* Install dependencies + + + +```bash + +python3 -m pip install -r requirements.txt + +``` + + + +* Download llama model + + + +```bash + +export HF_TOKEN=xxx + +huggingface-cli download meta-llama/Meta-Llama-3-8B-Instruct --local-dir $HOME/models/llama-3-8b-instruct + +``` \ No newline at end of file diff --git a/kernels/__init__.py b/kernels/__init__.py index dd492d3..b8547ec 100644 --- a/kernels/__init__.py +++ b/kernels/__init__.py @@ -1,6 +1,7 @@ # from .conv import _conv, conv from . import blocksparse from .cross_entropy import _cross_entropy, cross_entropy +from .fused_softmax import _softmax, softmax from .flash_attention import attention from .matmul import _matmul, get_higher_dtype, matmul @@ -12,4 +13,6 @@ "matmul", "attention", "get_higher_dtype", -] + "_softmax", + "softmax" +] \ No newline at end of file diff --git a/kernels/fused_softmax.py b/kernels/fused_softmax.py new file mode 100644 index 0000000..76c9ad7 --- /dev/null +++ b/kernels/fused_softmax.py @@ -0,0 +1,124 @@ +import torch +import triton +import triton.language as tl +from triton.runtime import driver + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + +def is_cdna(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', + 'gfx90a', 'gfx908') + +@triton.jit +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, + num_stages: tl.constexpr): + # starting row of the program + row_start = tl.program_id(0) + row_step = tl.num_programs(0) + for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): + # The stride represents how much we need to increase the pointer to advance 1 row + row_start_ptr = input_ptr + row_idx * input_row_stride + # The block size is the next power of two greater than n_cols, so we can fit each + # row in a single block + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + mask = col_offsets < n_cols + row = tl.load(input_ptrs, mask=mask, other=-float('inf')) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=0) + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + # Write back output to DRAM + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=mask) + +def triton_softmax(x): + device = torch.cuda.current_device() + properties = driver.active.utils.get_device_properties(device) + NUM_SM = properties["multiprocessor_count"] + NUM_REGS = properties["max_num_regs"] + SIZE_SMEM = properties["max_shared_mem"] + WARP_SIZE = properties["warpSize"] + target = triton.runtime.driver.active.get_current_target() + kernels = {} + + n_rows, n_cols = x.shape + + # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` + BLOCK_SIZE = triton.next_power_of_2(n_cols) + + # Another trick we can use is to ask the compiler to use more threads per row by + # increasing the number of warps (`num_warps`) over which each row is distributed. + # You will see in the next tutorial how to auto-tune this value in a more natural + # way so you don't have to come up with manual heuristics yourself. + num_warps = 8 + + # Number of software piepling stages. + # num_stages = 4 if SIZE_SMEM > 200000 else 2 + num_stages = 1 + + # Allocate output + y = torch.empty_like(x) + + # pre-compile kernel to get register usage and compute thread occupancy. + kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0)) + if kernel is None: + kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, + num_stages=num_stages, num_warps=num_warps, grid=(1, )) + kernel._init_handles() + n_regs = kernel.n_regs + size_smem = kernel.metadata.shared + if is_hip(): + # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available. + # However, this is not always the case. In most cases all registers can be used as regular purpose registers. + # ISA SECTION (3.6.4 for CDNA3) + # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used + # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total + # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is + # not required to be equal numbers of both types. + if is_cdna(): + NUM_GPRS = NUM_REGS * 2 + + # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor. + # When we divide this number with WARP_SIZE we get maximum number of waves that can + # execute on a CU (multi-processor) in parallel. + MAX_NUM_THREADS = properties["max_threads_per_sm"] + max_num_waves = MAX_NUM_THREADS // WARP_SIZE + occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps + else: + occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) + occupancy = min(occupancy, SIZE_SMEM // size_smem) + num_programs = NUM_SM * occupancy + kernels[BLOCK_SIZE] = (kernel, num_programs) + + num_programs = min(num_programs, n_rows) + + # Create a number of persistent programs. + kernel[(num_programs, 1, 1)]( + y, + x, + x.stride(0), + y.stride(0), + n_rows, + n_cols, + ) + return y + + + +class _softmax(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return triton_softmax(x) + + # @staticmethod + # def backward(ctx, grad_output): + # return grad_output, grad_output + +softmax = _softmax.apply \ No newline at end of file diff --git a/main.py b/main.py index eac1c68..1e33abd 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ from models.llama import llama_example_chat_completion, llama_example_text_completion from benchmarking import Profiler, compare_benchmarks import pprint +import torch.distributed as dist def main(operation: str, profile=False, benchmark=False, **kwargs): @@ -65,6 +66,9 @@ def main(operation: str, profile=False, benchmark=False, **kwargs): print(output) print("\n==================================\n") + if dist.is_initialized(): + dist.destroy_process_group() + def runner(operation: str, kwargs): if operation == "llama_chat_completion": diff --git a/models/llama/llama/generation.py b/models/llama/llama/generation.py index 1addd8a..e6f3fa0 100644 --- a/models/llama/llama/generation.py +++ b/models/llama/llama/generation.py @@ -20,7 +20,6 @@ from .tokenizer import ChatFormat, Dialog, Message, Tokenizer from benchmarking import Profiler - class CompletionPrediction(TypedDict, total=False): generation: str tokens: List[str] # not required @@ -196,7 +195,7 @@ def generate( probs = self.Math.softmax(logits[:, -1] / temperature, dim=-1) next_token = sample_top_p(probs, top_p) else: - next_token = self.Math.argmax(logits[:, -1], dim=-1) + next_token = self.Math.argmax(logits[:,-1], dim=-1) next_token = next_token.reshape(-1) # only replace token if prompt has already been generated diff --git a/models/llama/llama/math_ops.py b/models/llama/llama/math_ops.py index 5ef81b2..3024eee 100644 --- a/models/llama/llama/math_ops.py +++ b/models/llama/llama/math_ops.py @@ -8,6 +8,7 @@ from kernels.cross_entropy import cross_entropy from kernels.matmul import matmul from kernels.flash_attention import attention +from kernels.fused_softmax import softmax from benchmarking import Profiler import time @@ -70,17 +71,17 @@ def attention(self, xq, keys, values, head_dim, mask): @Profiler.profiling_decorator("softmax") def softmax(self, x, dim): - if self.use_triton: - return F.softmax(x, dim=-1) + if self.use_triton and x.ndim == 2: + return softmax(x) else: - return F.softmax(x, dim=-1) + return F.softmax(x, dim) @Profiler.profiling_decorator("argmax") def argmax(self, x, dim): if self.use_triton: - return torch.argmax(x, dim=-1) + return triton.language.argmax(x, axis=dim) else: - return torch.argmax(x, dim=-1) + return torch.argmax(x, dim=dim) @Profiler.profiling_decorator("cross_entropy") def cross_entropy(self, input_val, target, reduction, ignore_index): diff --git a/test/README.md b/test/README.md new file mode 100644 index 0000000..6cc47f7 --- /dev/null +++ b/test/README.md @@ -0,0 +1,4 @@ +# Unit Tests for Kernels + +Invoke all tests by calling ```pytest test/```. +To run a specific test, call ```pytest test/test_softmax.py```. \ No newline at end of file diff --git a/test/test_softmax.py b/test/test_softmax.py new file mode 100644 index 0000000..45891e2 --- /dev/null +++ b/test/test_softmax.py @@ -0,0 +1,19 @@ +import torch +import torch.nn.functional as F +import pytest +from kernels.fused_softmax import softmax + +@pytest.mark.parametrize("input_size", [(1024, 1024), (512, 512), (2048, 512)]) +def test_softmax_equivalence(input_size): + # Create random input tensor of specified size + x = torch.randn(*input_size).cuda() + + # Compute softmax using PyTorch + pytorch_softmax = F.softmax(x, dim=-1) + + # Compute softmax using Triton + triton_output = softmax(x) + + # Assert that both outputs are approximately equal + assert torch.allclose(pytorch_softmax, triton_output, atol=1e-5), \ + f"Triton softmax output doesn't match PyTorch softmax for input size {input_size}"