Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Support Microsoft Runtime Kernel Lib for our Low Precision Computation - BitBLAS #6036

Open
wants to merge 80 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
2be6218
Support Repack from GPTQ.
LeiWang1999 Jul 1, 2024
b92de92
chore: Remove unused input_size and output_size variables in MarlinLi…
LeiWang1999 Jul 1, 2024
71ea469
Support BitNet Model for 1.58bits.
LeiWang1999 Jul 16, 2024
dfa6b2f
Lint Fix
LeiWang1999 Jul 16, 2024
8d2c635
lint fix
LeiWang1999 Jul 16, 2024
41bb18e
Lint Fix for line length
LeiWang1999 Jul 16, 2024
29ac34d
Support Loading 1.58B Model with BitBLAS Format
LeiWang1999 Jul 17, 2024
7f69aef
Improve performance for bitnet
LeiWang1999 Jul 19, 2024
01a789a
Merge branch 'main' of https://github.com/vllm-project/vllm into bitb…
LeiWang1999 Jul 19, 2024
a973123
fix lm_head for gptq model refactor
LeiWang1999 Jul 19, 2024
aea1f4c
linx fix
LeiWang1999 Jul 19, 2024
17128d5
handle compressed scale weight.
LeiWang1999 Aug 13, 2024
1741ed4
lint fix
LeiWang1999 Aug 13, 2024
726a1f7
remove partial weight load for sw
LeiWang1999 Aug 15, 2024
68c8052
apply torch compile for uncompressed weight.
LeiWang1999 Aug 15, 2024
6eb2870
Merge branch 'main' of https://github.com/vllm-project/vllm into bitb…
LeiWang1999 Aug 15, 2024
52418ef
merge bug fix
LeiWang1999 Aug 15, 2024
a15ba12
lint fix
LeiWang1999 Aug 15, 2024
53babae
fix torch compile issue
LeiWang1999 Aug 18, 2024
40a4e53
bug fix.
LeiWang1999 Aug 20, 2024
d316a87
BENCHMARK SCRIPTS
LeiWang1999 Aug 20, 2024
4d40275
Merge branch 'main' of https://github.com/vllm-project/vllm into bitb…
LeiWang1999 Aug 20, 2024
bffc05b
Implement Test
LeiWang1999 Aug 20, 2024
8b0972b
lint fix
LeiWang1999 Aug 20, 2024
8e1a7e8
install bitblas by default to pass the doc gen.
LeiWang1999 Aug 20, 2024
7fbbccf
hide the bitblas import
LeiWang1999 Aug 20, 2024
c487e69
import fix
LeiWang1999 Aug 20, 2024
0d364e7
remove all bitnet related items
LeiWang1999 Aug 21, 2024
d8e9448
remove bitnet
LeiWang1999 Aug 21, 2024
112ef48
format fix
LeiWang1999 Aug 21, 2024
7a88731
remove commen requirements
LeiWang1999 Aug 21, 2024
1fed839
remove dep
LeiWang1999 Aug 21, 2024
4c4cf59
lint fix
LeiWang1999 Aug 21, 2024
68d8ff1
remove red
LeiWang1999 Aug 21, 2024
c4e818c
optimize tuning message
LeiWang1999 Aug 21, 2024
546553f
update
LeiWang1999 Aug 21, 2024
e74368a
Update docs/source/quantization/bitblas.rst
LeiWang1999 Aug 22, 2024
051e11a
Update docs/source/quantization/bitblas.rst
LeiWang1999 Aug 22, 2024
7f6708a
Update docs/source/quantization/bitblas.rst
LeiWang1999 Aug 22, 2024
e3c8159
Update docs/source/quantization/bitblas.rst
LeiWang1999 Aug 22, 2024
b335ea8
Merge branch 'bitblas-intg' of https://github.com/LeiWang1999/vllm-bi…
LeiWang1999 Aug 22, 2024
5b8cc8c
typo fix
LeiWang1999 Aug 23, 2024
6308db5
support bfloat16
LeiWang1999 Aug 23, 2024
763c71e
update recommend installing command
LeiWang1999 Aug 23, 2024
2ad2a14
Merge branch 'main' of https://github.com/vllm-project/vllm into bitb…
LeiWang1999 Aug 23, 2024
31a7bbd
lint fix
LeiWang1999 Aug 23, 2024
4f7ab2f
Merge branch 'main' of https://github.com/vllm-project/vllm into bitb…
LeiWang1999 Oct 23, 2024
92e3a08
Add commit_id.py with commit hash
LeiWang1999 Oct 23, 2024
cd924dd
lint fix
LeiWang1999 Oct 23, 2024
b8bb550
Merge branch 'main' of https://github.com/vllm-project/vllm into bitb…
LeiWang1999 Nov 1, 2024
9d9e115
BitBLAS quantization updates
Nov 11, 2024
c190b2e
Merge branch 'main' of https://github.com/vllm-project/vllm into bitb…
Nov 11, 2024
ae7a169
revert marlin change
LeiWang1999 Nov 13, 2024
b2a82e6
also supports sm 70
LeiWang1999 Nov 13, 2024
757aed0
replace parameter
LeiWang1999 Nov 13, 2024
36ba50a
review handling
LeiWang1999 Nov 13, 2024
8359a98
remove commit id
LeiWang1999 Nov 13, 2024
742be3d
lint fix
LeiWang1999 Nov 13, 2024
fa1c932
remove debug print
LeiWang1999 Nov 13, 2024
df7a5f6
bug fix
LeiWang1999 Dec 19, 2024
9adb4d5
lint fix
LeiWang1999 Dec 19, 2024
f0a1ec3
Merge branch 'main' of https://github.com/vllm-project/vllm into bitb…
LeiWang1999 Dec 19, 2024
7d5dd06
merge upstream
LeiWang1999 Dec 19, 2024
8d7881b
modify commit id
LeiWang1999 Dec 19, 2024
217fa5a
add bitblas to index
LeiWang1999 Dec 19, 2024
d6586e7
Update docs/source/quantization/bitblas.rst
LeiWang1999 Dec 19, 2024
430ca44
Update docs/source/quantization/bitblas.rst
LeiWang1999 Dec 19, 2024
f2af59e
force use MINIMUM_BITBLAS_VERSION
LeiWang1999 Dec 19, 2024
d868cac
lint fix
LeiWang1999 Dec 19, 2024
6ec9800
lint fix
LeiWang1999 Dec 19, 2024
bcbad57
lint fix
LeiWang1999 Dec 19, 2024
6cc9022
Merge branch 'bitblas-intg' of https://github.com/LeiWang1999/vllm-bi…
LeiWang1999 Dec 19, 2024
3703449
conflict resolved
LeiWang1999 Dec 20, 2024
3343ace
Merge branch 'main' of https://github.com/vllm-project/vllm into bitb…
LeiWang1999 Feb 7, 2025
8ef6545
Add BitBLASLinearKernel for optimized mixed precision linear operations
LeiWang1999 Feb 7, 2025
f6e9a35
Enhance BitBLAS support by adding tile size handling and adjusting sh…
LeiWang1999 Feb 12, 2025
4160497
Add BitBLAS documentation and remove obsolete RST file
LeiWang1999 Feb 13, 2025
2bc4dec
fix comments
LeiWang1999 Feb 13, 2025
1050f92
Update supported hardware documentation to include BitBLAS compatibility
LeiWang1999 Feb 13, 2025
4a5ebb4
Implement additional optimizations for BitBLAS performance and stability
LeiWang1999 Feb 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 236 additions & 0 deletions benchmarks/kernels/benchmark_bitblas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
MINIMUM_BITBLAS_VERSION)

try:
import bitblas
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
raise ImportError("bitblas version is wrong. Please "
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
except ImportError as e:
bitblas_import_exception = e
raise ValueError("Trying to use the bitblas backend, but could not import"
f"with the following error: {bitblas_import_exception}. "
"Please install bitblas through the following command: "
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
) from bitblas_import_exception

from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target

from vllm.utils import FlexibleArgumentParser

parser = FlexibleArgumentParser(
description="Benchmark BitBLAS int4 on a specific target.")

# Add arguments to the parser
parser.add_argument(
"--target",
type=str,
default=auto_detect_nvidia_target(),
help="Specify the target device for benchmarking.",
)
parser.add_argument("--group_size",
type=int,
default=None,
help="Group size for grouped quantization.")
parser.add_argument(
"--A_dtype",
type=str,
default="float16",
choices=["float16", "float32", "float64", "int32", "int8"],
help="Data type of activation A.",
)
parser.add_argument(
"--W_dtype",
type=str,
default="int4",
choices=[
"float16",
"float32",
"float64",
"int32",
"int8",
"int4",
"int2",
"int1",
"nf4",
"fp4_e2m1",
],
help="Data type of weight W.",
)
parser.add_argument(
"--accum_dtype",
type=str,
default="float16",
choices=["float16", "int32"],
help="Data type for accumulation.",
)
parser.add_argument(
"--out_dtype",
type=str,
default="float16",
choices=["float16", "float32", "int32", "int8"],
help="Data type for output.",
)
parser.add_argument(
"--layout",
type=str,
default="nt",
choices=["nt", "nn"],
help="Matrix layout, 'nt' for non-transpose A and transpose W.",
)
parser.add_argument("--with_bias",
action="store_true",
help="Include bias in the benchmark.")
parser.add_argument(
"--with_scaling",
action="store_true",
help="Include scaling factor in the quantization.",
)
parser.add_argument("--with_zeros",
action="store_true",
help="Include zeros in the quantization.")
parser.add_argument(
"--zeros_mode",
type=str,
default=None,
choices=["original", "rescale", "quantized"],
help="Specify the mode for calculating zeros.",
)

# Parse the arguments
args = parser.parse_args()

# Assign arguments to variables
target = args.target
A_dtype = args.A_dtype
W_dtype = args.W_dtype
accum_dtype = args.accum_dtype
out_dtype = args.out_dtype
layout = args.layout
with_bias = args.with_bias
group_size = args.group_size
with_scaling = args.with_scaling
with_zeros = args.with_zeros
zeros_mode = args.zeros_mode

# Define a list of shared arguments that repeat in every config
shared_args = [
A_dtype,
W_dtype,
out_dtype,
accum_dtype,
layout,
with_bias,
group_size,
with_scaling,
with_zeros,
zeros_mode,
]

# Define just the (M, K, N) shapes in a more compact list
shapes = [
# square test
(1, 16384, 16384),
# BLOOM-176B
(1, 43008, 14336),
(1, 14336, 14336),
(1, 57344, 14336),
(1, 14336, 57344),
# OPT-65B
(1, 9216, 9216),
(1, 36864, 9216),
(1, 9216, 36864),
(1, 22016, 8192),
# LLAMA-70B/65B
(1, 8192, 22016),
(1, 8192, 8192),
(1, 28672, 8192),
(1, 8192, 28672),
# square test
(16384, 16384, 16384),
# BLOOM-176B
(8192, 43008, 14336),
(8192, 14336, 14336),
(8192, 57344, 14336),
(8192, 14336, 57344),
# OPT-65B
(8192, 9216, 9216),
(8192, 36864, 9216),
(8192, 9216, 36864),
(8192, 22016, 8192),
# LLAMA-70B/65B
(8192, 8192, 22016),
(8192, 8192, 8192),
(8192, 28672, 8192),
(8192, 8192, 28672),
]

# Build test shapes with all the shared arguments
test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args))
for shape in shapes]

benchmark_sets = []
benchmark_sets.extend(test_shapes)

benchmark_results = {}
for config_class, operator, input_args in benchmark_sets:
config = config_class(*input_args)
matmul = operator(config, target=target, enable_tuning=True)
kernel_latency = matmul.profile_latency()

print("Time cost is: {:.3f} ms".format(kernel_latency))

profile_config = {
f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": {
"BitBLAS_top20_latency": kernel_latency,
}
}

benchmark_results.update(profile_config)

# Define headers for the table
headers = [
"PrimFunc",
"Input Arguments",
"BitBLAS Top20 Latency",
]

# Calculate column widths for pretty printing
col_widths = [0, 0, 0]
for config_key, values in benchmark_results.items():
args_split = config_key.split("-")
func_name = args_split[0]
input_args_str = "-".join(args_split[1:])
col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2)
col_widths[1] = max(col_widths[1],
len(input_args_str) + 2,
len(headers[1]) + 2)
col_widths[2] = max(col_widths[2],
len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2,
len(headers[2]) + 2)
# break only if you want to measure widths from a single example;
# otherwise, let it loop over all items.

# Print header
for i, header in enumerate(headers):
headers[i] = header.ljust(col_widths[i])
print("".join(headers))
print("-" * sum(col_widths))

# Print rows
for config_key, values in benchmark_results.items():
args_split = config_key.split("-")
func_name = args_split[0]
input_args_str = "-".join(args_split[1:])
row = [
func_name,
input_args_str,
f"{values['BitBLAS_top20_latency']:.3f} ms",
]
row_str = "".join(
[str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)])
print(row_str)
40 changes: 40 additions & 0 deletions docs/source/features/quantization/bitblas.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# BitBLAS

vLLM now supports [BitBLAS](https://github.com/microsoft/BitBLAS) for more efficient and flexible model inference. Compared to other quantization frameworks, BitBLAS provides more precision combinations.

Below are the steps to utilize BitBLAS with vLLM.

```console
pip install bitblas>=0.1.0
```

vLLM reads the model's config file and supports pre-quantized checkpoints.

You can find pre-quantized models on:

- [Hugging Face (BitBLAS)](https://huggingface.co/models?other=bitblas)
- [Hugging Face (GPTQ)](https://huggingface.co/models?other=gptq)

Usually, these repositories have a `quantize_config.json` file that includes a `quantization_config` section.

## Read bitblas format checkpoint

```python
from vllm import LLM
import torch

# "hxbgsyxh/llama-13b-4bit-g-1-bitblas" is a pre-quantized checkpoint.
model_id = "hxbgsyxh/llama-13b-4bit-g-1-bitblas"
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, quantization="bitblas")
```

## Read gptq format checkpoint

```python
from vllm import LLM
import torch

# "hxbgsyxh/llama-13b-4bit-g-1" is a pre-quantized checkpoint.
model_id = "hxbgsyxh/llama-13b-4bit-g-1"
llm = LLM(model=model_id, dtype=torch.float16, trust_remote_code=True, quantization="bitblas", max_model_len=1024)
```
11 changes: 11 additions & 0 deletions docs/source/features/quantization/supported_hardware.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ The table below shows the compatibility of various quantization implementations
* ✗
* ✗
* ✗
- * BitBLAS (GPTQ)
* ✅︎
* ✅︎
* ✅︎
* ✅︎
* ✅︎
* ✅︎
* ✗
* ✗
* ✗
* ✗
- * INT8 (W8A8)
* ✗
* ✅︎
Expand Down
62 changes: 62 additions & 0 deletions tests/models/test_bitblas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Compare the outputs of a GPTQ model to a bitblas model.

Note: GPTQ and bitblas do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
bitblas/GPTQ models are in the top 3 selections of each other.

Note: bitblas internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for bitblas. As a result, we re-run the
test up to 3 times to see if we pass.

Run `pytest tests/models/test_bitblas.py`.
"""
from dataclasses import dataclass

import pytest

from .utils import check_logprobs_close


@dataclass
class ModelPair:
model_bitblas: str
model_gptq: str


model_pairs = [
ModelPair(model_bitblas="hxbgsyxh/opt-125m-4bit-128g-bitblas",
model_gptq="hxbgsyxh/opt-125m-4bit-128g"),
]


@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.")
@pytest.mark.parametrize("model_pair", model_pairs)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
vllm_runner,
example_prompts,
model_pair: ModelPair,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with vllm_runner(model_pair.model_bitblas,
dtype=dtype,
quantization="bitblas") as bitblas_model:
bitblas_outputs = bitblas_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

with vllm_runner(model_pair.model_gptq, dtype=dtype,
quantization="gptq") as gptq_model:
gptq_outputs = gptq_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

check_logprobs_close(
outputs_0_lst=gptq_outputs,
outputs_1_lst=bitblas_outputs,
name_0="gptq",
name_1="bitblas",
)
Loading
Loading