Files
xc-llm-ascend/vllm_ascend/patch/worker/patch_qwen3_c8.py

55 lines
2.2 KiB
Python
Raw Normal View History

[v0.18.0]feat(quant): add C8 INT8 KV cache support for GQA attention models (#7474) (#8007) backport of #7474 This PR adds C8 (INT8) KV cache quantization support for standard GQA attention models (e.g., Qwen3-32B W8A8C8). C8 uses static per-channel quantization scales to store KV cache in INT8, reducing KV cache memory by ~50% compared to BF16, enabling higher batch concurrency and longer context lengths on the same hardware. **Key changes:** 1. **`attention_v1.py`** — New `AscendC8AttentionBackendImpl` subclass of `AscendAttentionBackendImpl`: - `_prepare_c8_scales`: Shards per-channel scales/offsets to the current TP rank and pre-computes BF16 BNSD-shaped antiquant tensors (one-time per layer). - `_quantize_kv_to_int8`: Quantizes BF16 K/V to INT8 before `reshape_and_cache`, using pre-cached inverse scales. - `_forward_c8_decode`: FIA V1 BNSD paged attention with native INT8 KV and `perchannel` antiquant mode. - `_forward_c8_chunked_prefill`: Splits decode (FIA V1 BNSD paged INT8) and prefill (FIA V1 TND float) into two kernel calls. - `_forward_c8_fused_infer_attention`: Handles `PrefillNoCache` and `PrefillCacheHit` states. 2. **`quantization/methods/kv_c8.py`** — New `AscendC8KVCacheAttentionMethod` scheme: - Creates `k/v_cache_scale/offset` parameters via `_c8_kv_scale_weight_loader`, which handles per-channel scale shapes and lazy resizing. - Sets `layer.kv_cache_torch_dtype = torch.int8` so `get_kv_cache_spec()` returns INT8 dtype automatically. - Upgrades `layer.impl` to `AscendC8AttentionBackendImpl` via class surgery. 3. **`quantization/modelslim_config.py`** — C8 branch in `get_quant_method()` activates when `kv_cache_type == "C8"` in `quant_model_description.json`. 4. **`patch/worker/patch_qwen3_c8.py`** — Intercepts per-channel C8 scale/offset weights before `AutoWeightsLoader` discards them, routing them to the parameters created by `AscendC8KVCacheAttentionMethod`. 5. **`tests/ut/quantization/test_kv_c8.py`** — Unit tests covering `_c8_kv_scale_weight_loader`, `AscendC8KVCacheAttentionMethod`, and `AscendC8AttentionBackendImpl` scale helpers. Yes. Users can now serve Qwen3-32B W8A8C8 quantized models with INT8 KV cache on Ascend NPU. The model checkpoint must contain a `quant_model_description.json` with `"kv_cache_type": "C8"` and per-channel scale/offset tensors in safetensors. No changes to the serving CLI — the feature activates automatically when the quantization config is detected. Benchmarked with `vllm serve` (TP=8, `max_num_seqs=256`, `max_model_len=131072`, `enable_chunked_prefill=true`) + `random_bench` (input_len=10240, output_len=2048, 960 prompts, max_concurrency=192): ``` ============ Serving Benchmark Result ============ Successful requests: 960 Failed requests: 0 Maximum request concurrency: 192 Benchmark duration (s): 1359.81 Total input tokens: 9830400 Total generated tokens: 1966080 Request throughput (req/s): 0.71 Output token throughput (tok/s): 1445.85 Peak output token throughput (tok/s): 2304.00 Total token throughput (tok/s): 8675.12 ---------------Time to First Token---------------- Mean TTFT (ms): 24598.51 Median TTFT (ms): 23167.02 P50 TTFT (ms): 23167.02 P90 TTFT (ms): 47717.08 P99 TTFT (ms): 84402.61 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 120.76 Median TPOT (ms): 121.50 P50 TPOT (ms): 121.50 P90 TPOT (ms): 127.05 P99 TPOT (ms): 130.13 ---------------Inter-token Latency---------------- Mean ITL (ms): 120.70 Median ITL (ms): 90.34 P50 ITL (ms): 90.34 P90 ITL (ms): 93.79 P99 ITL (ms): 101.80 ================================================== ``` All attention states verified: `PrefillNoCache`, `PrefillCacheHit`, `ChunkedPrefill`, `DecodeOnly`. - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/8b6325758cce5f9c36d38f2462edbd368b97a07c Signed-off-by: lico67373 <918688502@qq.com> Co-authored-by: LICO67373 <110013619+LICO1314@users.noreply.github.com>
2026-04-08 10:51:58 +08:00
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections.abc import Iterable
import torch
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
_orig_qwen3_causal_lm_load_weights = Qwen3ForCausalLM.load_weights
def _patched_qwen3_causal_lm_load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
quant_config = self.quant_config
if quant_config is None or not callable(getattr(quant_config, "get_cache_scale", None)):
return _orig_qwen3_causal_lm_load_weights(self, weights)
params_dict = dict(self.named_parameters())
c8_loaded_params: set[str] = set()
def _intercept_c8_scales(
raw_weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[tuple[str, torch.Tensor]]:
for name, loaded_weight in raw_weights:
scale_name = quant_config.get_cache_scale(name)
if scale_name is not None:
if scale_name in params_dict:
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight.squeeze())
c8_loaded_params.add(scale_name)
else:
yield name, loaded_weight
loaded_params = _orig_qwen3_causal_lm_load_weights(self, _intercept_c8_scales(weights))
loaded_params.update(c8_loaded_params)
return loaded_params
Qwen3ForCausalLM.load_weights = _patched_qwen3_causal_lm_load_weights