backport of #7474
This PR adds C8 (INT8) KV cache quantization support for standard GQA
attention models (e.g., Qwen3-32B W8A8C8). C8 uses static per-channel
quantization scales to store KV cache in INT8, reducing KV cache memory
by ~50% compared to BF16, enabling higher batch concurrency and longer
context lengths on the same hardware.
**Key changes:**
1. **`attention_v1.py`** — New `AscendC8AttentionBackendImpl` subclass
of `AscendAttentionBackendImpl`:
- `_prepare_c8_scales`: Shards per-channel scales/offsets to the current
TP rank and pre-computes BF16 BNSD-shaped antiquant tensors (one-time
per layer).
- `_quantize_kv_to_int8`: Quantizes BF16 K/V to INT8 before
`reshape_and_cache`, using pre-cached inverse scales.
- `_forward_c8_decode`: FIA V1 BNSD paged attention with native INT8 KV
and `perchannel` antiquant mode.
- `_forward_c8_chunked_prefill`: Splits decode (FIA V1 BNSD paged INT8)
and prefill (FIA V1 TND float) into two kernel calls.
- `_forward_c8_fused_infer_attention`: Handles `PrefillNoCache` and
`PrefillCacheHit` states.
2. **`quantization/methods/kv_c8.py`** — New
`AscendC8KVCacheAttentionMethod` scheme:
- Creates `k/v_cache_scale/offset` parameters via
`_c8_kv_scale_weight_loader`, which handles per-channel scale shapes and
lazy resizing.
- Sets `layer.kv_cache_torch_dtype = torch.int8` so
`get_kv_cache_spec()` returns INT8 dtype automatically.
- Upgrades `layer.impl` to `AscendC8AttentionBackendImpl` via class
surgery.
3. **`quantization/modelslim_config.py`** — C8 branch in
`get_quant_method()` activates when `kv_cache_type == "C8"` in
`quant_model_description.json`.
4. **`patch/worker/patch_qwen3_c8.py`** — Intercepts per-channel C8
scale/offset weights before `AutoWeightsLoader` discards them, routing
them to the parameters created by `AscendC8KVCacheAttentionMethod`.
5. **`tests/ut/quantization/test_kv_c8.py`** — Unit tests covering
`_c8_kv_scale_weight_loader`, `AscendC8KVCacheAttentionMethod`, and
`AscendC8AttentionBackendImpl` scale helpers.
Yes. Users can now serve Qwen3-32B W8A8C8 quantized models with INT8 KV
cache on Ascend NPU. The model checkpoint must contain a
`quant_model_description.json` with `"kv_cache_type": "C8"` and
per-channel scale/offset tensors in safetensors.
No changes to the serving CLI — the feature activates automatically when
the quantization config is detected.
Benchmarked with `vllm serve` (TP=8, `max_num_seqs=256`,
`max_model_len=131072`, `enable_chunked_prefill=true`) + `random_bench`
(input_len=10240, output_len=2048, 960 prompts, max_concurrency=192):
```
============ Serving Benchmark Result ============
Successful requests: 960
Failed requests: 0
Maximum request concurrency: 192
Benchmark duration (s): 1359.81
Total input tokens: 9830400
Total generated tokens: 1966080
Request throughput (req/s): 0.71
Output token throughput (tok/s): 1445.85
Peak output token throughput (tok/s): 2304.00
Total token throughput (tok/s): 8675.12
---------------Time to First Token----------------
Mean TTFT (ms): 24598.51
Median TTFT (ms): 23167.02
P50 TTFT (ms): 23167.02
P90 TTFT (ms): 47717.08
P99 TTFT (ms): 84402.61
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 120.76
Median TPOT (ms): 121.50
P50 TPOT (ms): 121.50
P90 TPOT (ms): 127.05
P99 TPOT (ms): 130.13
---------------Inter-token Latency----------------
Mean ITL (ms): 120.70
Median ITL (ms): 90.34
P50 ITL (ms): 90.34
P90 ITL (ms): 93.79
P99 ITL (ms): 101.80
==================================================
```
All attention states verified: `PrefillNoCache`, `PrefillCacheHit`,
`ChunkedPrefill`, `DecodeOnly`.
- vLLM version: v0.17.0
- vLLM main:
8b6325758c
Signed-off-by: lico67373 <918688502@qq.com>
Co-authored-by: LICO67373 <110013619+LICO1314@users.noreply.github.com>
148 lines
7.1 KiB
Python
148 lines
7.1 KiB
Python
import torch
|
|
from vllm.config import get_current_vllm_config
|
|
from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
|
|
|
|
from .base import AscendAttentionScheme
|
|
from .registry import register_scheme
|
|
|
|
|
|
def _fa_quant_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor):
|
|
"""Weight loader for MLA-based C8 (FAKQuant) models."""
|
|
if param.numel() == 1 and loaded_weight.numel() == 1:
|
|
param.data.fill_(loaded_weight.item())
|
|
else:
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
shard_size = loaded_weight.shape[0] // tp_size
|
|
loaded_weight = loaded_weight.narrow(0, shard_size * tp_rank, shard_size)
|
|
assert param.size() == loaded_weight.size(), (
|
|
f"Attempted to load weight ({loaded_weight.size()}) into parameter ({param.size()}) when TP is ({tp_size})"
|
|
)
|
|
|
|
param.data.copy_(loaded_weight)
|
|
|
|
|
|
@register_scheme("FAKQuant", "attention")
|
|
class AscendFAQuantAttentionMethod:
|
|
def __init__(self):
|
|
self.transpose_weight = True
|
|
self.printFlag = False
|
|
vllm_config = get_current_vllm_config()
|
|
config = vllm_config.model_config.hf_config
|
|
self.kv_lora_rank = getattr(config, "kv_lora_rank", 0)
|
|
self.qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
|
|
|
|
def create_weights(self, layer: torch.nn.Module) -> None:
|
|
extra_module_names = ["fa_q", "fa_k", "fa_v"]
|
|
for name in extra_module_names:
|
|
setattr(layer, name, torch.nn.Module())
|
|
params_dict = {}
|
|
dtype = torch.get_default_dtype()
|
|
params_dict["fa_q.scale"] = torch.empty((layer.num_heads, 1), dtype=dtype)
|
|
params_dict["fa_k.scale"] = torch.empty((layer.num_kv_heads, 1), dtype=dtype)
|
|
params_dict["fa_v.scale"] = torch.empty((layer.num_kv_heads, 1), dtype=dtype)
|
|
params_dict["fa_q.offset"] = torch.empty((layer.num_heads, 1), dtype=torch.int8)
|
|
params_dict["fa_k.offset"] = torch.empty((layer.num_kv_heads, 1), dtype=torch.int8)
|
|
params_dict["fa_v.offset"] = torch.empty((layer.num_kv_heads, 1), dtype=torch.int8)
|
|
|
|
for name, weight in params_dict.items():
|
|
module_name, weight_name = name.rsplit(".", 1)
|
|
module = getattr(layer, module_name)
|
|
weight_param = torch.nn.Parameter(weight, requires_grad=False)
|
|
module.register_parameter(weight_name, weight_param)
|
|
# When loading weights, segment them according to TP
|
|
weight_param.weight_loader = _fa_quant_weight_loader
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
fa_k_scale = torch.squeeze(layer.fa_k.scale).unsqueeze(0)
|
|
layer.fak_descale_float = torch.nn.Parameter(fa_k_scale.to(torch.float), requires_grad=False)
|
|
layer.fak_descale = torch.nn.Parameter(fa_k_scale, requires_grad=False)
|
|
layer.fak_descale_reciprocal = 1.0 / torch.nn.Parameter(fa_k_scale, requires_grad=False)
|
|
fa_k_offset = torch.squeeze(layer.fa_k.offset).unsqueeze(0)
|
|
layer.fak_offset = torch.nn.Parameter(fa_k_offset.to(layer.fak_descale.dtype), requires_grad=False)
|
|
|
|
repeated_quant_kscale = fa_k_scale.repeat(self.kv_lora_rank)
|
|
layer.quant_kscale = repeated_quant_kscale.view(1, self.kv_lora_rank)
|
|
layer.quant_kscale = 1.0 / torch.nn.Parameter(layer.quant_kscale.to(torch.float), requires_grad=False)
|
|
|
|
|
|
@register_scheme("INT8_DYNAMIC", "attention")
|
|
class AscendSFAQuantAttentionMethod:
|
|
def __init__(self):
|
|
vllm_config = get_current_vllm_config()
|
|
config = vllm_config.model_config.hf_config
|
|
self.index_head_dim = config.index_head_dim
|
|
|
|
def create_weights(self, layer: torch.nn.Module) -> None:
|
|
extra_module_names = ["indexer"]
|
|
for name in extra_module_names:
|
|
setattr(layer, name, torch.nn.Module())
|
|
params_dict = {}
|
|
params_dict["indexer.q_rot"] = torch.empty((self.index_head_dim, self.index_head_dim), dtype=torch.float32)
|
|
params_dict["indexer.k_rot"] = torch.empty((self.index_head_dim, self.index_head_dim), dtype=torch.float32)
|
|
for name, weight in params_dict.items():
|
|
module_name, weight_name = name.split(".")
|
|
module = getattr(layer, module_name)
|
|
weight_param = torch.nn.Parameter(weight, requires_grad=False)
|
|
module.register_parameter(weight_name, weight_param)
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
pass
|
|
|
|
|
|
def _c8_kv_scale_weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None:
|
|
"""Weight loader for dense-attention C8 KV cache scales/offsets."""
|
|
loaded_weight = loaded_weight.squeeze()
|
|
if param.data.shape != loaded_weight.shape:
|
|
param.data = loaded_weight.to(param.dtype).clone()
|
|
else:
|
|
param.data.copy_(loaded_weight)
|
|
|
|
|
|
class AscendC8KVCacheAttentionMethod(AscendAttentionScheme):
|
|
"""C8 INT8 KV cache quantization for dense-attention models (e.g. Qwen3)."""
|
|
|
|
def __init__(self, quant_description: dict, prefix: str):
|
|
self.quant_description = quant_description
|
|
self.prefix = prefix
|
|
|
|
def create_weights(self, layer: torch.nn.Module) -> None:
|
|
# Override kv_cache_torch_dtype so Attention.get_kv_cache_spec returns int8 automatically.
|
|
layer.kv_cache_torch_dtype = torch.int8
|
|
# Upgrade impl to the C8-specific subclass so the C8 forward path is always used.
|
|
if hasattr(layer, "impl"):
|
|
from vllm_ascend.attention.attention_v1 import AscendC8AttentionBackendImpl
|
|
|
|
layer.impl.__class__ = AscendC8AttentionBackendImpl
|
|
layer.k_cache_scale = torch.nn.Parameter(torch.ones(1, dtype=torch.float32), requires_grad=False)
|
|
layer.k_cache_scale.weight_loader = _c8_kv_scale_weight_loader
|
|
layer.k_cache_offset = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=False)
|
|
layer.k_cache_offset.weight_loader = _c8_kv_scale_weight_loader
|
|
layer.v_cache_scale = torch.nn.Parameter(torch.ones(1, dtype=torch.float32), requires_grad=False)
|
|
layer.v_cache_scale.weight_loader = _c8_kv_scale_weight_loader
|
|
layer.v_cache_offset = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=False)
|
|
layer.v_cache_offset.weight_loader = _c8_kv_scale_weight_loader
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
layer.k_cache_scale.data = layer.k_cache_scale.data.flatten()
|
|
layer.k_cache_offset.data = layer.k_cache_offset.data.flatten()
|
|
layer.v_cache_scale.data = layer.v_cache_scale.data.flatten()
|
|
layer.v_cache_offset.data = layer.v_cache_offset.data.flatten()
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache,
|
|
attn_metadata,
|
|
attn_type,
|
|
scale,
|
|
output,
|
|
) -> torch.Tensor:
|
|
raise RuntimeError(
|
|
"AscendC8KVCacheAttentionMethod.apply should not be called. "
|
|
"C8 KV cache quantization is handled by the attention backend."
|
|
)
|