From a9cb93a7590cc391f7aed4f3a62cf234d7b277f7 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Fri, 20 Sep 2024 01:29:01 -0700 Subject: [PATCH 01/12] update readme + silu_mul --- README.md | 8 ++++++++ kernels/silu_mul.py | 26 ++++++++++++++++++++++++++ 2 files changed, 34 insertions(+) create mode 100644 README.md create mode 100644 kernels/silu_mul.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..02422d9 --- /dev/null +++ b/README.md @@ -0,0 +1,8 @@ +# Triton Kernels + +## Supported Kernels +* flash attention +* matmul +* cross entropy + +## Contributing \ No newline at end of file diff --git a/kernels/silu_mul.py b/kernels/silu_mul.py new file mode 100644 index 0000000..70b1b32 --- /dev/null +++ b/kernels/silu_mul.py @@ -0,0 +1,26 @@ +import logging + +import torch +import triton +import triton.language as tl + +from ..utils import pointwise_dynamic + + +@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) +@triton.jit +def silu_and_mul_kernel(x, y): + x_fp32 = x.to(tl.float32) + x_silu = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32))) + return x_silu * y + + +class SiluAndMul(torch.autograd.Function): + @staticmethod + def forward(ctx, A, B): + logging.debug("GEMS SILU AND MUL FORWARD") + return silu_and_mul_kernel(A, B) + + +def silu_and_mul(A, B): + return SiluAndMul.apply(A, B) From 64757661a37df9db00ccc6a0697856fc715eb89f Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Fri, 20 Sep 2024 01:29:21 -0700 Subject: [PATCH 02/12] Update gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index d300e33..9e9fb12 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,6 @@ __pycache__/ .pytest_cache **/.cache **/meta-llama/**/* + +# Virtual Environment +venv/ \ No newline at end of file From 426b934a5ea2049a79d8a5567b1586bc1e01756d Mon Sep 17 00:00:00 2001 From: catherinelee274 Date: Fri, 20 Sep 2024 16:22:09 -0700 Subject: [PATCH 03/12] update readme --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 02422d9..f3a80d8 100644 --- a/README.md +++ b/README.md @@ -5,4 +5,5 @@ * matmul * cross entropy -## Contributing \ No newline at end of file +## Contributing +python3 main.py llama_chat_completion --benchmark --ckpt_path --tokenizer_path \ No newline at end of file From 7c84e9db6a2a76c254bf281d3813aeb11819cfa5 Mon Sep 17 00:00:00 2001 From: catherinelee274 Date: Mon, 23 Sep 2024 03:26:43 -0700 Subject: [PATCH 04/12] updates --- README.md | 55 ++++++++++++- benchmarking/benchmark_utils.py | 5 +- kernels/__init__.py | 5 +- kernels/fused_softmax.py | 128 +++++++++++++++++++++++++++++++ models/llama/llama/generation.py | 8 +- test/test_softmax.py | 19 +++++ 6 files changed, 215 insertions(+), 5 deletions(-) create mode 100644 kernels/fused_softmax.py create mode 100644 test/test_softmax.py diff --git a/README.md b/README.md index f3a80d8..66a8594 100644 --- a/README.md +++ b/README.md @@ -6,4 +6,57 @@ * cross entropy ## Contributing -python3 main.py llama_chat_completion --benchmark --ckpt_path --tokenizer_path \ No newline at end of file +python3 main.py llama_chat_completion --benchmark --ckpt_dir --tokenizer_path + + +## Getting started + + + +* Install dependencies + + + +```bash + +python3 -m pip install -r requirements.txt + +``` + + + +* Download llama model + + + +```bash + +export HF_TOKEN=xxx + + + +huggingface-cli download meta-llama/Meta-Llama-3-8B-Instruct \ + +--local-dir $HOME/models/llama-3-8b-instruct \ + +--local-dir-use-symlinks False + +``` + + + +* Clone repo and get branch + + + +```bash + +git clone https://github.com/shelbyt/kernels.git + +cd kernel + +git checkout cudamode + +python3 test_llama.py + +``` \ No newline at end of file diff --git a/benchmarking/benchmark_utils.py b/benchmarking/benchmark_utils.py index 3477d12..74f6c2f 100644 --- a/benchmarking/benchmark_utils.py +++ b/benchmarking/benchmark_utils.py @@ -1,6 +1,6 @@ from typing import Any, Dict import pandas as pd - +import os def compare_benchmarks(benchmarks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: series_dict = {k: pd.Series(v.values()) for k, v in benchmarks.items()} @@ -24,5 +24,8 @@ def compare_benchmarks(benchmarks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: columns = [c for c in df.columns if not "kernel" in c] columns = ["kernel", "kernel_path"] + columns df = df[columns] + folder_name = '.results' + csv_file_path = os.path.jsoin(folder_name, 'output.csv') + df.save_csv(csv_file_path, index=False) df.set_index("kernel", inplace=True) return df diff --git a/kernels/__init__.py b/kernels/__init__.py index dd492d3..ddf7da9 100644 --- a/kernels/__init__.py +++ b/kernels/__init__.py @@ -1,6 +1,7 @@ # from .conv import _conv, conv from . import blocksparse from .cross_entropy import _cross_entropy, cross_entropy +from .fused_softmax import TritonSoftmax, triton_softmax from .flash_attention import attention from .matmul import _matmul, get_higher_dtype, matmul @@ -12,4 +13,6 @@ "matmul", "attention", "get_higher_dtype", -] + "TritonSoftmax", + "triton_softmax" +] \ No newline at end of file diff --git a/kernels/fused_softmax.py b/kernels/fused_softmax.py new file mode 100644 index 0000000..753ad88 --- /dev/null +++ b/kernels/fused_softmax.py @@ -0,0 +1,128 @@ +import torch +import triton +import triton.language as tl +from triton.runtime import driver + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def is_cdna(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', + 'gfx90a', 'gfx908') + + + + +@triton.jit +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, + num_stages: tl.constexpr): + # starting row of the program + row_start = tl.program_id(0) + row_step = tl.num_programs(0) + for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): + # The stride represents how much we need to increase the pointer to advance 1 row + row_start_ptr = input_ptr + row_idx * input_row_stride + # The block size is the next power of two greater than n_cols, so we can fit each + # row in a single block + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + mask = col_offsets < n_cols + row = tl.load(input_ptrs, mask=mask, other=-float('inf')) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=0) + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + # Write back output to DRAM + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=mask) + +def softmax(x): + device = torch.cuda.current_device() + properties = driver.active.utils.get_device_properties(device) + NUM_SM = properties["multiprocessor_count"] + NUM_REGS = properties["max_num_regs"] + SIZE_SMEM = properties["max_shared_mem"] + WARP_SIZE = properties["warpSize"] + target = triton.runtime.driver.active.get_current_target() + kernels = {} + + n_rows, n_cols = x.shape + + # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` + BLOCK_SIZE = triton.next_power_of_2(n_cols) + + # Another trick we can use is to ask the compiler to use more threads per row by + # increasing the number of warps (`num_warps`) over which each row is distributed. + # You will see in the next tutorial how to auto-tune this value in a more natural + # way so you don't have to come up with manual heuristics yourself. + num_warps = 8 + + # Number of software piepling stages. + # num_stages = 4 if SIZE_SMEM > 200000 else 2 + num_stages = 1 + + # Allocate output + y = torch.empty_like(x) + + # pre-compile kernel to get register usage and compute thread occupancy. + kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0)) + if kernel is None: + kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, + num_stages=num_stages, num_warps=num_warps, grid=(1, )) + kernel._init_handles() + n_regs = kernel.n_regs + size_smem = kernel.metadata.shared + if is_hip(): + # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available. + # However, this is not always the case. In most cases all registers can be used as regular purpose registers. + # ISA SECTION (3.6.4 for CDNA3) + # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used + # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total + # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is + # not required to be equal numbers of both types. + if is_cdna(): + NUM_GPRS = NUM_REGS * 2 + + # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor. + # When we divide this number with WARP_SIZE we get maximum number of waves that can + # execute on a CU (multi-processor) in parallel. + MAX_NUM_THREADS = properties["max_threads_per_sm"] + max_num_waves = MAX_NUM_THREADS // WARP_SIZE + occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps + else: + occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) + occupancy = min(occupancy, SIZE_SMEM // size_smem) + num_programs = NUM_SM * occupancy + kernels[BLOCK_SIZE] = (kernel, num_programs) + + num_programs = min(num_programs, n_rows) + + # Create a number of persistent programs. + kernel[(num_programs, 1, 1)]( + y, + x, + x.stride(0), + y.stride(0), + n_rows, + n_cols, + ) + return y + + + +class TritonSoftmax(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return softmax(x) + + # @staticmethod + # def backward(ctx, grad_output): + # return grad_output, grad_output + +triton_softmax = TritonSoftmax.apply \ No newline at end of file diff --git a/models/llama/llama/generation.py b/models/llama/llama/generation.py index 1addd8a..c62631d 100644 --- a/models/llama/llama/generation.py +++ b/models/llama/llama/generation.py @@ -19,7 +19,7 @@ from .math_ops import MathOps from .tokenizer import ChatFormat, Dialog, Message, Tokenizer from benchmarking import Profiler - +from kernels.fused_softmax import triton_softmax class CompletionPrediction(TypedDict, total=False): generation: str @@ -193,7 +193,11 @@ def generate( for cur_pos in range(min_prompt_len, total_len): logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) if temperature > 0: - probs = self.Math.softmax(logits[:, -1] / temperature, dim=-1) + if self.use_triton: + probs = triton_softmax(logits[:,-1]) + else: + probs = self.Math.softmax(logits[:, -1] / temperature, dim=-1) # TODO: softmax + next_token = sample_top_p(probs, top_p) else: next_token = self.Math.argmax(logits[:, -1], dim=-1) diff --git a/test/test_softmax.py b/test/test_softmax.py new file mode 100644 index 0000000..0c475b9 --- /dev/null +++ b/test/test_softmax.py @@ -0,0 +1,19 @@ +import torch +import torch.nn.functional as F +import pytest +from fused_softmax import triton_softmax # Import your Triton softmax function + +@pytest.mark.parametrize("input_size", [(1024, 1024), (512, 512), (2048, 512)]) +def test_softmax_equivalence(input_size): + # Create random input tensor of specified size + x = torch.randn(*input_size).cuda() + + # Compute softmax using PyTorch + pytorch_softmax = F.softmax(x, dim=-1) + + # Compute softmax using Triton + triton_output = triton_softmax(x) + + # Assert that both outputs are approximately equal + assert torch.allclose(pytorch_softmax, triton_output, atol=1e-5), \ + f"Triton softmax output doesn't match PyTorch softmax for input size {input_size}" From 16c404083dc55009eaf8623f1e7ea74b2d5919fc Mon Sep 17 00:00:00 2001 From: catherinelee274 Date: Tue, 24 Sep 2024 07:50:43 +0000 Subject: [PATCH 05/12] udpates, removing warning --- .results/no_improvements.txt | 33 ++++++++++++++++++++++++++++++++ .results/output.csv | 16 ++++++++++++++++ .results/with_argmax_softmax.txt | 33 ++++++++++++++++++++++++++++++++ .results/with_softmax.txt | 33 ++++++++++++++++++++++++++++++++ README.md | 12 ++++-------- benchmarking/benchmark_utils.py | 21 +++++++++++++++++--- main.py | 4 ++++ models/llama/llama/generation.py | 7 +++++-- test/README.md | 3 +++ test/test_softmax.py | 2 +- 10 files changed, 150 insertions(+), 14 deletions(-) create mode 100644 .results/no_improvements.txt create mode 100644 .results/output.csv create mode 100644 .results/with_argmax_softmax.txt create mode 100644 .results/with_softmax.txt create mode 100644 test/README.md diff --git a/.results/no_improvements.txt b/.results/no_improvements.txt new file mode 100644 index 0000000..cb7f88f --- /dev/null +++ b/.results/no_improvements.txt @@ -0,0 +1,33 @@ +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| kernel | kernel_path | triton | non_triton | triton-non_triton | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| chat_completion | chat_completion | 23.363035631999992 | 23.20719621399985 | 0.15583941800014145 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| chat_completion | chat_completion.chat_completion | 15.086765727000056 | 15.037501877000068 | 0.04926384999998845 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| generate | chat_completion.chat_completion.generate | 15.085606602000098 | 15.036371463999785 | 0.04923513800031287 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| softmax | chat_completion.chat_completion.generate.softmax | 3.286413995026158e-05 | 3.2326578084264846e-05 | 5.375618659967339e-07 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| transformer_forward | chat_completion.chat_completion.generate.transformer_forward | 0.02763666868558919 | 0.0275613940689651 | 7.527461662408877e-05 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| RMSNorm | chat_completion.chat_completion.generate.transformer_forward.RMSNorm | 5.827654970598593e-05 | 5.889426978790332e-05 | -6.177200819173836e-07 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| transform_block_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward | 0.0008483634264688104 | 0.0008461694928391175 | 2.1939336296929223e-06 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| RMSNorm | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.RMSNorm | 5.88654680211981e-05 | 5.911337566830836e-05 | -2.4790764711026176e-07 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| attention_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward | 0.000491074871831931 | 0.0004896370724516646 | 1.437799380266374e-06 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| apply_rotary_emb | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.apply_rotary_emb | 8.806513197629571e-05 | 8.784457036059364e-05 | 2.2056161570207199e-07 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| attention | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention | 0.00017238609349734332 | 0.00017186252022342258 | 5.235732739207336e-07 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| matmul | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.matmul | 4.327956193072774e-05 | 4.3119774561757945e-05 | 1.597873689697973e-07 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| softmax | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.softmax | 2.88178547136412e-05 | 2.8733377409596218e-05 | 8.4477304044983e-08 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| feed_forward_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.feed_forward_forward | 0.00012714412176862338 | 0.00012642639452434758 | 7.177272442757999e-07 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| precompute_freqs_cis | chat_completion.precompute_freqs_cis | 0.0003592040002331487 | 0.00034214400011478574 | 1.706000011836295e-05 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| diff --git a/.results/output.csv b/.results/output.csv new file mode 100644 index 0000000..252b18d --- /dev/null +++ b/.results/output.csv @@ -0,0 +1,16 @@ +kernel,kernel_path,triton,non_triton,triton-non_triton +chat_completion,chat_completion,23.529098738999892,, +chat_completion,chat_completion.chat_completion,15.229127281999808,23.885352947999763,-8.656225665999955 +generate,chat_completion.chat_completion.generate,15.228001847000087,15.719125695999992,-0.491123848999905 +softmax,chat_completion.chat_completion.generate.softmax,3.3104419885907815e-05,15.717982729999676,-15.71794962557979 +transformer_forward,chat_completion.chat_completion.generate.transformer_forward,0.028188293081122473,0.028127436847877902,6.0856233244570984e-05 +RMSNorm,chat_completion.chat_completion.generate.transformer_forward.RMSNorm,5.8095263698646935e-05,5.832987626748923e-05,-2.3461256884229476e-07 +transform_block_forward,chat_completion.chat_completion.generate.transformer_forward.transform_block_forward,0.0008650784662146882,0.0008631522671130063,1.9261991016819406e-06 +RMSNorm,chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.RMSNorm,5.8953483520847516e-05,5.920526296552413e-05,-2.517794446766161e-07 +attention_forward,chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward,0.000502045900099974,0.0005004565054525663,1.5893946474076093e-06 +apply_rotary_emb,chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.apply_rotary_emb,8.903067526494468e-05,8.906119377373757e-05,-3.05185087928929e-08 +attention,chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention,0.00017652477433859322,0.00017632209013441347,2.0268420417975867e-07 +matmul,chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.matmul,4.470880419631436e-05,4.464734606365697e-05,6.145813265738515e-08 +softmax,chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.softmax,2.94319926492972e-05,2.9444132737926863e-05,-1.2140088629663967e-08 +feed_forward_forward,chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.feed_forward_forward,0.00013109435198944578,0.00013054154848626072,5.528035031850662e-07 +precompute_freqs_cis,chat_completion.precompute_freqs_cis,0.0003498339997349831,0.0003586540001379035,-8.820000402920414e-06 diff --git a/.results/with_argmax_softmax.txt b/.results/with_argmax_softmax.txt new file mode 100644 index 0000000..355a9e9 --- /dev/null +++ b/.results/with_argmax_softmax.txt @@ -0,0 +1,33 @@ +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| kernel | kernel_path | triton | non_triton | triton-non_triton | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| chat_completion | chat_completion | 23.316155643000002 | | | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| chat_completion | chat_completion.chat_completion | 15.026287104999938 | 23.7322098059999 | -8.705922700999963 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| generate | chat_completion.chat_completion.generate | 15.025166173999878 | 15.588987290000205 | -0.5638211160003266 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| softmax | chat_completion.chat_completion.generate.softmax | 3.194667139574176e-05 | 15.587871648000146 | -15.58783970132875 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| transformer_forward | chat_completion.chat_completion.generate.transformer_forward | 0.02752115360446059 | 0.027770376318450807 | -0.000249222713990218 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| RMSNorm | chat_completion.chat_completion.generate.transformer_forward.RMSNorm | 5.7826257591456e-05 | 5.821012373749101e-05 | -3.838661460350075e-07 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| transform_block_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward | 0.0008447297065781728 | 0.000852438977687303 | -7.709271109130212e-06 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| RMSNorm | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.RMSNorm | 5.789487043606993e-05 | 5.838420017865719e-05 | -4.893297425872609e-07 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| attention_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward | 0.0004889343329106125 | 0.0004935158465387692 | -4.581513628156746e-06 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| apply_rotary_emb | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.apply_rotary_emb | 8.793432188105586e-05 | 8.822047401022832e-05 | -2.861521291724629e-07 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| attention | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention | 0.00017178861625323266 | 0.00017370242444190314 | -1.9138081886704806e-06 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| matmul | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.matmul | 4.3013189750603796e-05 | 4.373048187156608e-05 | -7.172921209622871e-07 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| softmax | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.softmax | 2.8741690918662413e-05 | 2.8887915059976048e-05 | -1.462241413136349e-07 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| feed_forward_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.feed_forward_forward | 0.00012659588773971398 | 0.00012815826540009848 | -1.5623776603845016e-06 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| precompute_freqs_cis | chat_completion.precompute_freqs_cis | 0.0003504739997879369 | 0.0003422939998927177 | 8.179999895219225e-06 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| diff --git a/.results/with_softmax.txt b/.results/with_softmax.txt new file mode 100644 index 0000000..08610ed --- /dev/null +++ b/.results/with_softmax.txt @@ -0,0 +1,33 @@ +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| kernel | kernel_path | triton | non_triton | triton-non_triton | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| chat_completion | chat_completion | 23.529098738999892 | | | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| chat_completion | chat_completion.chat_completion | 15.229127281999808 | 23.885352947999763 | -8.656225665999955 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| generate | chat_completion.chat_completion.generate | 15.228001847000087 | 15.719125695999992 | -0.491123848999905 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| softmax | chat_completion.chat_completion.generate.softmax | 3.3104419885907815e-05 | 15.717982729999676 | -15.71794962557979 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| transformer_forward | chat_completion.chat_completion.generate.transformer_forward | 0.028188293081122473 | 0.028127436847877902 | 6.0856233244570984e-05 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| RMSNorm | chat_completion.chat_completion.generate.transformer_forward.RMSNorm | 5.8095263698646935e-05 | 5.832987626748923e-05 | -2.3461256884229476e-07 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| transform_block_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward | 0.0008650784662146882 | 0.0008631522671130063 | 1.9261991016819406e-06 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| RMSNorm | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.RMSNorm | 5.8953483520847516e-05 | 5.920526296552413e-05 | -2.517794446766161e-07 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| attention_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward | 0.000502045900099974 | 0.0005004565054525663 | 1.5893946474076093e-06 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| apply_rotary_emb | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.apply_rotary_emb | 8.903067526494468e-05 | 8.906119377373757e-05 | -3.05185087928929e-08 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| attention | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention | 0.00017652477433859322 | 0.00017632209013441347 | 2.0268420417975867e-07 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| matmul | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.matmul | 4.470880419631436e-05 | 4.464734606365697e-05 | 6.145813265738515e-08 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| softmax | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.softmax | 2.94319926492972e-05 | 2.9444132737926863e-05 | -1.2140088629663967e-08 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| feed_forward_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.feed_forward_forward | 0.00013109435198944578 | 0.00013054154848626072 | 5.528035031850662e-07 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| +| precompute_freqs_cis | chat_completion.precompute_freqs_cis | 0.0003498339997349831 | 0.0003586540001379035 | -8.820000402920414e-06 | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| diff --git a/README.md b/README.md index 66a8594..7332236 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,9 @@ * cross entropy ## Contributing +``` python3 main.py llama_chat_completion --benchmark --ckpt_dir --tokenizer_path +``` ## Getting started @@ -33,17 +35,11 @@ python3 -m pip install -r requirements.txt export HF_TOKEN=xxx - - -huggingface-cli download meta-llama/Meta-Llama-3-8B-Instruct \ - ---local-dir $HOME/models/llama-3-8b-instruct \ - ---local-dir-use-symlinks False +huggingface-cli download meta-llama/Meta-Llama-3-8B-Instruct --local-dir $HOME/models/llama-3-8b-instruct ``` - + * Clone repo and get branch diff --git a/benchmarking/benchmark_utils.py b/benchmarking/benchmark_utils.py index 74f6c2f..512aa8f 100644 --- a/benchmarking/benchmark_utils.py +++ b/benchmarking/benchmark_utils.py @@ -2,6 +2,19 @@ import pandas as pd import os +def save_metrics(df): + folder_name = '.results' + + # Check if the folder exists, if not, create it + if not os.path.exists(folder_name): + os.makedirs(folder_name) + + csv_file_path = os.path.join(folder_name, 'output.csv') + + df.to_csv(csv_file_path, index=False) + print(f"DataFrame saved to {csv_file_path}") + + def compare_benchmarks(benchmarks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: series_dict = {k: pd.Series(v.values()) for k, v in benchmarks.items()} series_dict["kernel_path"] = pd.Series( @@ -24,8 +37,10 @@ def compare_benchmarks(benchmarks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: columns = [c for c in df.columns if not "kernel" in c] columns = ["kernel", "kernel_path"] + columns df = df[columns] - folder_name = '.results' - csv_file_path = os.path.jsoin(folder_name, 'output.csv') - df.save_csv(csv_file_path, index=False) + + + save_metrics(df) + + df.set_index("kernel", inplace=True) return df diff --git a/main.py b/main.py index eac1c68..1e33abd 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ from models.llama import llama_example_chat_completion, llama_example_text_completion from benchmarking import Profiler, compare_benchmarks import pprint +import torch.distributed as dist def main(operation: str, profile=False, benchmark=False, **kwargs): @@ -65,6 +66,9 @@ def main(operation: str, profile=False, benchmark=False, **kwargs): print(output) print("\n==================================\n") + if dist.is_initialized(): + dist.destroy_process_group() + def runner(operation: str, kwargs): if operation == "llama_chat_completion": diff --git a/models/llama/llama/generation.py b/models/llama/llama/generation.py index c62631d..8b3e533 100644 --- a/models/llama/llama/generation.py +++ b/models/llama/llama/generation.py @@ -196,11 +196,14 @@ def generate( if self.use_triton: probs = triton_softmax(logits[:,-1]) else: - probs = self.Math.softmax(logits[:, -1] / temperature, dim=-1) # TODO: softmax + probs = self.Math.softmax(logits[:, -1] / temperature, dim=-1) next_token = sample_top_p(probs, top_p) else: - next_token = self.Math.argmax(logits[:, -1], dim=-1) + if self.use_triton: + next_token = self.triton.language.argmax(logits[:, -1], axis=-1) + else: + next_token = self.Math.argmax(logits[:, -1], dim=-1) next_token = next_token.reshape(-1) # only replace token if prompt has already been generated diff --git a/test/README.md b/test/README.md new file mode 100644 index 0000000..a8aeef6 --- /dev/null +++ b/test/README.md @@ -0,0 +1,3 @@ +# Unit Tests for Kernels + +Invoke by calling ```pytest test/``` \ No newline at end of file diff --git a/test/test_softmax.py b/test/test_softmax.py index 0c475b9..6f444ab 100644 --- a/test/test_softmax.py +++ b/test/test_softmax.py @@ -1,7 +1,7 @@ import torch import torch.nn.functional as F import pytest -from fused_softmax import triton_softmax # Import your Triton softmax function +from kernels.fused_softmax import triton_softmax @pytest.mark.parametrize("input_size", [(1024, 1024), (512, 512), (2048, 512)]) def test_softmax_equivalence(input_size): From 7e47ddd437881a072e505e0856e039e7d5068ff3 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Tue, 24 Sep 2024 01:00:39 -0700 Subject: [PATCH 06/12] Remove files from commit but keep them locally --- .results/no_improvements.txt | 33 ----------------------- .results/output.csv | 16 ----------- .results/with_argmax_softmax.txt | 33 ----------------------- .results/with_softmax.txt | 33 ----------------------- benchmarking/benchmark_utils.py | 46 -------------------------------- kernels/silu_mul.py | 26 ------------------ 6 files changed, 187 deletions(-) delete mode 100644 .results/no_improvements.txt delete mode 100644 .results/output.csv delete mode 100644 .results/with_argmax_softmax.txt delete mode 100644 .results/with_softmax.txt delete mode 100644 benchmarking/benchmark_utils.py delete mode 100644 kernels/silu_mul.py diff --git a/.results/no_improvements.txt b/.results/no_improvements.txt deleted file mode 100644 index cb7f88f..0000000 --- a/.results/no_improvements.txt +++ /dev/null @@ -1,33 +0,0 @@ -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| kernel | kernel_path | triton | non_triton | triton-non_triton | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| chat_completion | chat_completion | 23.363035631999992 | 23.20719621399985 | 0.15583941800014145 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| chat_completion | chat_completion.chat_completion | 15.086765727000056 | 15.037501877000068 | 0.04926384999998845 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| generate | chat_completion.chat_completion.generate | 15.085606602000098 | 15.036371463999785 | 0.04923513800031287 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| softmax | chat_completion.chat_completion.generate.softmax | 3.286413995026158e-05 | 3.2326578084264846e-05 | 5.375618659967339e-07 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| transformer_forward | chat_completion.chat_completion.generate.transformer_forward | 0.02763666868558919 | 0.0275613940689651 | 7.527461662408877e-05 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| RMSNorm | chat_completion.chat_completion.generate.transformer_forward.RMSNorm | 5.827654970598593e-05 | 5.889426978790332e-05 | -6.177200819173836e-07 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| transform_block_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward | 0.0008483634264688104 | 0.0008461694928391175 | 2.1939336296929223e-06 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| RMSNorm | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.RMSNorm | 5.88654680211981e-05 | 5.911337566830836e-05 | -2.4790764711026176e-07 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| attention_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward | 0.000491074871831931 | 0.0004896370724516646 | 1.437799380266374e-06 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| apply_rotary_emb | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.apply_rotary_emb | 8.806513197629571e-05 | 8.784457036059364e-05 | 2.2056161570207199e-07 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| attention | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention | 0.00017238609349734332 | 0.00017186252022342258 | 5.235732739207336e-07 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| matmul | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.matmul | 4.327956193072774e-05 | 4.3119774561757945e-05 | 1.597873689697973e-07 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| softmax | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.softmax | 2.88178547136412e-05 | 2.8733377409596218e-05 | 8.4477304044983e-08 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| feed_forward_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.feed_forward_forward | 0.00012714412176862338 | 0.00012642639452434758 | 7.177272442757999e-07 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| precompute_freqs_cis | chat_completion.precompute_freqs_cis | 0.0003592040002331487 | 0.00034214400011478574 | 1.706000011836295e-05 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| diff --git a/.results/output.csv b/.results/output.csv deleted file mode 100644 index 252b18d..0000000 --- a/.results/output.csv +++ /dev/null @@ -1,16 +0,0 @@ -kernel,kernel_path,triton,non_triton,triton-non_triton -chat_completion,chat_completion,23.529098738999892,, -chat_completion,chat_completion.chat_completion,15.229127281999808,23.885352947999763,-8.656225665999955 -generate,chat_completion.chat_completion.generate,15.228001847000087,15.719125695999992,-0.491123848999905 -softmax,chat_completion.chat_completion.generate.softmax,3.3104419885907815e-05,15.717982729999676,-15.71794962557979 -transformer_forward,chat_completion.chat_completion.generate.transformer_forward,0.028188293081122473,0.028127436847877902,6.0856233244570984e-05 -RMSNorm,chat_completion.chat_completion.generate.transformer_forward.RMSNorm,5.8095263698646935e-05,5.832987626748923e-05,-2.3461256884229476e-07 -transform_block_forward,chat_completion.chat_completion.generate.transformer_forward.transform_block_forward,0.0008650784662146882,0.0008631522671130063,1.9261991016819406e-06 -RMSNorm,chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.RMSNorm,5.8953483520847516e-05,5.920526296552413e-05,-2.517794446766161e-07 -attention_forward,chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward,0.000502045900099974,0.0005004565054525663,1.5893946474076093e-06 -apply_rotary_emb,chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.apply_rotary_emb,8.903067526494468e-05,8.906119377373757e-05,-3.05185087928929e-08 -attention,chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention,0.00017652477433859322,0.00017632209013441347,2.0268420417975867e-07 -matmul,chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.matmul,4.470880419631436e-05,4.464734606365697e-05,6.145813265738515e-08 -softmax,chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.softmax,2.94319926492972e-05,2.9444132737926863e-05,-1.2140088629663967e-08 -feed_forward_forward,chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.feed_forward_forward,0.00013109435198944578,0.00013054154848626072,5.528035031850662e-07 -precompute_freqs_cis,chat_completion.precompute_freqs_cis,0.0003498339997349831,0.0003586540001379035,-8.820000402920414e-06 diff --git a/.results/with_argmax_softmax.txt b/.results/with_argmax_softmax.txt deleted file mode 100644 index 355a9e9..0000000 --- a/.results/with_argmax_softmax.txt +++ /dev/null @@ -1,33 +0,0 @@ -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| kernel | kernel_path | triton | non_triton | triton-non_triton | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| chat_completion | chat_completion | 23.316155643000002 | | | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| chat_completion | chat_completion.chat_completion | 15.026287104999938 | 23.7322098059999 | -8.705922700999963 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| generate | chat_completion.chat_completion.generate | 15.025166173999878 | 15.588987290000205 | -0.5638211160003266 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| softmax | chat_completion.chat_completion.generate.softmax | 3.194667139574176e-05 | 15.587871648000146 | -15.58783970132875 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| transformer_forward | chat_completion.chat_completion.generate.transformer_forward | 0.02752115360446059 | 0.027770376318450807 | -0.000249222713990218 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| RMSNorm | chat_completion.chat_completion.generate.transformer_forward.RMSNorm | 5.7826257591456e-05 | 5.821012373749101e-05 | -3.838661460350075e-07 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| transform_block_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward | 0.0008447297065781728 | 0.000852438977687303 | -7.709271109130212e-06 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| RMSNorm | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.RMSNorm | 5.789487043606993e-05 | 5.838420017865719e-05 | -4.893297425872609e-07 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| attention_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward | 0.0004889343329106125 | 0.0004935158465387692 | -4.581513628156746e-06 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| apply_rotary_emb | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.apply_rotary_emb | 8.793432188105586e-05 | 8.822047401022832e-05 | -2.861521291724629e-07 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| attention | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention | 0.00017178861625323266 | 0.00017370242444190314 | -1.9138081886704806e-06 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| matmul | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.matmul | 4.3013189750603796e-05 | 4.373048187156608e-05 | -7.172921209622871e-07 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| softmax | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.softmax | 2.8741690918662413e-05 | 2.8887915059976048e-05 | -1.462241413136349e-07 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| feed_forward_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.feed_forward_forward | 0.00012659588773971398 | 0.00012815826540009848 | -1.5623776603845016e-06 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| precompute_freqs_cis | chat_completion.precompute_freqs_cis | 0.0003504739997879369 | 0.0003422939998927177 | 8.179999895219225e-06 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| diff --git a/.results/with_softmax.txt b/.results/with_softmax.txt deleted file mode 100644 index 08610ed..0000000 --- a/.results/with_softmax.txt +++ /dev/null @@ -1,33 +0,0 @@ -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| kernel | kernel_path | triton | non_triton | triton-non_triton | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| chat_completion | chat_completion | 23.529098738999892 | | | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| chat_completion | chat_completion.chat_completion | 15.229127281999808 | 23.885352947999763 | -8.656225665999955 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| generate | chat_completion.chat_completion.generate | 15.228001847000087 | 15.719125695999992 | -0.491123848999905 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| softmax | chat_completion.chat_completion.generate.softmax | 3.3104419885907815e-05 | 15.717982729999676 | -15.71794962557979 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| transformer_forward | chat_completion.chat_completion.generate.transformer_forward | 0.028188293081122473 | 0.028127436847877902 | 6.0856233244570984e-05 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| RMSNorm | chat_completion.chat_completion.generate.transformer_forward.RMSNorm | 5.8095263698646935e-05 | 5.832987626748923e-05 | -2.3461256884229476e-07 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| transform_block_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward | 0.0008650784662146882 | 0.0008631522671130063 | 1.9261991016819406e-06 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| RMSNorm | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.RMSNorm | 5.8953483520847516e-05 | 5.920526296552413e-05 | -2.517794446766161e-07 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| attention_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward | 0.000502045900099974 | 0.0005004565054525663 | 1.5893946474076093e-06 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| apply_rotary_emb | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.apply_rotary_emb | 8.903067526494468e-05 | 8.906119377373757e-05 | -3.05185087928929e-08 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| attention | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention | 0.00017652477433859322 | 0.00017632209013441347 | 2.0268420417975867e-07 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| matmul | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.matmul | 4.470880419631436e-05 | 4.464734606365697e-05 | 6.145813265738515e-08 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| softmax | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.attention_forward.attention.softmax | 2.94319926492972e-05 | 2.9444132737926863e-05 | -1.2140088629663967e-08 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| feed_forward_forward | chat_completion.chat_completion.generate.transformer_forward.transform_block_forward.feed_forward_forward | 0.00013109435198944578 | 0.00013054154848626072 | 5.528035031850662e-07 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| -| precompute_freqs_cis | chat_completion.precompute_freqs_cis | 0.0003498339997349831 | 0.0003586540001379035 | -8.820000402920414e-06 | -|-------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------|------------------------|-------------------------| diff --git a/benchmarking/benchmark_utils.py b/benchmarking/benchmark_utils.py deleted file mode 100644 index 512aa8f..0000000 --- a/benchmarking/benchmark_utils.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Any, Dict -import pandas as pd -import os - -def save_metrics(df): - folder_name = '.results' - - # Check if the folder exists, if not, create it - if not os.path.exists(folder_name): - os.makedirs(folder_name) - - csv_file_path = os.path.join(folder_name, 'output.csv') - - df.to_csv(csv_file_path, index=False) - print(f"DataFrame saved to {csv_file_path}") - - -def compare_benchmarks(benchmarks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: - series_dict = {k: pd.Series(v.values()) for k, v in benchmarks.items()} - series_dict["kernel_path"] = pd.Series( - benchmarks[list(benchmarks.keys())[0]].keys() - ) - series_dict["kernel"] = pd.Series( - [k.split(".")[-1] for k in series_dict["kernel_path"]] - ) - df = pd.DataFrame() - - for k, v in series_dict.items(): - df[k] = v - columns = [c for c in df.columns if not "kernel" in c] - for i in range(len(columns)): - for j in range(i + 1, len(columns)): - # calculate the difference between the two columns - diff_col_name = f"{columns[i]}-{columns[j]}" - df[diff_col_name] = df[columns[i]] - df[columns[j]] - df.sort_values(by="kernel_path", inplace=True) - columns = [c for c in df.columns if not "kernel" in c] - columns = ["kernel", "kernel_path"] + columns - df = df[columns] - - - save_metrics(df) - - - df.set_index("kernel", inplace=True) - return df diff --git a/kernels/silu_mul.py b/kernels/silu_mul.py deleted file mode 100644 index 70b1b32..0000000 --- a/kernels/silu_mul.py +++ /dev/null @@ -1,26 +0,0 @@ -import logging - -import torch -import triton -import triton.language as tl - -from ..utils import pointwise_dynamic - - -@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) -@triton.jit -def silu_and_mul_kernel(x, y): - x_fp32 = x.to(tl.float32) - x_silu = tl.fdiv(x_fp32, (1.0 + tl.exp(-x_fp32))) - return x_silu * y - - -class SiluAndMul(torch.autograd.Function): - @staticmethod - def forward(ctx, A, B): - logging.debug("GEMS SILU AND MUL FORWARD") - return silu_and_mul_kernel(A, B) - - -def silu_and_mul(A, B): - return SiluAndMul.apply(A, B) From a8759f1ac743afcbcd75e63b14e4588131312104 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Tue, 24 Sep 2024 01:00:58 -0700 Subject: [PATCH 07/12] update README.md --- README.md | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/README.md b/README.md index 7332236..c44abb5 100644 --- a/README.md +++ b/README.md @@ -37,22 +37,4 @@ export HF_TOKEN=xxx huggingface-cli download meta-llama/Meta-Llama-3-8B-Instruct --local-dir $HOME/models/llama-3-8b-instruct -``` - - - -* Clone repo and get branch - - - -```bash - -git clone https://github.com/shelbyt/kernels.git - -cd kernel - -git checkout cudamode - -python3 test_llama.py - ``` \ No newline at end of file From 213c407dedb8af8eb2e092ac38b53cc376e90fe9 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Tue, 24 Sep 2024 22:26:25 -0700 Subject: [PATCH 08/12] Add back benchmarking_utils --- benchmarking/benchmark_utils.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 benchmarking/benchmark_utils.py diff --git a/benchmarking/benchmark_utils.py b/benchmarking/benchmark_utils.py new file mode 100644 index 0000000..3477d12 --- /dev/null +++ b/benchmarking/benchmark_utils.py @@ -0,0 +1,28 @@ +from typing import Any, Dict +import pandas as pd + + +def compare_benchmarks(benchmarks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: + series_dict = {k: pd.Series(v.values()) for k, v in benchmarks.items()} + series_dict["kernel_path"] = pd.Series( + benchmarks[list(benchmarks.keys())[0]].keys() + ) + series_dict["kernel"] = pd.Series( + [k.split(".")[-1] for k in series_dict["kernel_path"]] + ) + df = pd.DataFrame() + + for k, v in series_dict.items(): + df[k] = v + columns = [c for c in df.columns if not "kernel" in c] + for i in range(len(columns)): + for j in range(i + 1, len(columns)): + # calculate the difference between the two columns + diff_col_name = f"{columns[i]}-{columns[j]}" + df[diff_col_name] = df[columns[i]] - df[columns[j]] + df.sort_values(by="kernel_path", inplace=True) + columns = [c for c in df.columns if not "kernel" in c] + columns = ["kernel", "kernel_path"] + columns + df = df[columns] + df.set_index("kernel", inplace=True) + return df From bec93123c8774d18e92871cad534f379820bad75 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Thu, 3 Oct 2024 00:11:01 -0700 Subject: [PATCH 09/12] updates--need to debug argmax issue --- models/llama/llama/generation.py | 20 ++++++++++++-------- models/llama/llama/math_ops.py | 6 ++++-- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/models/llama/llama/generation.py b/models/llama/llama/generation.py index 8b3e533..64782aa 100644 --- a/models/llama/llama/generation.py +++ b/models/llama/llama/generation.py @@ -193,17 +193,21 @@ def generate( for cur_pos in range(min_prompt_len, total_len): logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) if temperature > 0: - if self.use_triton: - probs = triton_softmax(logits[:,-1]) - else: - probs = self.Math.softmax(logits[:, -1] / temperature, dim=-1) + # if self.use_triton: + # probs = triton_softmax(logits[:,-1]) + # else: + # probs = self.Math.softmax(logits[:, -1] / temperature, dim=-1) + MathOps.softmax(logits[:, -1] / temperature, dim=-1) + + next_token = sample_top_p(probs, top_p) else: - if self.use_triton: - next_token = self.triton.language.argmax(logits[:, -1], axis=-1) - else: - next_token = self.Math.argmax(logits[:, -1], dim=-1) + # if self.use_triton: + # next_token = self.triton.language.argmax(logits[:, -1], axis=-1) + # else: + # next_token = self.Math.argmax(logits[:, -1], dim=-1) + MathOps.argmax(logits[:,-1], dim = -1) next_token = next_token.reshape(-1) # only replace token if prompt has already been generated diff --git a/models/llama/llama/math_ops.py b/models/llama/llama/math_ops.py index 5ef81b2..5ddbf4d 100644 --- a/models/llama/llama/math_ops.py +++ b/models/llama/llama/math_ops.py @@ -8,6 +8,7 @@ from kernels.cross_entropy import cross_entropy from kernels.matmul import matmul from kernels.flash_attention import attention +from kernels.fused_softmax import triton_softmax from benchmarking import Profiler import time @@ -70,14 +71,15 @@ def attention(self, xq, keys, values, head_dim, mask): @Profiler.profiling_decorator("softmax") def softmax(self, x, dim): - if self.use_triton: - return F.softmax(x, dim=-1) + if self.use_triton and len(x) == 2: + return triton_softmax(x, dim=-1) else: return F.softmax(x, dim=-1) @Profiler.profiling_decorator("argmax") def argmax(self, x, dim): if self.use_triton: + # TODO: change return torch.argmax(x, dim=-1) else: return torch.argmax(x, dim=-1) From 88e1bed8ec526c11cedf19e5bce4d3fc4e67ed7c Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Mon, 7 Oct 2024 18:36:43 -0700 Subject: [PATCH 10/12] Put softmax in MathOps - Rename certain functions to conform with naming scheme - Current triton softmax does not handle > 2 dimensions but will need to investigate (probably by looking at llama.cpp) --- kernels/__init__.py | 6 +++--- kernels/fused_softmax.py | 12 ++++-------- models/llama/llama/generation.py | 16 ++-------------- models/llama/llama/math_ops.py | 13 ++++++------- test/README.md | 3 ++- test/test_softmax.py | 4 ++-- 6 files changed, 19 insertions(+), 35 deletions(-) diff --git a/kernels/__init__.py b/kernels/__init__.py index ddf7da9..b8547ec 100644 --- a/kernels/__init__.py +++ b/kernels/__init__.py @@ -1,7 +1,7 @@ # from .conv import _conv, conv from . import blocksparse from .cross_entropy import _cross_entropy, cross_entropy -from .fused_softmax import TritonSoftmax, triton_softmax +from .fused_softmax import _softmax, softmax from .flash_attention import attention from .matmul import _matmul, get_higher_dtype, matmul @@ -13,6 +13,6 @@ "matmul", "attention", "get_higher_dtype", - "TritonSoftmax", - "triton_softmax" + "_softmax", + "softmax" ] \ No newline at end of file diff --git a/kernels/fused_softmax.py b/kernels/fused_softmax.py index 753ad88..76c9ad7 100644 --- a/kernels/fused_softmax.py +++ b/kernels/fused_softmax.py @@ -7,14 +7,10 @@ def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" - def is_cdna(): return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', 'gfx90a', 'gfx908') - - - @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr): @@ -42,7 +38,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, softmax_output, mask=mask) -def softmax(x): +def triton_softmax(x): device = torch.cuda.current_device() properties = driver.active.utils.get_device_properties(device) NUM_SM = properties["multiprocessor_count"] @@ -116,13 +112,13 @@ def softmax(x): -class TritonSoftmax(torch.autograd.Function): +class _softmax(torch.autograd.Function): @staticmethod def forward(ctx, x): - return softmax(x) + return triton_softmax(x) # @staticmethod # def backward(ctx, grad_output): # return grad_output, grad_output -triton_softmax = TritonSoftmax.apply \ No newline at end of file +softmax = _softmax.apply \ No newline at end of file diff --git a/models/llama/llama/generation.py b/models/llama/llama/generation.py index 64782aa..3fa0185 100644 --- a/models/llama/llama/generation.py +++ b/models/llama/llama/generation.py @@ -19,7 +19,6 @@ from .math_ops import MathOps from .tokenizer import ChatFormat, Dialog, Message, Tokenizer from benchmarking import Profiler -from kernels.fused_softmax import triton_softmax class CompletionPrediction(TypedDict, total=False): generation: str @@ -193,21 +192,10 @@ def generate( for cur_pos in range(min_prompt_len, total_len): logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) if temperature > 0: - # if self.use_triton: - # probs = triton_softmax(logits[:,-1]) - # else: - # probs = self.Math.softmax(logits[:, -1] / temperature, dim=-1) - MathOps.softmax(logits[:, -1] / temperature, dim=-1) - - - + probs = self.Math.softmax(logits[:, -1] / temperature, dim=-1) next_token = sample_top_p(probs, top_p) else: - # if self.use_triton: - # next_token = self.triton.language.argmax(logits[:, -1], axis=-1) - # else: - # next_token = self.Math.argmax(logits[:, -1], dim=-1) - MathOps.argmax(logits[:,-1], dim = -1) + self.Math.argmax(logits[:,-1], dim=-1) next_token = next_token.reshape(-1) # only replace token if prompt has already been generated diff --git a/models/llama/llama/math_ops.py b/models/llama/llama/math_ops.py index 5ddbf4d..dfddc89 100644 --- a/models/llama/llama/math_ops.py +++ b/models/llama/llama/math_ops.py @@ -8,7 +8,7 @@ from kernels.cross_entropy import cross_entropy from kernels.matmul import matmul from kernels.flash_attention import attention -from kernels.fused_softmax import triton_softmax +from kernels.fused_softmax import softmax from benchmarking import Profiler import time @@ -71,18 +71,17 @@ def attention(self, xq, keys, values, head_dim, mask): @Profiler.profiling_decorator("softmax") def softmax(self, x, dim): - if self.use_triton and len(x) == 2: - return triton_softmax(x, dim=-1) + if self.use_triton and x.ndim == 2: + return softmax(x) else: - return F.softmax(x, dim=-1) + return F.softmax(x, dim) @Profiler.profiling_decorator("argmax") def argmax(self, x, dim): if self.use_triton: - # TODO: change - return torch.argmax(x, dim=-1) + return self.triton.language.argmax(x, axis=dim) else: - return torch.argmax(x, dim=-1) + return torch.argmax(x, dim=dim) @Profiler.profiling_decorator("cross_entropy") def cross_entropy(self, input_val, target, reduction, ignore_index): diff --git a/test/README.md b/test/README.md index a8aeef6..6cc47f7 100644 --- a/test/README.md +++ b/test/README.md @@ -1,3 +1,4 @@ # Unit Tests for Kernels -Invoke by calling ```pytest test/``` \ No newline at end of file +Invoke all tests by calling ```pytest test/```. +To run a specific test, call ```pytest test/test_softmax.py```. \ No newline at end of file diff --git a/test/test_softmax.py b/test/test_softmax.py index 6f444ab..45891e2 100644 --- a/test/test_softmax.py +++ b/test/test_softmax.py @@ -1,7 +1,7 @@ import torch import torch.nn.functional as F import pytest -from kernels.fused_softmax import triton_softmax +from kernels.fused_softmax import softmax @pytest.mark.parametrize("input_size", [(1024, 1024), (512, 512), (2048, 512)]) def test_softmax_equivalence(input_size): @@ -12,7 +12,7 @@ def test_softmax_equivalence(input_size): pytorch_softmax = F.softmax(x, dim=-1) # Compute softmax using Triton - triton_output = triton_softmax(x) + triton_output = softmax(x) # Assert that both outputs are approximately equal assert torch.allclose(pytorch_softmax, triton_output, atol=1e-5), \ From 243e22b46ce20de0f430f01f365faf07888daee7 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Mon, 7 Oct 2024 18:41:20 -0700 Subject: [PATCH 11/12] Add back accidentally removed next_token var --- models/llama/llama/generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/llama/llama/generation.py b/models/llama/llama/generation.py index 3fa0185..e6f3fa0 100644 --- a/models/llama/llama/generation.py +++ b/models/llama/llama/generation.py @@ -195,7 +195,7 @@ def generate( probs = self.Math.softmax(logits[:, -1] / temperature, dim=-1) next_token = sample_top_p(probs, top_p) else: - self.Math.argmax(logits[:,-1], dim=-1) + next_token = self.Math.argmax(logits[:,-1], dim=-1) next_token = next_token.reshape(-1) # only replace token if prompt has already been generated From 036e4bc0b57b16f35ebd941e8c69083a30248fe4 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Mon, 7 Oct 2024 18:58:56 -0700 Subject: [PATCH 12/12] self.triton.langauge.argmax -> triton.langauge.argmax --- models/llama/llama/math_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/llama/llama/math_ops.py b/models/llama/llama/math_ops.py index dfddc89..3024eee 100644 --- a/models/llama/llama/math_ops.py +++ b/models/llama/llama/math_ops.py @@ -79,7 +79,7 @@ def softmax(self, x, dim): @Profiler.profiling_decorator("argmax") def argmax(self, x, dim): if self.use_triton: - return self.triton.language.argmax(x, axis=dim) + return triton.language.argmax(x, axis=dim) else: return torch.argmax(x, dim=dim)