[v0.18.0]feat(quant): add C8 INT8 KV cache support for GQA attention models (#7474) (#8007)

backport of #7474

This PR adds C8 (INT8) KV cache quantization support for standard GQA
attention models (e.g., Qwen3-32B W8A8C8). C8 uses static per-channel
quantization scales to store KV cache in INT8, reducing KV cache memory
by ~50% compared to BF16, enabling higher batch concurrency and longer
context lengths on the same hardware.

**Key changes:**

1. **`attention_v1.py`** — New `AscendC8AttentionBackendImpl` subclass
of `AscendAttentionBackendImpl`:
- `_prepare_c8_scales`: Shards per-channel scales/offsets to the current
TP rank and pre-computes BF16 BNSD-shaped antiquant tensors (one-time
per layer).
- `_quantize_kv_to_int8`: Quantizes BF16 K/V to INT8 before
`reshape_and_cache`, using pre-cached inverse scales.
- `_forward_c8_decode`: FIA V1 BNSD paged attention with native INT8 KV
and `perchannel` antiquant mode.
- `_forward_c8_chunked_prefill`: Splits decode (FIA V1 BNSD paged INT8)
and prefill (FIA V1 TND float) into two kernel calls.
- `_forward_c8_fused_infer_attention`: Handles `PrefillNoCache` and
`PrefillCacheHit` states.

2. **`quantization/methods/kv_c8.py`** — New
`AscendC8KVCacheAttentionMethod` scheme:
- Creates `k/v_cache_scale/offset` parameters via
`_c8_kv_scale_weight_loader`, which handles per-channel scale shapes and
lazy resizing.
- Sets `layer.kv_cache_torch_dtype = torch.int8` so
`get_kv_cache_spec()` returns INT8 dtype automatically.
- Upgrades `layer.impl` to `AscendC8AttentionBackendImpl` via class
surgery.

3. **`quantization/modelslim_config.py`** — C8 branch in
`get_quant_method()` activates when `kv_cache_type == "C8"` in
`quant_model_description.json`.

4. **`patch/worker/patch_qwen3_c8.py`** — Intercepts per-channel C8
scale/offset weights before `AutoWeightsLoader` discards them, routing
them to the parameters created by `AscendC8KVCacheAttentionMethod`.

5. **`tests/ut/quantization/test_kv_c8.py`** — Unit tests covering
`_c8_kv_scale_weight_loader`, `AscendC8KVCacheAttentionMethod`, and
`AscendC8AttentionBackendImpl` scale helpers.

Yes. Users can now serve Qwen3-32B W8A8C8 quantized models with INT8 KV
cache on Ascend NPU. The model checkpoint must contain a
`quant_model_description.json` with `"kv_cache_type": "C8"` and
per-channel scale/offset tensors in safetensors.

No changes to the serving CLI — the feature activates automatically when
the quantization config is detected.

Benchmarked with `vllm serve` (TP=8, `max_num_seqs=256`,
`max_model_len=131072`, `enable_chunked_prefill=true`) + `random_bench`
(input_len=10240, output_len=2048, 960 prompts, max_concurrency=192):

```
============ Serving Benchmark Result ============
Successful requests:                     960
Failed requests:                         0
Maximum request concurrency:             192
Benchmark duration (s):                  1359.81
Total input tokens:                      9830400
Total generated tokens:                  1966080
Request throughput (req/s):              0.71
Output token throughput (tok/s):         1445.85
Peak output token throughput (tok/s):    2304.00
Total token throughput (tok/s):          8675.12
---------------Time to First Token----------------
Mean TTFT (ms):                          24598.51
Median TTFT (ms):                        23167.02
P50 TTFT (ms):                           23167.02
P90 TTFT (ms):                           47717.08
P99 TTFT (ms):                           84402.61
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          120.76
Median TPOT (ms):                        121.50
P50 TPOT (ms):                           121.50
P90 TPOT (ms):                           127.05
P99 TPOT (ms):                           130.13
---------------Inter-token Latency----------------
Mean ITL (ms):                           120.70
Median ITL (ms):                         90.34
P50 ITL (ms):                            90.34
P90 ITL (ms):                            93.79
P99 ITL (ms):                            101.80
==================================================
```

All attention states verified: `PrefillNoCache`, `PrefillCacheHit`,
`ChunkedPrefill`, `DecodeOnly`.

- vLLM version: v0.17.0
- vLLM main:
8b6325758c

Signed-off-by: lico67373 <918688502@qq.com>
Co-authored-by: LICO67373 <110013619+LICO1314@users.noreply.github.com>
This commit is contained in:
Mengqing Cao
2026-04-08 10:51:58 +08:00
committed by GitHub
parent fbd5d0fd55
commit 044d4c3974
8 changed files with 761 additions and 8 deletions

View File

@@ -22,6 +22,7 @@ import torch
import torch_npu
import vllm.envs as envs_vllm
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import ( # type: ignore
AttentionBackend,
@@ -978,3 +979,364 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output
class AscendC8AttentionBackendImpl(AscendAttentionBackendImpl):
"""Attention backend implementation for INT8 KV cache (C8/QuaRot) models.
This subclass handles static per-channel INT8 KV cache quantization.
It is activated via class surgery in AscendC8KVCacheAttentionMethod.create_weights
(vllm_ascend/quantization/methods/kv_c8.py)
so that C8 attention layers automatically use this forward path.
"""
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError("fused output quantization is not yet supported for AscendC8AttentionBackendImpl")
num_tokens = query.shape[0]
if attn_metadata is None:
return output.fill_(0)
float_key, float_value = None, None
if key is not None and value is not None:
if attn_metadata.attn_state != AscendAttentionState.DecodeOnly:
float_key, float_value = key, value
key, value = self._quantize_kv_to_int8(key, value, layer, attn_metadata.num_actual_tokens)
query, key, value, _ = self.reshape_and_cache(query, key, value, kv_cache, attn_metadata, output)
if attn_metadata.model_runner_type == "pooling":
attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output
self._prepare_c8_scales(layer, query.device)
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
return self._forward_c8_decode(query, attn_metadata, output, layer)
elif attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill:
return self._forward_c8_chunked_prefill(query, float_key, float_value, attn_metadata, output, layer)
else:
return self._forward_c8_fused_infer_attention(
query,
float_key if float_key is not None else key,
float_value if float_value is not None else value,
attn_metadata,
output,
layer,
)
def _prepare_c8_scales(self, layer: AttentionLayer, device: torch.device) -> None:
"""Shard per-channel C8 scales/offsets to this TP rank and pre-compute
BF16 BNSD antiquant tensors for FIA V1 decode fast path.
"""
if hasattr(layer, "_c8_scales_prepared"):
return
def _shard_and_reshape(raw: torch.Tensor) -> torch.Tensor:
if raw.numel() == 1:
return raw.to(device=device)
expected = self.num_kv_heads * self.head_size
if raw.numel() != expected:
total_kv_heads = raw.numel() // self.head_size
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
kv_head_start = tp_rank * total_kv_heads // tp_size
raw = raw.view(total_kv_heads, self.head_size)[
kv_head_start : kv_head_start + self.num_kv_heads
].contiguous()
return raw.view(1, self.num_kv_heads, self.head_size).to(device=device)
layer._c8_k_scale = _shard_and_reshape(layer.k_cache_scale.data)
layer._c8_k_offset = _shard_and_reshape(layer.k_cache_offset.data)
layer._c8_v_scale = _shard_and_reshape(layer.v_cache_scale.data)
layer._c8_v_offset = _shard_and_reshape(layer.v_cache_offset.data)
bnsd = (1, self.num_kv_heads, 1, self.head_size)
layer._c8_k_aq_scale = layer._c8_k_scale.to(torch.bfloat16).view(bnsd).contiguous()
layer._c8_k_aq_offset = layer._c8_k_offset.to(torch.bfloat16).view(bnsd).contiguous()
layer._c8_v_aq_scale = layer._c8_v_scale.to(torch.bfloat16).view(bnsd).contiguous()
layer._c8_v_aq_offset = layer._c8_v_offset.to(torch.bfloat16).view(bnsd).contiguous()
layer._c8_k_inv_scale_bf16 = (1.0 / layer._c8_k_scale).to(torch.bfloat16)
layer._c8_k_offset_bf16 = layer._c8_k_offset.to(torch.bfloat16)
layer._c8_v_inv_scale_bf16 = (1.0 / layer._c8_v_scale).to(torch.bfloat16)
layer._c8_v_offset_bf16 = layer._c8_v_offset.to(torch.bfloat16)
layer._c8_scales_prepared = True
def _dequant_paged_kv_to_dense(
self,
key: torch.Tensor,
value: torch.Tensor,
block_table: torch.Tensor,
seq_lens: list,
target_dtype: torch.dtype,
layer,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Gather paged INT8 KV blocks and dequantize to target_dtype."""
batch_size = block_table.shape[0]
block_size = key.shape[1]
H = key.shape[2]
max_blocks_per_seq = block_table.shape[1]
max_tokens_padded = max_blocks_per_seq * block_size
flat_ids = block_table.reshape(-1)
gathered_k = key[flat_ids].view(batch_size, max_tokens_padded, H)
gathered_v = value[flat_ids].view(batch_size, max_tokens_padded, H)
seq_lens_t = torch.tensor(seq_lens, dtype=torch.long, device=key.device)
positions = torch.arange(max_tokens_padded, dtype=torch.long, device=key.device)
valid_mask = (positions.unsqueeze(0) < seq_lens_t.unsqueeze(1)).view(-1)
dense_k = gathered_k.view(-1, H)[valid_mask]
dense_v = gathered_v.view(-1, H)[valid_mask]
dense_k = dense_k.view(-1, self.num_kv_heads, self.head_size)
dense_v = dense_v.view(-1, self.num_kv_heads, self.head_size)
k_scale = layer._c8_k_scale.to(target_dtype)
k_offset = layer._c8_k_offset.to(target_dtype)
v_scale = layer._c8_v_scale.to(target_dtype)
v_offset = layer._c8_v_offset.to(target_dtype)
dense_k = (dense_k.to(target_dtype) - k_offset) * k_scale
dense_v = (dense_v.to(target_dtype) - v_offset) * v_scale
return dense_k, dense_v
def _quantize_kv_to_int8(
self,
key: torch.Tensor,
value: torch.Tensor,
layer: AttentionLayer,
num_actual_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize K/V from float to INT8 using static per-channel C8 scales."""
self._prepare_c8_scales(layer, key.device)
actual_key = key[:num_actual_tokens]
actual_value = value[:num_actual_tokens]
k_int8 = torch.clamp(
torch.round(actual_key * layer._c8_k_inv_scale_bf16 + layer._c8_k_offset_bf16),
-128,
127,
).to(torch.int8)
v_int8 = torch.clamp(
torch.round(actual_value * layer._c8_v_inv_scale_bf16 + layer._c8_v_offset_bf16),
-128,
127,
).to(torch.int8)
return k_int8, v_int8
def _forward_c8_decode(
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
layer: AttentionLayer,
) -> torch.Tensor:
"""C8 decode via FIA V1 BNSD with native paged INT8 KV + perchannel antiquant."""
num_block, block_size, _, _ = self.key_cache.shape # type: ignore[attr-defined]
assert block_size % 32 == 0, f"C8 INT8 KV cache requires block_size to be a multiple of 32, got {block_size}"
key = self.key_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
value = self.value_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
batch_size = len(attn_metadata.seq_lens_list)
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query[:batch_size].unsqueeze(2),
key,
value,
key_antiquant_scale=layer._c8_k_aq_scale,
key_antiquant_offset=layer._c8_k_aq_offset,
value_antiquant_scale=layer._c8_v_aq_scale,
value_antiquant_offset=layer._c8_v_aq_offset,
block_table=attn_metadata.block_tables,
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BNSD",
scale=self.scale,
block_size=block_size,
key_antiquant_mode=0,
value_antiquant_mode=0,
sparse_mode=0,
)
attn_output = attn_output.squeeze(2)
output[:batch_size] = attn_output
return output
def _forward_c8_chunked_prefill(
self,
query: torch.Tensor,
float_key: torch.Tensor | None,
float_value: torch.Tensor | None,
attn_metadata: AscendMetadata,
output: torch.Tensor,
layer: AttentionLayer,
) -> torch.Tensor:
"""C8 ChunkedPrefill: decode via FIA V1 BNSD paged INT8 (zero gather),
prefill via FIA V1 TND with float KV (new) or gather+dequant (continuing).
"""
num_decode_tokens = attn_metadata.num_decode_tokens
num_decodes = attn_metadata.num_decodes
actual_seq_qlen = attn_metadata.actual_seq_lengths_q
num_tokens = int(actual_seq_qlen[-1]) # type: ignore[index]
if num_decode_tokens > 0:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore[attr-defined]
assert block_size % 32 == 0, (
f"C8 INT8 KV cache requires block_size to be a multiple of 32, got {block_size}"
)
kv_k = self.key_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
kv_v = self.value_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
attn_out, _ = torch_npu.npu_fused_infer_attention_score(
query[:num_decode_tokens].unsqueeze(2),
kv_k,
kv_v,
key_antiquant_scale=layer._c8_k_aq_scale,
key_antiquant_offset=layer._c8_k_aq_offset,
value_antiquant_scale=layer._c8_v_aq_scale,
value_antiquant_offset=layer._c8_v_aq_offset,
block_table=attn_metadata.block_tables[:num_decodes],
actual_seq_lengths_kv=attn_metadata.seq_lens_list[:num_decodes],
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BNSD",
scale=self.scale,
block_size=block_size,
key_antiquant_mode=0,
value_antiquant_mode=0,
sparse_mode=0,
)
output[:num_decode_tokens] = attn_out.squeeze(2)
if attn_metadata.num_prefills > 0:
prefill_q = query[num_decode_tokens:num_tokens]
prefill_seq_qlen = [
actual_seq_qlen[i] - num_decode_tokens for i in range(num_decodes, len(actual_seq_qlen))
]
all_new_prefill = True
for i in range(num_decodes, len(attn_metadata.seq_lens_list)):
q_start = actual_seq_qlen[i - 1] if i > 0 else 0
qlen_i = actual_seq_qlen[i] - q_start
if attn_metadata.seq_lens_list[i] > qlen_i:
all_new_prefill = False
break
if all_new_prefill and float_key is not None and float_value is not None:
prefill_k = float_key[num_decode_tokens:num_tokens]
prefill_v = float_value[num_decode_tokens:num_tokens]
prefill_seq_kvlen = prefill_seq_qlen
else:
num_block, blk_size, _, _ = self.key_cache.shape # type: ignore[attr-defined]
paged_k = self.key_cache.view(num_block, blk_size, -1) # type: ignore[attr-defined]
paged_v = self.value_cache.view(num_block, blk_size, -1) # type: ignore[attr-defined]
prefill_bt = attn_metadata.block_tables[num_decodes:]
prefill_sl = attn_metadata.seq_lens_list[num_decodes:]
prefill_k, prefill_v = self._dequant_paged_kv_to_dense(
paged_k, paged_v, prefill_bt, prefill_sl, query.dtype, layer
)
prefill_seq_kvlen = torch.tensor(prefill_sl, dtype=torch.int32).cumsum(dim=0)
# block_table is None for prefill; FIA ignores block_size in this case.
# Use cache block_size for consistency rather than a magic number.
cache_block_size = self.key_cache.shape[1] # type: ignore[attr-defined]
attn_out, _ = torch_npu.npu_fused_infer_attention_score(
query=prefill_q,
key=prefill_k,
value=prefill_v,
atten_mask=attn_metadata.attn_mask,
block_table=None,
input_layout="TND",
block_size=cache_block_size,
actual_seq_lengths=prefill_seq_qlen,
actual_seq_lengths_kv=prefill_seq_kvlen,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
n_prefill = num_tokens - num_decode_tokens
attn_out = attn_out.view(n_prefill, self.num_heads, self.head_size)
output[num_decode_tokens:num_tokens] = attn_out[:n_prefill]
return output
def _forward_c8_fused_infer_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
layer: AttentionLayer,
):
"""C8 FIA V1 TND for prefill states (PrefillNoCache uses float KV directly,
PrefillCacheHit gathers + dequants paged INT8 KV).
"""
self._prepare_c8_scales(layer, query.device)
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
actual_seq_qlen = attn_metadata.actual_seq_lengths_q
num_tokens = int(actual_seq_qlen[-1]) # type: ignore[index]
query = query[:num_tokens]
if (
attn_metadata.attn_state == AscendAttentionState.PrefillNoCache
and self.attn_type != AttentionType.ENCODER_DECODER
):
key = key[:num_tokens]
value = value[:num_tokens]
if key.dtype == torch.int8:
if block_table is not None:
seq_lens = (
actual_seq_lengths_kv if isinstance(actual_seq_lengths_kv, list) else actual_seq_lengths_kv.tolist()
)
key, value = self._dequant_paged_kv_to_dense(key, value, block_table, seq_lens, query.dtype, layer)
block_table = None
# block_table is None after dequant; FIA ignores block_size.
# Use cache block_size for consistency rather than a magic number.
block_size = self.key_cache.shape[1] # type: ignore[attr-defined]
actual_seq_lengths_kv = torch.tensor(seq_lens, dtype=torch.int32).cumsum(dim=0)
else:
qdt = query.dtype
k_scale = layer._c8_k_scale.to(qdt)
k_offset = layer._c8_k_offset.to(qdt)
v_scale = layer._c8_v_scale.to(qdt)
v_offset = layer._c8_v_offset.to(qdt)
key = (key.to(qdt) - k_offset) * k_scale
value = (value.to(qdt) - v_offset) * v_scale
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=actual_seq_qlen,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size)
output[:num_tokens] = attn_output
return output