Skip to content

Commit

Permalink
[ROCm] [Feature] [Doc] [Dockerfile] [BugFix] Support Per-Token-Activa…
Browse files Browse the repository at this point in the history
…tion Per-Channel-Weight FP8 Quantization Inferencing (#12501)
  • Loading branch information
tjtanaa authored Feb 7, 2025
1 parent 0630d45 commit eaa92d4
Show file tree
Hide file tree
Showing 8 changed files with 295 additions and 32 deletions.
2 changes: 1 addition & 1 deletion Dockerfile.rocm_base
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ARG RCCL_BRANCH="648a58d"
ARG RCCL_REPO="https://github.com/ROCm/rccl"
ARG TRITON_BRANCH="e5be006"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
ARG PYTORCH_BRANCH="8d4926e"
ARG PYTORCH_BRANCH="3a585126"
ARG PYTORCH_VISION_BRANCH="v0.19.1"
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
Expand Down
64 changes: 45 additions & 19 deletions docs/source/getting_started/installation/gpu/rocm.inc.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Installation

vLLM supports AMD GPUs with ROCm 6.2.
vLLM supports AMD GPUs with ROCm 6.3.

:::{attention}
There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source.
Expand All @@ -9,7 +9,7 @@ There are no pre-built wheels for this device, so you must either use the pre-bu
## Requirements

- GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
- ROCm 6.2
- ROCm 6.3

## Set up using Python

Expand All @@ -24,9 +24,15 @@ Currently, there are no pre-built ROCm wheels.
- [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/index.html)
- [PyTorch](https://pytorch.org/)

For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0`, `rocm/pytorch-nightly`.
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.3_ubuntu24.04_py3.12_pytorch_release_2.4.0`, `rocm/pytorch-nightly`. If you are using docker image, you can skip to Step 3.

Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/)
Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/). Example:

```console
# Install PyTorch
$ pip uninstall torch -y
$ pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/rocm6.3
```

1. Install [Triton flash attention for ROCm](https://github.com/ROCm/triton)

Expand All @@ -37,7 +43,7 @@ Currently, there are no pre-built ROCm wheels.
pip uninstall -y triton
git clone https://github.com/OpenAI/triton.git
cd triton
git checkout e192dba
git checkout e5be006
cd python
pip3 install .
cd ../..
Expand All @@ -49,15 +55,15 @@ Currently, there are no pre-built ROCm wheels.

2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/ROCm/flash-attention/tree/ck_tile)

Install ROCm's flash attention (v2.5.9.post1) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support)
Install ROCm's flash attention (v2.7.2) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support)
Alternatively, wheels intended for vLLM use can be accessed under the releases.

For example, for ROCm 6.2, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`.
For example, for ROCm 6.3, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`.

```console
git clone https://github.com/ROCm/flash-attention.git
cd flash-attention
git checkout 3cea2fb
git checkout b7d29fb
git submodule update --init
GPU_ARCHS="gfx90a" python3 setup.py install
cd ..
Expand All @@ -67,20 +73,16 @@ Currently, there are no pre-built ROCm wheels.
You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
:::

3. Build vLLM. For example, vLLM on ROCM 6.2 can be built with the following steps:
3. Build vLLM. For example, vLLM on ROCM 6.3 can be built with the following steps:

```bash
$ pip install --upgrade pip

# Install PyTorch
$ pip uninstall torch -y
$ pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/rocm6.2

# Build & install AMD SMI
$ pip install /opt/rocm/share/amd_smi

# Install dependencies
$ pip install --upgrade numba scipy huggingface-hub[cli]
$ pip install --upgrade numba scipy huggingface-hub[cli,hf_transfer] setuptools_scm
$ pip install "numpy<2"
$ pip install -r requirements-rocm.txt

Expand All @@ -104,7 +106,7 @@ Currently, there are no pre-built ROCm wheels.
For vLLM, please refer to [vLLM performance optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#vllm-performance-optimization).
:::

## Set up using Docker
## Set up using Docker (Recommended)

### Pre-built images

Expand All @@ -120,7 +122,12 @@ for instructions on how to use this prebuilt docker image.

Building the Docker image from source is the recommended way to use vLLM with ROCm.

First, build a docker image from <gh-file:Dockerfile.rocm> and launch a docker container from the image.
#### (Optional) Build an image with ROCm software stack

Build a docker image from <gh-file:Dockerfile.rocm_base> which setup ROCm software stack needed by the vLLM.
**This step is optional as this rocm_base image is usually prebuilt and store at [Docker Hub](https://hub.docker.com/r/rocm/vllm-dev) under tag `rocm/vllm-dev:base` to speed up user experience.**
If you choose to build this rocm_base image yourself, the steps are as follows.

It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon:

```console
Expand All @@ -131,7 +138,26 @@ It is important that the user kicks off the docker build using buildkit. Either
}
```

<gh-file:Dockerfile.rocm> uses ROCm 6.2 by default, but also supports ROCm 5.7, 6.0 and 6.1 in older vLLM branches.
To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default:

```console
DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm_base -t rocm/vllm-dev:base .
```

#### Build an image with vLLM

First, build a docker image from <gh-file:Dockerfile.rocm> and launch a docker container from the image.
It is important that the user kicks off the docker build using buildkit. Either the user put `DOCKER_BUILDKIT=1` as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon:

```console
{
"features": {
"buildkit": true
}
}
```

<gh-file:Dockerfile.rocm> uses ROCm 6.3 by default, but also supports ROCm 5.7, 6.0, 6.1, and 6.2, in older vLLM branches.
It provides flexibility to customize the build of docker image using the following arguments:

- `BASE_IMAGE`: specifies the base image used when running `docker build`. The default value `rocm/vllm-dev:base` is an image published and maintained by AMD. It is being built using <gh-file:Dockerfile.rocm_base>
Expand All @@ -141,13 +167,13 @@ It provides flexibility to customize the build of docker image using the followi

Their values can be passed in when running `docker build` with `--build-arg` options.

To build vllm on ROCm 6.2 for MI200 and MI300 series, you can use the default:
To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default:

```console
DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm .
```

To build vllm on ROCm 6.2 for Radeon RX7900 series (gfx1100), you should pick the alternative base image:
To build vllm on ROCm 6.3 for Radeon RX7900 series (gfx1100), you should pick the alternative base image:

```console
DOCKER_BUILDKIT=1 docker build --build-arg BASE_IMAGE="rocm/vllm-dev:navi_base" -f Dockerfile.rocm -t vllm-rocm .
Expand Down
49 changes: 38 additions & 11 deletions tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,21 @@ def check_model(model):

assert isinstance(attn.quant_method, Fp8KVCacheMethod)

# NOTE: it is valid for scales to be 1.0 (default value), but
# we know these checkpoints have scales < 1.0
assert 0.0 < attn._k_scale < 1.0
assert 0.0 < attn._v_scale < 1.0
if not current_platform.is_rocm():
# NOTE: This code path requires validation on Non-CUDA platform
# NOTE: it is valid for scales to be 1.0 (default value), but
# we know these checkpoints have scales < 1.0
assert 0.0 < attn._k_scale < 1.0
assert 0.0 < attn._v_scale < 1.0
else:
# NOTE: This code path is for ROCm platform
# NOTE: it is valid for scales to be 1.0 (default value), but
# we know these checkpoints have scales < 1.0
# However on ROCm platform, the _k_scale and _v_scale will be
# scaled by a factor of 2 as described in
# vllm/model_executor/layers/quantization/kv_cache.py
assert 0.0 < attn._k_scale < (1.0 * 2.0)
assert 0.0 < attn._v_scale < (1.0 * 2.0)

llm.apply_model(check_model)

Expand Down Expand Up @@ -91,13 +102,29 @@ def check_model(model):
assert attn._k_scale == 1.0
assert attn._v_scale == 1.0

if current_platform.has_device_capability(89) and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn
else:
# For GPUs without hardware support, we pack the fp8 weights
# for weight-only quantization using Marlin kernels
assert fc1.weight.dtype == torch.int32
if current_platform.is_cuda():
if current_platform.has_device_capability(
89) and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn
else:
# For GPUs without hardware support, we pack the fp8 weights
# for weight-only quantization using Marlin kernels
assert fc1.weight.dtype == torch.int32
elif current_platform.is_rocm():
# Only MI300 and above support quantization='fp8'
if current_platform.has_device_capability(
94) and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fnuz
else: # unsupported ROCm platform
pytest.skip(
"Skip `test_load_fp16_model`. "
"It only runs on ROCm platform with FP8 compute."
" e.g. MI300X and above.")
else: # unsupported platform
pytest.skip("Skip `test_load_fp16_model`. "
"It only runs on CUDA and ROCm platform.")

llm.apply_model(check_model)

Expand Down
55 changes: 55 additions & 0 deletions tests/quantization/test_ptpc_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests whether PTPC w8a8 FP8 computation is enabled correctly.
Run `pytest tests/quantization/test_ptpc_fp8.py --forked`.
"""
import pytest
import torch

from tests.quantization.utils import is_quant_method_supported
from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod
from vllm.model_executor.layers.quantization.ptpc_fp8 import (
PTPCFp8LinearMethod)
from vllm.platforms import current_platform


@pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"),
reason="PTPC FP8 is not supported on this GPU type.")
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="This test is for ROCm GPU.")
@pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"])
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:

try:
with vllm_runner("facebook/opt-125m",
dtype=dtype,
quantization="ptpc_fp8",
kv_cache_dtype=kv_cache_dtype) as llm:

model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.quant_method, PTPCFp8LinearMethod)
if kv_cache_dtype == "ptpc_fp8":
attn = model.model.decoder.layers[0].self_attn.attn
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
assert attn._k_scale == 1.0
assert attn._v_scale == 1.0

if current_platform.has_device_capability(94):
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fnuz
else:
pytest.skip()

output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
except AssertionError as e:
if str(
e
) == "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified.": # noqa: E501
# If the error message matches, the test passes
pass
else:
# If the error message does not match, re-raise the exception
raise
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"deepspeedfp",
"tpu_int8",
"fp8",
"ptpc_fp8",
"fbgemm_fp8",
"modelopt",
# The order of gptq methods is important for config.py iteration over
Expand Down Expand Up @@ -99,6 +100,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .modelopt import ModelOptFp8Config
from .moe_wna16 import MoeWNA16Config
from .neuron_quant import NeuronQuantConfig
from .ptpc_fp8 import PTPCFp8Config
from .qqq import QQQConfig
from .tpu_int8 import Int8TpuConfig

Expand All @@ -120,6 +122,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"ptpc_fp8": PTPCFp8Config,
"qqq": QQQConfig,
"hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
Expand Down
Loading

0 comments on commit eaa92d4

Please sign in to comment.