From 69138f9767bba95c9244eb7c206564154e015476 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 4 Sep 2024 10:04:41 -0400 Subject: [PATCH] [CLEANUP] --- examples/cython_tests/mqa.pyx | 54 +++++++++++++++++++++++ examples/cython_tests/mqa_test.py | 54 +++++++++++++++++++++++ examples/cython_tests/new_c_example.py | 10 +++++ examples/cython_tests/setup.py | 15 +++++++ examples/cython_tests/torch_extension.pyx | 20 +++++++++ 5 files changed, 153 insertions(+) create mode 100644 examples/cython_tests/mqa.pyx create mode 100644 examples/cython_tests/mqa_test.py create mode 100644 examples/cython_tests/new_c_example.py create mode 100644 examples/cython_tests/setup.py create mode 100644 examples/cython_tests/torch_extension.pyx diff --git a/examples/cython_tests/mqa.pyx b/examples/cython_tests/mqa.pyx new file mode 100644 index 00000000..eda6674c --- /dev/null +++ b/examples/cython_tests/mqa.pyx @@ -0,0 +1,54 @@ +import torch +from torch import nn +cimport cython + +cdef class MultiQueryAttention: + cdef int embed_dim + cdef int num_heads + cdef int head_dim + cdef object query_proj # Treat nn.Linear as a Python object + cdef object key_proj # Treat nn.Linear as a Python object + cdef object value_proj # Treat nn.Linear as a Python object + cdef object out_proj # Treat nn.Linear as a Python object + + def __cinit__(self, int embed_dim, int num_heads): + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + # Initialize nn.Linear layers as regular Python objects + self.query_proj = nn.Linear(embed_dim, embed_dim) + self.key_proj = nn.Linear(embed_dim, self.head_dim) + self.value_proj = nn.Linear(embed_dim, self.head_dim) + self.out_proj = nn.Linear(embed_dim, embed_dim) + + @cython.boundscheck(False) + @cython.wraparound(False) + def forward(self, query, key, value): + cdef int batch_size, seq_len, _ + + # Assuming the input tensors are torch.Tensor objects + batch_size, seq_len, _ = query.size() + + # Linear projections + queries = self.query_proj(query) + keys = self.key_proj(key) + values = self.value_proj(value) + + # Reshape for multi-head attention + queries = queries.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.unsqueeze(1).expand(-1, self.num_heads, -1, -1) + values = values.unsqueeze(1).expand(-1, self.num_heads, -1, -1) + + # Scaled dot-product attention + scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** 0.5) + attn_weights = torch.nn.functional.softmax(scores, dim=-1) + attn_output = torch.matmul(attn_weights, values) + + # Concatenate and project the output + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim) + output = self.out_proj(attn_output) + + return output diff --git a/examples/cython_tests/mqa_test.py b/examples/cython_tests/mqa_test.py new file mode 100644 index 00000000..40633fed --- /dev/null +++ b/examples/cython_tests/mqa_test.py @@ -0,0 +1,54 @@ +import timeit +import torch +from zeta import MultiQueryAttention as PyTorchMQA +from mqa import MultiQueryAttention as CythonMQA + +# Input parameters +batch_size = 32 +seq_len = 128 +embed_dim = 512 +num_heads = 8 + +# Create sample input tensors +query = torch.randn(batch_size, seq_len, embed_dim) +key = torch.randn(batch_size, seq_len, embed_dim) +value = torch.randn(batch_size, seq_len, embed_dim) + +# Initialize the PyTorch Multi-Query Attention layer (from zeta package) +pytorch_mqa = PyTorchMQA(dim=embed_dim, heads=num_heads) + +# Initialize the Cython Multi-Query Attention layer (from mqa module) +cython_mqa = CythonMQA(embed_dim, num_heads) + + +# Define functions for benchmarking +def run_pytorch_mqa(): + output, _, _ = pytorch_mqa(query) + return output + + +def run_cython_mqa(): + output = cython_mqa.forward(query, key, value) + return output + + +# Warm-up runs (important to avoid cold start issues) +for _ in range(20): + run_pytorch_mqa() + run_cython_mqa() + +# Benchmark PyTorch implementation +pytorch_time = timeit.timeit( + "run_pytorch_mqa()", globals=globals(), number=1000 +) + +# Benchmark Cython implementation +cython_time = timeit.timeit("run_cython_mqa()", globals=globals(), number=1000) + +# Print the results +print(f"PyTorch MQA execution time: {pytorch_time:.6f} seconds") +print(f"Cython MQA execution time: {cython_time:.6f} seconds") +if cython_time < pytorch_time: + print(f"Cython is faster by: {pytorch_time / cython_time:.2f}x") +else: + print(f"PyTorch is faster by: {cython_time / pytorch_time:.2f}x") diff --git a/examples/cython_tests/new_c_example.py b/examples/cython_tests/new_c_example.py new file mode 100644 index 00000000..f8ad3350 --- /dev/null +++ b/examples/cython_tests/new_c_example.py @@ -0,0 +1,10 @@ +import torch +import torch_extension # Import the compiled Cython module + +# Create a PyTorch tensor +input_tensor = torch.tensor([0.0, 1.0, 2.0, 3.0]) + +# Use the Cython function to apply the sin function +output_tensor = torch_extension.apply_sin(input_tensor) + +print(output_tensor) diff --git a/examples/cython_tests/setup.py b/examples/cython_tests/setup.py new file mode 100644 index 00000000..298de35f --- /dev/null +++ b/examples/cython_tests/setup.py @@ -0,0 +1,15 @@ +from setuptools import setup, Extension +from torch.utils.cpp_extension import BuildExtension +from Cython.Build import cythonize + +setup( + name="mqa", + ext_modules=cythonize( + Extension( + "mqa", + sources=["mqa.pyx"], + language="c++", + ) + ), + cmdclass={"build_ext": BuildExtension}, +) diff --git a/examples/cython_tests/torch_extension.pyx b/examples/cython_tests/torch_extension.pyx new file mode 100644 index 00000000..e2bbc01a --- /dev/null +++ b/examples/cython_tests/torch_extension.pyx @@ -0,0 +1,20 @@ +import torch # Use standard Python import for PyTorch +cimport cython +import numpy as np + +@cython.boundscheck(False) +@cython.wraparound(False) +def apply_sin(input_tensor): + cdef int i + cdef int size = input_tensor.numel() + + # Convert the PyTorch tensor to a NumPy array + np_array = input_tensor.numpy() + + # Apply sin element-wise using NumPy + np_output = np.sin(np_array) + + # Convert the NumPy array back to a PyTorch tensor + output_tensor = torch.from_numpy(np_output) + + return output_tensor