backport of #7474
This PR adds C8 (INT8) KV cache quantization support for standard GQA
attention models (e.g., Qwen3-32B W8A8C8). C8 uses static per-channel
quantization scales to store KV cache in INT8, reducing KV cache memory
by ~50% compared to BF16, enabling higher batch concurrency and longer
context lengths on the same hardware.
**Key changes:**
1. **`attention_v1.py`** — New `AscendC8AttentionBackendImpl` subclass
of `AscendAttentionBackendImpl`:
- `_prepare_c8_scales`: Shards per-channel scales/offsets to the current
TP rank and pre-computes BF16 BNSD-shaped antiquant tensors (one-time
per layer).
- `_quantize_kv_to_int8`: Quantizes BF16 K/V to INT8 before
`reshape_and_cache`, using pre-cached inverse scales.
- `_forward_c8_decode`: FIA V1 BNSD paged attention with native INT8 KV
and `perchannel` antiquant mode.
- `_forward_c8_chunked_prefill`: Splits decode (FIA V1 BNSD paged INT8)
and prefill (FIA V1 TND float) into two kernel calls.
- `_forward_c8_fused_infer_attention`: Handles `PrefillNoCache` and
`PrefillCacheHit` states.
2. **`quantization/methods/kv_c8.py`** — New
`AscendC8KVCacheAttentionMethod` scheme:
- Creates `k/v_cache_scale/offset` parameters via
`_c8_kv_scale_weight_loader`, which handles per-channel scale shapes and
lazy resizing.
- Sets `layer.kv_cache_torch_dtype = torch.int8` so
`get_kv_cache_spec()` returns INT8 dtype automatically.
- Upgrades `layer.impl` to `AscendC8AttentionBackendImpl` via class
surgery.
3. **`quantization/modelslim_config.py`** — C8 branch in
`get_quant_method()` activates when `kv_cache_type == "C8"` in
`quant_model_description.json`.
4. **`patch/worker/patch_qwen3_c8.py`** — Intercepts per-channel C8
scale/offset weights before `AutoWeightsLoader` discards them, routing
them to the parameters created by `AscendC8KVCacheAttentionMethod`.
5. **`tests/ut/quantization/test_kv_c8.py`** — Unit tests covering
`_c8_kv_scale_weight_loader`, `AscendC8KVCacheAttentionMethod`, and
`AscendC8AttentionBackendImpl` scale helpers.
Yes. Users can now serve Qwen3-32B W8A8C8 quantized models with INT8 KV
cache on Ascend NPU. The model checkpoint must contain a
`quant_model_description.json` with `"kv_cache_type": "C8"` and
per-channel scale/offset tensors in safetensors.
No changes to the serving CLI — the feature activates automatically when
the quantization config is detected.
Benchmarked with `vllm serve` (TP=8, `max_num_seqs=256`,
`max_model_len=131072`, `enable_chunked_prefill=true`) + `random_bench`
(input_len=10240, output_len=2048, 960 prompts, max_concurrency=192):
```
============ Serving Benchmark Result ============
Successful requests: 960
Failed requests: 0
Maximum request concurrency: 192
Benchmark duration (s): 1359.81
Total input tokens: 9830400
Total generated tokens: 1966080
Request throughput (req/s): 0.71
Output token throughput (tok/s): 1445.85
Peak output token throughput (tok/s): 2304.00
Total token throughput (tok/s): 8675.12
---------------Time to First Token----------------
Mean TTFT (ms): 24598.51
Median TTFT (ms): 23167.02
P50 TTFT (ms): 23167.02
P90 TTFT (ms): 47717.08
P99 TTFT (ms): 84402.61
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 120.76
Median TPOT (ms): 121.50
P50 TPOT (ms): 121.50
P90 TPOT (ms): 127.05
P99 TPOT (ms): 130.13
---------------Inter-token Latency----------------
Mean ITL (ms): 120.70
Median ITL (ms): 90.34
P50 ITL (ms): 90.34
P90 ITL (ms): 93.79
P99 ITL (ms): 101.80
==================================================
```
All attention states verified: `PrefillNoCache`, `PrefillCacheHit`,
`ChunkedPrefill`, `DecodeOnly`.
- vLLM version: v0.17.0
- vLLM main:
8b6325758c
Signed-off-by: lico67373 <918688502@qq.com>
Co-authored-by: LICO67373 <110013619+LICO1314@users.noreply.github.com>
This commit is contained in:
@@ -429,6 +429,21 @@ class AscendModelSlimConfig(QuantizationConfig):
|
||||
self._add_kvcache_quant_metadata()
|
||||
logger.info("Applied hf_to_vllm_mapper to quant_description keys")
|
||||
|
||||
def get_cache_scale(self, name: str) -> str | None:
|
||||
"""Map checkpoint C8 KV scale/offset names to vLLM parameter names."""
|
||||
if self.quant_description.get("kv_cache_type") != "C8":
|
||||
return None
|
||||
_C8_SCALE_MAPPING = {
|
||||
"k_proj.kv_cache_scale": "attn.k_cache_scale",
|
||||
"k_proj.kv_cache_offset": "attn.k_cache_offset",
|
||||
"v_proj.kv_cache_scale": "attn.v_cache_scale",
|
||||
"v_proj.kv_cache_offset": "attn.v_cache_offset",
|
||||
}
|
||||
for src_suffix, dst_suffix in _C8_SCALE_MAPPING.items():
|
||||
if name.endswith(src_suffix):
|
||||
return name[: -len(src_suffix)] + dst_suffix
|
||||
return None
|
||||
|
||||
def quant_prefix_mapper(self, model_type: str, prefix: str) -> str:
|
||||
self.model_type = model_type
|
||||
return prefix
|
||||
@@ -476,6 +491,10 @@ class AscendModelSlimConfig(QuantizationConfig):
|
||||
):
|
||||
scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping)
|
||||
return AscendKVCacheMethod(scheme)
|
||||
elif isinstance(layer, AttentionLayerBase) and self.quant_description.get("kv_cache_type") == "C8":
|
||||
from .methods.kv_c8 import AscendC8KVCacheAttentionMethod
|
||||
|
||||
return AscendKVCacheMethod(AscendC8KVCacheAttentionMethod(self.quant_description, prefix))
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping):
|
||||
# Delayed import to avoid circular import
|
||||
|
||||
Reference in New Issue
Block a user