Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Support Microsoft Runtime Kernel Lib for our Low Precision Computation - BitBLAS #6036

Open
wants to merge 80 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
2be6218
Support Repack from GPTQ.
LeiWang1999 Jul 1, 2024
b92de92
chore: Remove unused input_size and output_size variables in MarlinLi…
LeiWang1999 Jul 1, 2024
71ea469
Support BitNet Model for 1.58bits.
LeiWang1999 Jul 16, 2024
dfa6b2f
Lint Fix
LeiWang1999 Jul 16, 2024
8d2c635
lint fix
LeiWang1999 Jul 16, 2024
41bb18e
Lint Fix for line length
LeiWang1999 Jul 16, 2024
29ac34d
Support Loading 1.58B Model with BitBLAS Format
LeiWang1999 Jul 17, 2024
7f69aef
Improve performance for bitnet
LeiWang1999 Jul 19, 2024
01a789a
Merge branch 'main' of https://github.com/vllm-project/vllm into bitb…
LeiWang1999 Jul 19, 2024
a973123
fix lm_head for gptq model refactor
LeiWang1999 Jul 19, 2024
aea1f4c
linx fix
LeiWang1999 Jul 19, 2024
17128d5
handle compressed scale weight.
LeiWang1999 Aug 13, 2024
1741ed4
lint fix
LeiWang1999 Aug 13, 2024
726a1f7
remove partial weight load for sw
LeiWang1999 Aug 15, 2024
68c8052
apply torch compile for uncompressed weight.
LeiWang1999 Aug 15, 2024
6eb2870
Merge branch 'main' of https://github.com/vllm-project/vllm into bitb…
LeiWang1999 Aug 15, 2024
52418ef
merge bug fix
LeiWang1999 Aug 15, 2024
a15ba12
lint fix
LeiWang1999 Aug 15, 2024
53babae
fix torch compile issue
LeiWang1999 Aug 18, 2024
40a4e53
bug fix.
LeiWang1999 Aug 20, 2024
d316a87
BENCHMARK SCRIPTS
LeiWang1999 Aug 20, 2024
4d40275
Merge branch 'main' of https://github.com/vllm-project/vllm into bitb…
LeiWang1999 Aug 20, 2024
bffc05b
Implement Test
LeiWang1999 Aug 20, 2024
8b0972b
lint fix
LeiWang1999 Aug 20, 2024
8e1a7e8
install bitblas by default to pass the doc gen.
LeiWang1999 Aug 20, 2024
7fbbccf
hide the bitblas import
LeiWang1999 Aug 20, 2024
c487e69
import fix
LeiWang1999 Aug 20, 2024
0d364e7
remove all bitnet related items
LeiWang1999 Aug 21, 2024
d8e9448
remove bitnet
LeiWang1999 Aug 21, 2024
112ef48
format fix
LeiWang1999 Aug 21, 2024
7a88731
remove commen requirements
LeiWang1999 Aug 21, 2024
1fed839
remove dep
LeiWang1999 Aug 21, 2024
4c4cf59
lint fix
LeiWang1999 Aug 21, 2024
68d8ff1
remove red
LeiWang1999 Aug 21, 2024
c4e818c
optimize tuning message
LeiWang1999 Aug 21, 2024
546553f
update
LeiWang1999 Aug 21, 2024
e74368a
Update docs/source/quantization/bitblas.rst
LeiWang1999 Aug 22, 2024
051e11a
Update docs/source/quantization/bitblas.rst
LeiWang1999 Aug 22, 2024
7f6708a
Update docs/source/quantization/bitblas.rst
LeiWang1999 Aug 22, 2024
e3c8159
Update docs/source/quantization/bitblas.rst
LeiWang1999 Aug 22, 2024
b335ea8
Merge branch 'bitblas-intg' of https://github.com/LeiWang1999/vllm-bi…
LeiWang1999 Aug 22, 2024
5b8cc8c
typo fix
LeiWang1999 Aug 23, 2024
6308db5
support bfloat16
LeiWang1999 Aug 23, 2024
763c71e
update recommend installing command
LeiWang1999 Aug 23, 2024
2ad2a14
Merge branch 'main' of https://github.com/vllm-project/vllm into bitb…
LeiWang1999 Aug 23, 2024
31a7bbd
lint fix
LeiWang1999 Aug 23, 2024
4f7ab2f
Merge branch 'main' of https://github.com/vllm-project/vllm into bitb…
LeiWang1999 Oct 23, 2024
92e3a08
Add commit_id.py with commit hash
LeiWang1999 Oct 23, 2024
cd924dd
lint fix
LeiWang1999 Oct 23, 2024
b8bb550
Merge branch 'main' of https://github.com/vllm-project/vllm into bitb…
LeiWang1999 Nov 1, 2024
9d9e115
BitBLAS quantization updates
Nov 11, 2024
c190b2e
Merge branch 'main' of https://github.com/vllm-project/vllm into bitb…
Nov 11, 2024
ae7a169
revert marlin change
LeiWang1999 Nov 13, 2024
b2a82e6
also supports sm 70
LeiWang1999 Nov 13, 2024
757aed0
replace parameter
LeiWang1999 Nov 13, 2024
36ba50a
review handling
LeiWang1999 Nov 13, 2024
8359a98
remove commit id
LeiWang1999 Nov 13, 2024
742be3d
lint fix
LeiWang1999 Nov 13, 2024
fa1c932
remove debug print
LeiWang1999 Nov 13, 2024
df7a5f6
bug fix
LeiWang1999 Dec 19, 2024
9adb4d5
lint fix
LeiWang1999 Dec 19, 2024
f0a1ec3
Merge branch 'main' of https://github.com/vllm-project/vllm into bitb…
LeiWang1999 Dec 19, 2024
7d5dd06
merge upstream
LeiWang1999 Dec 19, 2024
8d7881b
modify commit id
LeiWang1999 Dec 19, 2024
217fa5a
add bitblas to index
LeiWang1999 Dec 19, 2024
d6586e7
Update docs/source/quantization/bitblas.rst
LeiWang1999 Dec 19, 2024
430ca44
Update docs/source/quantization/bitblas.rst
LeiWang1999 Dec 19, 2024
f2af59e
force use MINIMUM_BITBLAS_VERSION
LeiWang1999 Dec 19, 2024
d868cac
lint fix
LeiWang1999 Dec 19, 2024
6ec9800
lint fix
LeiWang1999 Dec 19, 2024
bcbad57
lint fix
LeiWang1999 Dec 19, 2024
6cc9022
Merge branch 'bitblas-intg' of https://github.com/LeiWang1999/vllm-bi…
LeiWang1999 Dec 19, 2024
3703449
conflict resolved
LeiWang1999 Dec 20, 2024
3343ace
Merge branch 'main' of https://github.com/vllm-project/vllm into bitb…
LeiWang1999 Feb 7, 2025
8ef6545
Add BitBLASLinearKernel for optimized mixed precision linear operations
LeiWang1999 Feb 7, 2025
f6e9a35
Enhance BitBLAS support by adding tile size handling and adjusting sh…
LeiWang1999 Feb 12, 2025
4160497
Add BitBLAS documentation and remove obsolete RST file
LeiWang1999 Feb 13, 2025
2bc4dec
fix comments
LeiWang1999 Feb 13, 2025
1050f92
Update supported hardware documentation to include BitBLAS compatibility
LeiWang1999 Feb 13, 2025
4a5ebb4
Implement additional optimizations for BitBLAS performance and stability
LeiWang1999 Feb 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
686 changes: 686 additions & 0 deletions benchmarks/kernels/benchmark_bitblas.py

Large diffs are not rendered by default.

41 changes: 41 additions & 0 deletions docs/source/quantization/bitblas.rst
LeiWang1999 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
.. _bitblas:
LeiWang1999 marked this conversation as resolved.
Show resolved Hide resolved

BitBLAS
==================

vLLM now supports `BitBLAS <https://github.com/microsoft/BitBLAS>`_ for more efficient and flexible model inference.
Compared to other quantization frameworks, BitBLAS provides more precision combinations.

Below are the steps to utilize BitBLAS with vLLM.

.. code-block:: console

$ pip install bitblas>=0.0.1.dev15

vLLM reads the model's config file and supports pre-quantized checkpoint.

You can find pre-quantized models on https://huggingface.co/models?other=bitblas or https://huggingface.co/models?other=bitnet or https://huggingface.co/models?other=gptq.

And usually, these repositories have a quantize_config.json file that includes a quantization_config section.

Read bitblas format checkpoint.
--------------------------
LeiWang1999 marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

from vllm import LLM
import torch
# "hxbgsyxh/llama-13b-4bit-g-1-bitblas is a pre-quantized checkpoint.
model_id = "hxbgsyxh/llama-13b-4bit-g-1-bitblas"
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, quantization="bitblas")

Read gptq format checkpoint.
--------------------------
.. code-block:: python
LeiWang1999 marked this conversation as resolved.
Show resolved Hide resolved

from vllm import LLM
import torch
# hxbgsyxh/llama-13b-4bit-g-1 is a pre-quantized checkpoint.
model_id = "hxbgsyxh/llama-13b-4bit-g-1"
llm = LLM(model=model_id, dtype=torch.float16, trust_remote_code=True, quantization="bitblas", max_model_len=1024)

62 changes: 62 additions & 0 deletions tests/models/test_bitblas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Compare the outputs of a GPTQ model to a bitblas model.

Note: GPTQ and bitblas do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
bitblas/GPTQ models are in the top 3 selections of each other.

Note: bitblas internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for bitblas. As a result, we re-run the
test up to 3 times to see if we pass.

Run `pytest tests/models/test_bitblas.py`.
"""
from dataclasses import dataclass

import pytest

from .utils import check_logprobs_close


@dataclass
class ModelPair:
model_bitblas: str
model_gptq: str


model_pairs = [
ModelPair(model_bitblas="hxbgsyxh/opt-125m-4bit-128g-bitblas",
model_gptq="hxbgsyxh/opt-125m-4bit-128g"),
]


@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.")
@pytest.mark.parametrize("model_pair", model_pairs)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
vllm_runner,
example_prompts,
model_pair: ModelPair,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with vllm_runner(model_pair.model_bitblas,
dtype=dtype,
quantization="bitblas") as bitblas_model:
bitblas_outputs = bitblas_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

with vllm_runner(model_pair.model_gptq, dtype=dtype,
quantization="gptq") as gptq_model:
gptq_outputs = gptq_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

check_logprobs_close(
outputs_0_lst=gptq_outputs,
outputs_1_lst=bitblas_outputs,
name_0="gptq",
name_1="bitblas",
)
60 changes: 60 additions & 0 deletions tests/models/test_gptq_bitblas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Compare the outputs of a GPTQ model to a bitblas model.

Note: GPTQ and bitblas do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
bitblas/GPTQ models are in the top 3 selections of each other.

Note: bitblas internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for bitblas. As a result, we re-run the
test up to 3 times to see if we pass.

Run `pytest tests/models/test_bitblas.py`.
"""
from dataclasses import dataclass

import pytest

from .utils import check_logprobs_close


@dataclass
class ModelPair:
model_gptq: str


model_pairs = [
ModelPair(model_gptq="hxbgsyxh/opt-125m-4bit-128g"),
]


@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.")
@pytest.mark.parametrize("model_pair", model_pairs)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
vllm_runner,
example_prompts,
model_pair: ModelPair,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with vllm_runner(model_pair.model_gptq,
dtype=dtype,
quantization="bitblas") as bitblas_model:
bitblas_outputs = bitblas_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

with vllm_runner(model_pair.model_gptq, dtype=dtype,
quantization="gptq") as gptq_model:
gptq_outputs = gptq_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

check_logprobs_close(
outputs_0_lst=gptq_outputs,
outputs_1_lst=bitblas_outputs,
name_0="gptq",
name_1="gptq_bitblas",
)
6 changes: 5 additions & 1 deletion tests/quantization/test_lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import pytest
import torch

from vllm.model_executor.layers.quantization.bitblas import BitBLASLinearMethod
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_bitblas import (
GPTQBitBLASLinearMethod)
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod)
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
Expand Down Expand Up @@ -36,7 +39,8 @@ def test_lm_head(
if lm_head_quantized:
assert isinstance(
lm_head_layer.linear_method,
(GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod))
(GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod,
GPTQBitBLASLinearMethod, BitBLASLinearMethod))
else:
assert isinstance(lm_head_layer.linear_method,
UnquantizedEmbeddingMethod)
Expand Down
66 changes: 48 additions & 18 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ def _verify_quantization(self) -> None:
]
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
"compressed-tensors", "experts_int8"
"gptq_bitblas", "bitblas", "awq_marlin", "fbgemm_fp8",
"compressed_tensors", "compressed-tensors", "experts_int8"
]
tpu_supported_quantization = ["tpu_int8"]
neuron_supported_quantization = ["neuron_quant"]
Expand Down Expand Up @@ -584,10 +584,27 @@ def get_vocab_size(self) -> int:
def get_hidden_size(self) -> int:
return self.hf_text_config.hidden_size

def find_flash_attn_supported_head_dims(self, head_dim: int) -> int:
"""
Find the closest head dimension to the given head dimension that
is supported by Flash Attention.
"""
from vllm.attention.backends.flash_attn import FlashAttentionBackend

FLASHATTN_SUPPORTED_HEAD_DIMS = (
FlashAttentionBackend.get_supported_head_sizes())

for supported_head_dim in FLASHATTN_SUPPORTED_HEAD_DIMS:
if head_dim <= supported_head_dim:
return supported_head_dim
raise ValueError(
f"Head dimension {head_dim} is not supported by Flash Attention."
f"Supported head dimensions are {FLASHATTN_SUPPORTED_HEAD_DIMS}.")

LeiWang1999 marked this conversation as resolved.
Show resolved Hide resolved
def get_head_size(self) -> int:
# TODO remove hard code
if hasattr(self.hf_text_config, "model_type"
) and self.hf_text_config.model_type == 'deepseek_v2':
if (hasattr(self.hf_text_config, "model_type")
and self.hf_text_config.model_type == "deepseek_v2"):
# FlashAttention supports only head_size 32, 64, 128, 256,
# we need to pad head_size 192 to 256
return 256
Expand Down Expand Up @@ -623,8 +640,11 @@ def get_total_num_kv_heads(self) -> int:
return self.hf_config.attn_config["kv_n_heads"]
return self.hf_config.num_attention_heads
if self.hf_config.model_type == "dbrx":
return getattr(self.hf_config.attn_config, "kv_n_heads",
self.hf_config.num_attention_heads)
return getattr(
self.hf_config.attn_config,
"kv_n_heads",
self.hf_config.num_attention_heads,
)

if self.is_attention_free:
return 0
Expand Down Expand Up @@ -814,6 +834,7 @@ class TokenizerPoolConfig:
The way the config will be used depends on the
pool type.
"""

pool_size: int
pool_type: Union[str, Type["BaseTokenizerGroup"]]
extra_config: dict
Expand Down Expand Up @@ -849,9 +870,11 @@ def create_config(
else:
tokenizer_pool_extra_config_parsed = (
tokenizer_pool_extra_config or {})
tokenizer_pool_config = cls(tokenizer_pool_size,
tokenizer_pool_type,
tokenizer_pool_extra_config_parsed)
tokenizer_pool_config = cls(
tokenizer_pool_size,
tokenizer_pool_type,
tokenizer_pool_extra_config_parsed,
)
else:
tokenizer_pool_config = None
return tokenizer_pool_config
Expand Down Expand Up @@ -1006,6 +1029,7 @@ def __init__(
# current node and we aren't in a ray placement group.

from vllm.executor import ray_utils

backend = "mp"
ray_found = ray_utils.ray_is_available()
if (current_platform.is_cuda()
Expand All @@ -1021,8 +1045,10 @@ def __init__(
backend = "ray"
else:
from ray import is_initialized as ray_is_initialized

if ray_is_initialized():
from ray.util import get_current_placement_group

if get_current_placement_group():
backend = "ray"
self.distributed_executor_backend = backend
Expand Down Expand Up @@ -1380,8 +1406,8 @@ def maybe_create_spec_config(

draft_hf_config = draft_model_config.hf_config

if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
if num_speculative_tokens is not None and hasattr(
draft_hf_config, "num_lookahead_tokens"):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens

n_predict = getattr(draft_hf_config, "n_predict", None)
Expand Down Expand Up @@ -1713,7 +1739,8 @@ def verify_with_model_config(self, model_config: ModelConfig):
elif isinstance(self.lora_dtype, str):
self.lora_dtype = getattr(torch, self.lora_dtype)
if model_config.quantization and model_config.quantization not in [
"awq", "gptq"
"awq",
"gptq",
]:
# TODO support marlin
logger.warning("%s quantization is not tested with LoRA yet.",
Expand Down Expand Up @@ -1876,8 +1903,8 @@ def _get_and_verify_max_len(
for key in possible_keys:
max_len = getattr(hf_config, key, None)
if max_len is not None:
max_len_key = key if max_len < derived_max_model_len \
else max_len_key
max_len_key = (key
if max_len < derived_max_model_len else max_len_key)
derived_max_model_len = min(derived_max_model_len, max_len)

# If sliding window is manually disabled, max_length should be less
Expand Down Expand Up @@ -1906,8 +1933,10 @@ def _get_and_verify_max_len(
logger.warning(
"The model's config.json does not contain any of the following "
"keys to determine the original maximum length of the model: "
"%s. Assuming the model's maximum length is %d.", possible_keys,
default_max_len)
"%s. Assuming the model's maximum length is %d.",
possible_keys,
default_max_len,
)
derived_max_model_len = default_max_len

rope_scaling = getattr(hf_config, "rope_scaling", None)
Expand Down Expand Up @@ -2001,10 +2030,10 @@ class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine"""

# Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
guided_decoding_backend: str = 'outlines'
guided_decoding_backend: str = "outlines"

def __post_init__(self):
valid_guided_backends = ['outlines', 'lm-format-enforcer']
valid_guided_backends = ["outlines", "lm-format-enforcer"]
backend = self.guided_decoding_backend
if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend},"
Expand All @@ -2014,6 +2043,7 @@ def __post_init__(self):
@dataclass
class ObservabilityConfig:
"""Configuration for observability."""

otlp_traces_endpoint: Optional[str] = None

# Collecting detailed timing information for each request can be expensive.
Expand Down
Loading