Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jcaip committed Feb 12, 2025
1 parent d503e5d commit c093681
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 55 deletions.
1 change: 1 addition & 0 deletions benchmarks/benchmark_gpu_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def sparse_func():
# (32, 32, 16),
(4096, 14336, 1),
# (14336, 4096, 1),
# (14336, 4096, 1),
# (11008, 4096, 16),
# (16, 4096, 4096),
# (4096, 4096, 11008),
Expand Down
41 changes: 41 additions & 0 deletions test/sparsity/test_bsr_sum_prod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch

import triton
import triton.language as tl
import pdb

from torchao.sparsity.utils import create_block_sparse_tensor
from torchao.sparsity.blocksparse import BlockSparseTensor
from torch.library import wrap_triton, triton_op



@torch.compile(dynamic=False, fullgraph=True)
def test(w, x):
b = x.unsqueeze(0)
out= (torch.mul(w, b)).sum(dim=1)
return out

torch.set_printoptions(profile='full', linewidth=100000)
torch.manual_seed(0)
size = 98432

with torch.no_grad():
create_block_sparse_tensor = torch.compiler.disable(create_block_sparse_tensor)
a = create_block_sparse_tensor(32, 32, 16, 0.5, torch.bfloat16).cuda() * torch.randn(32, 32, dtype=torch.bfloat16).cuda()
a[:16, :16] *= 4
a[16:, 16:] *= 4
a[16:, :16] *= 2
a[:16, 16:] *= 1
# print(a)
# print(x)
w = BlockSparseTensor.from_dense(a, 16).detach()
x = torch.arange(32).reshape((32, 1)).to(torch.bfloat16).cuda()
# expected= test(a.unsqueeze(2), x)
# print(expected)
# print("strides", w.unsqueeze(2).stride())
# print("strides", w.stride())
out = test(w.unsqueeze(2), x)
# print(out)

# torch.testing.assert_close(out, expected, rtol=1e-2, atol=1e-2)
14 changes: 14 additions & 0 deletions test/sparsity/test_supermask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
quantize_,
)
from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_
from torchao.sparsity.blocksparse import BlockSparseTensor
from torchao.sparsity.utils import create_block_sparse_tensor
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
Expand Down Expand Up @@ -58,6 +60,18 @@ def test_from_linear(self):
supermask_linear = SupermaskLinear.from_linear(linear, sparsity_level=0.5, blocksize=4)
assert supermask_linear.weight.shape == linear.weight.shape

def test_fastpath(self):
a = create_block_sparse_tensor(128, 128, 64, 0.5, torch.bfloat16).cuda()
# print(a)
w = a
x = torch.randn(128, 1).to(torch.bfloat16).cuda()
expected = (torch.mul(w.unsqueeze(2), x.unsqueeze(0))).sum(dim=1)

a_sparse = BlockSparseTensor.from_dense(a, 64)
w = a_sparse
out = (torch.mul(w.unsqueeze(2), x.unsqueeze(0))).sum(dim=1)
torch.testing.assert_close(out, expected, rtol=1e-2, atol=1e-2)


common_utils.instantiate_parametrized_tests(TestSupermask)

Expand Down
28 changes: 27 additions & 1 deletion torchao/_models/llama/benchmark_results.txt

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion torchao/_models/llama/bsr_benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --prefill_size 8192
#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --prefill_size 8192 --sparsity bsr-0.9-64
#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --sparsity bsr-0.9-64
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --sparsity bsr-0.9-32
#python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --sparsity bsr-0.9-32
75 changes: 54 additions & 21 deletions torchao/kernel/bsr_triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from torch.sparse._triton_ops_meta import get_meta



TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE = int(
os.getenv("TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE", 2)
)
Expand Down Expand Up @@ -354,13 +353,14 @@ def bsr_dense_addmm_meta(
# verbose=True,
# )
# get padded key
padded_key = (M, K, 16, Ms, Ks, beta == 0, beta == 1, alpha == 1)
meta = get_meta(
"bsr_dense_addmm",
padded_key,
device_name,
version=(_version, version_dtype, sparsity),
)
# padded_key = (M, K, 16, Ms, Ks, beta == 0, beta == 1, alpha == 1)
# meta = get_meta(
# "bsr_dense_addmm",
# padded_key,
# device_name,
# version=(_version, version_dtype, sparsity),
# )
pass
# breakpoint()
# return meta
# message
Expand All @@ -372,7 +372,7 @@ def bsr_dense_addmm_meta(

SPLIT_N = SPLIT_N or max(N // Ms, 1)
GROUP_SIZE_ROW = GROUP_SIZE_ROW or 4
num_stages = num_stages or 1
num_stages = num_stages or 4
num_warps = num_warps or 4
return dict(
SPLIT_N=SPLIT_N,
Expand Down Expand Up @@ -505,7 +505,6 @@ def _int_bsr_dense_addmm(
def bsr_dense_addmm(
input: torch.Tensor,
bsr: torch.Tensor,
row_indices: torch.Tensor,
dense: torch.Tensor,
*,
beta=1,
Expand Down Expand Up @@ -657,7 +656,6 @@ def kernel(grid, *sliced_tensors):
BLOCKSIZE_ROW=BM,
BLOCKSIZE_INNER=BK,
BLOCKSIZE_COL=BN,
BLOCKSIZE_K=32,
allow_tf32=dot_out_dtype == tl.float32,
acc_dtype=dot_out_dtype,
**meta,
Expand Down Expand Up @@ -746,7 +744,6 @@ def _bsr_strided_addmm_kernel(
BLOCKSIZE_ROW: tl.constexpr,
BLOCKSIZE_COL: tl.constexpr,
BLOCKSIZE_INNER: tl.constexpr,
BLOCKSIZE_K: tl.constexpr,
acc_dtype: tl.constexpr,
allow_tf32: tl.constexpr,
GROUP_SIZE_ROW: tl.constexpr,
Expand Down Expand Up @@ -782,10 +779,10 @@ def _bsr_strided_addmm_kernel(
row_block_arange = tl.arange(0, BLOCKSIZE_ROW)
inner_block_arange = tl.arange(0, BLOCKSIZE_INNER)

PADDED_BLOCKSIZE_COL : tl.constexpr = 16
# if BLOCKSIZE_COL < 16 or BLOCKSIZE_COL % 16 != 0:
# else:
# PADDED_BLOCKSIZE_COL: tl.constexpr = BLOCKSIZE_COL
if BLOCKSIZE_COL < 16 or BLOCKSIZE_COL % 16 != 0:
PADDED_BLOCKSIZE_COL : tl.constexpr = 16
else:
PADDED_BLOCKSIZE_COL: tl.constexpr = BLOCKSIZE_COL

col_block_arange = tl.arange(0, PADDED_BLOCKSIZE_COL)

Expand Down Expand Up @@ -826,11 +823,8 @@ def _bsr_strided_addmm_kernel(
)

output_acc_block = tl.zeros((BLOCKSIZE_ROW, PADDED_BLOCKSIZE_COL), dtype=acc_dtype)

nsub_blocks = tl.cdiv(BLOCKSIZE_ROW, BLOCKSIZE_K)


for i in range(row_nnz):
# offsets = tl.arange(0, PADDED_BLOCKSIZE_COL)[None, :]
for _ in range(row_nnz):
values_block = tl.load(values_block_ptrs)

# find which row of dense needs to get loaded
Expand All @@ -850,6 +844,45 @@ def _bsr_strided_addmm_kernel(
values_block_ptrs += values_nnz_stride
col_index_nnz_ptr += col_indices_stride

if not alpha_is_one:
output_acc_block *= alpha

if not left_alpha_is_one:
left_alpha_ptrs = (
left_alpha_ptr
+ left_alpha_batch_stride * batch_pid
+ left_alpha_tiled_row_stride * row_block_pid
+ left_alpha_tiled_col_stride * col_block_pid
+ left_alpha_row_block_stride * row_block_arange[:, None]
+ left_alpha_col_block_stride * col_block_arange[None, :]
)
output_acc_block *= tl.load(left_alpha_ptrs)

if not right_alpha_is_one:
right_alpha_ptrs = (
right_alpha_ptr
+ right_alpha_batch_stride * batch_pid
+ right_alpha_tiled_row_stride * row_block_pid
+ right_alpha_tiled_col_stride * col_block_pid
+ right_alpha_row_block_stride * row_block_arange[:, None]
+ right_alpha_col_block_stride * col_block_arange[None, :]
)
output_acc_block *= tl.load(right_alpha_ptrs)

if beta_is_nonzero:
input_ptrs = (
input_ptr
+ input_batch_stride * batch_pid
+ input_tiled_row_stride * row_block_pid
+ input_tiled_col_stride * col_block_pid
+ input_row_block_stride * row_block_arange[:, None]
+ input_col_block_stride * col_block_arange[None, :]
)
if beta_is_one:
output_acc_block += tl.load(input_ptrs)
else:
output_acc_block += beta * tl.load(input_ptrs)

# write back the result
tl.store(output_ptrs, output_acc_block.to(output_ptr.dtype.element_ty), mask=col_block_arange[None, :]< BLOCKSIZE_COL)

Expand Down
Loading

0 comments on commit c093681

Please sign in to comment.