[v0.18.0]feat(quant): add C8 INT8 KV cache support for GQA attention models (#7474) (#8007)

backport of #7474

This PR adds C8 (INT8) KV cache quantization support for standard GQA
attention models (e.g., Qwen3-32B W8A8C8). C8 uses static per-channel
quantization scales to store KV cache in INT8, reducing KV cache memory
by ~50% compared to BF16, enabling higher batch concurrency and longer
context lengths on the same hardware.

**Key changes:**

1. **`attention_v1.py`** — New `AscendC8AttentionBackendImpl` subclass
of `AscendAttentionBackendImpl`:
- `_prepare_c8_scales`: Shards per-channel scales/offsets to the current
TP rank and pre-computes BF16 BNSD-shaped antiquant tensors (one-time
per layer).
- `_quantize_kv_to_int8`: Quantizes BF16 K/V to INT8 before
`reshape_and_cache`, using pre-cached inverse scales.
- `_forward_c8_decode`: FIA V1 BNSD paged attention with native INT8 KV
and `perchannel` antiquant mode.
- `_forward_c8_chunked_prefill`: Splits decode (FIA V1 BNSD paged INT8)
and prefill (FIA V1 TND float) into two kernel calls.
- `_forward_c8_fused_infer_attention`: Handles `PrefillNoCache` and
`PrefillCacheHit` states.

2. **`quantization/methods/kv_c8.py`** — New
`AscendC8KVCacheAttentionMethod` scheme:
- Creates `k/v_cache_scale/offset` parameters via
`_c8_kv_scale_weight_loader`, which handles per-channel scale shapes and
lazy resizing.
- Sets `layer.kv_cache_torch_dtype = torch.int8` so
`get_kv_cache_spec()` returns INT8 dtype automatically.
- Upgrades `layer.impl` to `AscendC8AttentionBackendImpl` via class
surgery.

3. **`quantization/modelslim_config.py`** — C8 branch in
`get_quant_method()` activates when `kv_cache_type == "C8"` in
`quant_model_description.json`.

4. **`patch/worker/patch_qwen3_c8.py`** — Intercepts per-channel C8
scale/offset weights before `AutoWeightsLoader` discards them, routing
them to the parameters created by `AscendC8KVCacheAttentionMethod`.

5. **`tests/ut/quantization/test_kv_c8.py`** — Unit tests covering
`_c8_kv_scale_weight_loader`, `AscendC8KVCacheAttentionMethod`, and
`AscendC8AttentionBackendImpl` scale helpers.

Yes. Users can now serve Qwen3-32B W8A8C8 quantized models with INT8 KV
cache on Ascend NPU. The model checkpoint must contain a
`quant_model_description.json` with `"kv_cache_type": "C8"` and
per-channel scale/offset tensors in safetensors.

No changes to the serving CLI — the feature activates automatically when
the quantization config is detected.

Benchmarked with `vllm serve` (TP=8, `max_num_seqs=256`,
`max_model_len=131072`, `enable_chunked_prefill=true`) + `random_bench`
(input_len=10240, output_len=2048, 960 prompts, max_concurrency=192):

```
============ Serving Benchmark Result ============
Successful requests:                     960
Failed requests:                         0
Maximum request concurrency:             192
Benchmark duration (s):                  1359.81
Total input tokens:                      9830400
Total generated tokens:                  1966080
Request throughput (req/s):              0.71
Output token throughput (tok/s):         1445.85
Peak output token throughput (tok/s):    2304.00
Total token throughput (tok/s):          8675.12
---------------Time to First Token----------------
Mean TTFT (ms):                          24598.51
Median TTFT (ms):                        23167.02
P50 TTFT (ms):                           23167.02
P90 TTFT (ms):                           47717.08
P99 TTFT (ms):                           84402.61
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          120.76
Median TPOT (ms):                        121.50
P50 TPOT (ms):                           121.50
P90 TPOT (ms):                           127.05
P99 TPOT (ms):                           130.13
---------------Inter-token Latency----------------
Mean ITL (ms):                           120.70
Median ITL (ms):                         90.34
P50 ITL (ms):                            90.34
P90 ITL (ms):                            93.79
P99 ITL (ms):                            101.80
==================================================
```

All attention states verified: `PrefillNoCache`, `PrefillCacheHit`,
`ChunkedPrefill`, `DecodeOnly`.

- vLLM version: v0.17.0
- vLLM main:
8b6325758c

Signed-off-by: lico67373 <918688502@qq.com>
Co-authored-by: LICO67373 <110013619+LICO1314@users.noreply.github.com>
This commit is contained in:
Mengqing Cao
2026-04-08 10:51:58 +08:00
committed by GitHub
parent fbd5d0fd55
commit 044d4c3974
8 changed files with 761 additions and 8 deletions

View File

@@ -1,7 +1,9 @@
import unittest
import torch
import torch.nn as nn
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
from tests.ut.base import TestBase
class TestWeightLoader(unittest.TestCase):
@@ -10,7 +12,7 @@ class TestWeightLoader(unittest.TestCase):
def setUp(self):
"""Set up test environment before each test"""
# Import the module under test
from vllm_ascend.quantization.methods.kv_c8 import weight_loader
from vllm_ascend.quantization.methods.kv_c8 import _fa_quant_weight_loader as weight_loader
self.weight_loader = weight_loader
# Mock distributed functions
@@ -295,7 +297,7 @@ class TestAscendFAQuantAttentionMethodCreateWeights(unittest.TestCase):
method.create_weights(self.layer)
# Import weight_loader for comparison
from vllm_ascend.quantization.methods.kv_c8 import weight_loader
from vllm_ascend.quantization.methods.kv_c8 import _fa_quant_weight_loader as weight_loader
# Verify each parameter exists and has weight_loader
self.assertTrue(hasattr(self.layer.fa_q, "scale"))
@@ -440,7 +442,7 @@ class TestIntegration(unittest.TestCase):
v_offset = torch.randint(-128, 127, (1, 1), dtype=torch.int8)
# Load weights using weight_loader
from vllm_ascend.quantization.methods.kv_c8 import weight_loader
from vllm_ascend.quantization.methods.kv_c8 import _fa_quant_weight_loader as weight_loader
with torch.no_grad():
weight_loader(layer.fa_q.scale, q_scale)
@@ -464,5 +466,224 @@ class TestIntegration(unittest.TestCase):
self.assertTrue(hasattr(layer, "quant_kscale"))
class TestC8KVScaleWeightLoader(TestBase):
"""Tests for _c8_kv_scale_weight_loader in kv_c8.py."""
def setUp(self):
from vllm_ascend.quantization.methods.kv_c8 import _c8_kv_scale_weight_loader
self.loader = _c8_kv_scale_weight_loader
def test_shape_match_copies_value(self):
param = nn.Parameter(torch.ones(4, dtype=torch.float32), requires_grad=False)
loaded = torch.tensor([1.0, 2.0, 3.0, 4.0])
self.loader(param, loaded)
self.assertTrue(torch.allclose(param.data, loaded.float()))
def test_shape_mismatch_resizes_param(self):
param = nn.Parameter(torch.ones(1, dtype=torch.float32), requires_grad=False)
loaded = torch.arange(8, dtype=torch.float32)
self.loader(param, loaded)
self.assertEqual(param.data.shape, (8,))
self.assertTrue(torch.allclose(param.data, loaded))
def test_squeeze_before_compare(self):
param = nn.Parameter(torch.ones(4, dtype=torch.float32), requires_grad=False)
loaded = torch.arange(4, dtype=torch.float32).unsqueeze(0) # shape [1, 4]
self.loader(param, loaded)
self.assertEqual(param.data.shape, (4,))
def test_dtype_preserved_as_param_dtype(self):
param = nn.Parameter(torch.ones(4, dtype=torch.float32), requires_grad=False)
loaded = torch.arange(4, dtype=torch.float16)
self.loader(param, loaded)
self.assertEqual(param.data.dtype, torch.float32)
class TestAscendC8KVCacheAttentionMethod(TestBase):
"""Tests for AscendC8KVCacheAttentionMethod in kv_c8.py."""
def _make_method(self):
from vllm_ascend.quantization.methods.kv_c8 import AscendC8KVCacheAttentionMethod
return AscendC8KVCacheAttentionMethod(quant_description={}, prefix="model.layers.0.self_attn.attn")
def _make_layer_with_impl(self):
layer = nn.Module()
layer.impl = MagicMock()
return layer
def test_create_weights_sets_kv_cache_torch_dtype(self):
method = self._make_method()
layer = self._make_layer_with_impl()
method.create_weights(layer)
self.assertEqual(layer.kv_cache_torch_dtype, torch.int8)
def test_create_weights_registers_scale_offset_params(self):
method = self._make_method()
layer = self._make_layer_with_impl()
method.create_weights(layer)
self.assertIsInstance(layer.k_cache_scale, nn.Parameter)
self.assertIsInstance(layer.k_cache_offset, nn.Parameter)
self.assertIsInstance(layer.v_cache_scale, nn.Parameter)
self.assertIsInstance(layer.v_cache_offset, nn.Parameter)
self.assertFalse(layer.k_cache_scale.requires_grad)
self.assertFalse(layer.v_cache_offset.requires_grad)
def test_create_weights_initial_values(self):
method = self._make_method()
layer = self._make_layer_with_impl()
method.create_weights(layer)
self.assertEqual(layer.k_cache_scale.data.item(), 1.0)
self.assertEqual(layer.v_cache_scale.data.item(), 1.0)
self.assertEqual(layer.k_cache_offset.data.item(), 0.0)
self.assertEqual(layer.v_cache_offset.data.item(), 0.0)
def test_create_weights_assigns_weight_loader(self):
from vllm_ascend.quantization.methods.kv_c8 import _c8_kv_scale_weight_loader
method = self._make_method()
layer = self._make_layer_with_impl()
method.create_weights(layer)
self.assertIs(layer.k_cache_scale.weight_loader, _c8_kv_scale_weight_loader)
self.assertIs(layer.v_cache_scale.weight_loader, _c8_kv_scale_weight_loader)
self.assertIs(layer.k_cache_offset.weight_loader, _c8_kv_scale_weight_loader)
self.assertIs(layer.v_cache_offset.weight_loader, _c8_kv_scale_weight_loader)
def test_process_weights_after_loading_flattens(self):
method = self._make_method()
layer = nn.Module()
layer.k_cache_scale = nn.Parameter(torch.ones(2, 4), requires_grad=False)
layer.k_cache_offset = nn.Parameter(torch.zeros(2, 4), requires_grad=False)
layer.v_cache_scale = nn.Parameter(torch.ones(2, 4), requires_grad=False)
layer.v_cache_offset = nn.Parameter(torch.zeros(2, 4), requires_grad=False)
method.process_weights_after_loading(layer)
self.assertEqual(layer.k_cache_scale.data.dim(), 1)
self.assertEqual(layer.k_cache_scale.data.shape[0], 8)
self.assertEqual(layer.v_cache_offset.data.dim(), 1)
def test_apply_raises_runtime_error(self):
method = self._make_method()
layer = MagicMock()
with self.assertRaises(RuntimeError):
method.apply(layer, MagicMock(), MagicMock(), MagicMock(), None, None, None, None, None)
class TestAscendC8AttentionBackendImplScales(TestBase):
"""Tests for AscendC8AttentionBackendImpl scale helpers."""
def _make_impl(self, num_kv_heads=4, head_size=8):
from vllm_ascend.attention.attention_v1 import AscendC8AttentionBackendImpl
impl = object.__new__(AscendC8AttentionBackendImpl)
impl.num_heads = num_kv_heads
impl.num_kv_heads = num_kv_heads
impl.head_size = head_size
impl.scale = 1.0
impl.key_cache = None
impl.value_cache = None
return impl
def _make_layer(self, num_kv_heads=4, head_size=8):
layer = nn.Module()
layer.k_cache_scale = nn.Parameter(
torch.ones(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False
)
layer.k_cache_offset = nn.Parameter(
torch.zeros(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False
)
layer.v_cache_scale = nn.Parameter(
torch.ones(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False
)
layer.v_cache_offset = nn.Parameter(
torch.zeros(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False
)
return layer
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0)
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1)
def test_prepare_c8_scales_runs_once(self, mock_tp_size, mock_tp_rank):
impl = self._make_impl()
layer = self._make_layer()
impl._prepare_c8_scales(layer, torch.device("cpu"))
self.assertTrue(hasattr(layer, "_c8_scales_prepared"))
self.assertTrue(layer._c8_scales_prepared)
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0)
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1)
def test_prepare_c8_scales_idempotent(self, mock_tp_size, mock_tp_rank):
impl = self._make_impl()
layer = self._make_layer()
impl._prepare_c8_scales(layer, torch.device("cpu"))
k_scale_after_first = layer._c8_k_scale.clone()
layer.k_cache_scale.data = torch.ones(32, dtype=torch.float32) * 99
impl._prepare_c8_scales(layer, torch.device("cpu"))
self.assertTrue(torch.allclose(layer._c8_k_scale, k_scale_after_first))
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0)
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1)
def test_prepare_c8_scales_creates_bnsd_shape(self, mock_tp_size, mock_tp_rank):
num_kv_heads, head_size = 4, 8
impl = self._make_impl(num_kv_heads, head_size)
layer = self._make_layer(num_kv_heads, head_size)
impl._prepare_c8_scales(layer, torch.device("cpu"))
self.assertEqual(layer._c8_k_aq_scale.shape, (1, num_kv_heads, 1, head_size))
self.assertEqual(layer._c8_v_aq_scale.shape, (1, num_kv_heads, 1, head_size))
self.assertEqual(layer._c8_k_aq_scale.dtype, torch.bfloat16)
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0)
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1)
def test_quantize_kv_to_int8_output_dtype(self, mock_tp_size, mock_tp_rank):
num_kv_heads, head_size = 4, 8
impl = self._make_impl(num_kv_heads, head_size)
layer = self._make_layer(num_kv_heads, head_size)
impl._prepare_c8_scales(layer, torch.device("cpu"))
num_tokens = 6
key = torch.zeros(num_tokens, num_kv_heads, head_size, dtype=torch.bfloat16)
value = torch.zeros(num_tokens, num_kv_heads, head_size, dtype=torch.bfloat16)
key_q, value_q = impl._quantize_kv_to_int8(key, value, layer, num_tokens)
self.assertEqual(key_q.dtype, torch.int8)
self.assertEqual(value_q.dtype, torch.int8)
self.assertEqual(key_q.shape, key.shape)
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0)
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1)
def test_quantize_kv_to_int8_formula(self, mock_tp_size, mock_tp_rank):
"""With scale=2.0, offset=0: q = round(x / 2)."""
num_kv_heads, head_size = 1, 4
impl = self._make_impl(num_kv_heads, head_size)
layer = nn.Module()
scale_val = torch.full((num_kv_heads * head_size,), 2.0, dtype=torch.float32)
layer.k_cache_scale = nn.Parameter(scale_val.clone(), requires_grad=False)
layer.k_cache_offset = nn.Parameter(torch.zeros(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False)
layer.v_cache_scale = nn.Parameter(scale_val.clone(), requires_grad=False)
layer.v_cache_offset = nn.Parameter(torch.zeros(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False)
impl._prepare_c8_scales(layer, torch.device("cpu"))
key = torch.full((1, num_kv_heads, head_size), 4.0, dtype=torch.bfloat16)
value = torch.full((1, num_kv_heads, head_size), 4.0, dtype=torch.bfloat16)
key_q, _ = impl._quantize_kv_to_int8(key, value, layer, 1)
self.assertTrue(torch.all(key_q[0] == 2))
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0)
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1)
def test_dequant_paged_kv_to_dense_round_trip(self, mock_tp_size, mock_tp_rank):
"""With scale=1, offset=0: dequant(int8) == float(int8)."""
num_kv_heads, head_size = 2, 4
block_size = 32
num_blocks = 2
H = num_kv_heads * head_size
impl = self._make_impl(num_kv_heads, head_size)
layer = self._make_layer(num_kv_heads, head_size)
impl._prepare_c8_scales(layer, torch.device("cpu"))
key_int8 = torch.randint(-10, 10, (num_blocks, block_size, H), dtype=torch.int8)
value_int8 = torch.randint(-10, 10, (num_blocks, block_size, H), dtype=torch.int8)
seq_lens = [32, 32]
block_table = torch.tensor([[0], [1]], dtype=torch.long)
dense_k, dense_v = impl._dequant_paged_kv_to_dense(
key_int8, value_int8, block_table, seq_lens, torch.float32, layer
)
expected_k = key_int8.view(-1, num_kv_heads, head_size).float()
self.assertEqual(dense_k.shape, (64, num_kv_heads, head_size))
self.assertTrue(torch.allclose(dense_k, expected_k))
if __name__ == "__main__":
unittest.main(verbosity=2)
unittest.main(verbosity=2)

View File

@@ -16,6 +16,7 @@ from vllm_ascend.quantization.modelslim_config import (
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
class TestAscendModelSlimConfig(TestBase):
@@ -125,6 +126,19 @@ class TestAscendModelSlimConfig(TestBase):
attention_layer, "layers.1.attn")
self.assertIs(method, mock_ascend_kvcache.return_value)
def test_get_quant_method_for_c8_kv_cache_attention(self):
c8_config = AscendModelSlimConfig({"kv_cache_type": "C8"})
attention_layer = MagicMock(spec=AttentionLayerBase)
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.hf_config.model_type = None
with patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_vllm_config), \
patch("vllm_ascend.quantization.method_adapters.AscendKVCacheMethod", return_value=MagicMock()) as mock_kvcache:
method = c8_config.get_quant_method(attention_layer, "model.layers.0.self_attn.attn")
self.assertIs(method, mock_kvcache.return_value)
args, _ = mock_kvcache.call_args
from vllm_ascend.quantization.methods.kv_c8 import AscendC8KVCacheAttentionMethod
self.assertIsInstance(args[0], AscendC8KVCacheAttentionMethod)
def test_get_quant_method_for_fused_moe(self):
fused_moe_layer = MagicMock(spec=FusedMoE)
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)

View File

@@ -22,6 +22,7 @@ import torch
import torch_npu
import vllm.envs as envs_vllm
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import ( # type: ignore
AttentionBackend,
@@ -978,3 +979,364 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output
class AscendC8AttentionBackendImpl(AscendAttentionBackendImpl):
"""Attention backend implementation for INT8 KV cache (C8/QuaRot) models.
This subclass handles static per-channel INT8 KV cache quantization.
It is activated via class surgery in AscendC8KVCacheAttentionMethod.create_weights
(vllm_ascend/quantization/methods/kv_c8.py)
so that C8 attention layers automatically use this forward path.
"""
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError("fused output quantization is not yet supported for AscendC8AttentionBackendImpl")
num_tokens = query.shape[0]
if attn_metadata is None:
return output.fill_(0)
float_key, float_value = None, None
if key is not None and value is not None:
if attn_metadata.attn_state != AscendAttentionState.DecodeOnly:
float_key, float_value = key, value
key, value = self._quantize_kv_to_int8(key, value, layer, attn_metadata.num_actual_tokens)
query, key, value, _ = self.reshape_and_cache(query, key, value, kv_cache, attn_metadata, output)
if attn_metadata.model_runner_type == "pooling":
attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output
self._prepare_c8_scales(layer, query.device)
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
return self._forward_c8_decode(query, attn_metadata, output, layer)
elif attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill:
return self._forward_c8_chunked_prefill(query, float_key, float_value, attn_metadata, output, layer)
else:
return self._forward_c8_fused_infer_attention(
query,
float_key if float_key is not None else key,
float_value if float_value is not None else value,
attn_metadata,
output,
layer,
)
def _prepare_c8_scales(self, layer: AttentionLayer, device: torch.device) -> None:
"""Shard per-channel C8 scales/offsets to this TP rank and pre-compute
BF16 BNSD antiquant tensors for FIA V1 decode fast path.
"""
if hasattr(layer, "_c8_scales_prepared"):
return
def _shard_and_reshape(raw: torch.Tensor) -> torch.Tensor:
if raw.numel() == 1:
return raw.to(device=device)
expected = self.num_kv_heads * self.head_size
if raw.numel() != expected:
total_kv_heads = raw.numel() // self.head_size
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
kv_head_start = tp_rank * total_kv_heads // tp_size
raw = raw.view(total_kv_heads, self.head_size)[
kv_head_start : kv_head_start + self.num_kv_heads
].contiguous()
return raw.view(1, self.num_kv_heads, self.head_size).to(device=device)
layer._c8_k_scale = _shard_and_reshape(layer.k_cache_scale.data)
layer._c8_k_offset = _shard_and_reshape(layer.k_cache_offset.data)
layer._c8_v_scale = _shard_and_reshape(layer.v_cache_scale.data)
layer._c8_v_offset = _shard_and_reshape(layer.v_cache_offset.data)
bnsd = (1, self.num_kv_heads, 1, self.head_size)
layer._c8_k_aq_scale = layer._c8_k_scale.to(torch.bfloat16).view(bnsd).contiguous()
layer._c8_k_aq_offset = layer._c8_k_offset.to(torch.bfloat16).view(bnsd).contiguous()
layer._c8_v_aq_scale = layer._c8_v_scale.to(torch.bfloat16).view(bnsd).contiguous()
layer._c8_v_aq_offset = layer._c8_v_offset.to(torch.bfloat16).view(bnsd).contiguous()
layer._c8_k_inv_scale_bf16 = (1.0 / layer._c8_k_scale).to(torch.bfloat16)
layer._c8_k_offset_bf16 = layer._c8_k_offset.to(torch.bfloat16)
layer._c8_v_inv_scale_bf16 = (1.0 / layer._c8_v_scale).to(torch.bfloat16)
layer._c8_v_offset_bf16 = layer._c8_v_offset.to(torch.bfloat16)
layer._c8_scales_prepared = True
def _dequant_paged_kv_to_dense(
self,
key: torch.Tensor,
value: torch.Tensor,
block_table: torch.Tensor,
seq_lens: list,
target_dtype: torch.dtype,
layer,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Gather paged INT8 KV blocks and dequantize to target_dtype."""
batch_size = block_table.shape[0]
block_size = key.shape[1]
H = key.shape[2]
max_blocks_per_seq = block_table.shape[1]
max_tokens_padded = max_blocks_per_seq * block_size
flat_ids = block_table.reshape(-1)
gathered_k = key[flat_ids].view(batch_size, max_tokens_padded, H)
gathered_v = value[flat_ids].view(batch_size, max_tokens_padded, H)
seq_lens_t = torch.tensor(seq_lens, dtype=torch.long, device=key.device)
positions = torch.arange(max_tokens_padded, dtype=torch.long, device=key.device)
valid_mask = (positions.unsqueeze(0) < seq_lens_t.unsqueeze(1)).view(-1)
dense_k = gathered_k.view(-1, H)[valid_mask]
dense_v = gathered_v.view(-1, H)[valid_mask]
dense_k = dense_k.view(-1, self.num_kv_heads, self.head_size)
dense_v = dense_v.view(-1, self.num_kv_heads, self.head_size)
k_scale = layer._c8_k_scale.to(target_dtype)
k_offset = layer._c8_k_offset.to(target_dtype)
v_scale = layer._c8_v_scale.to(target_dtype)
v_offset = layer._c8_v_offset.to(target_dtype)
dense_k = (dense_k.to(target_dtype) - k_offset) * k_scale
dense_v = (dense_v.to(target_dtype) - v_offset) * v_scale
return dense_k, dense_v
def _quantize_kv_to_int8(
self,
key: torch.Tensor,
value: torch.Tensor,
layer: AttentionLayer,
num_actual_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize K/V from float to INT8 using static per-channel C8 scales."""
self._prepare_c8_scales(layer, key.device)
actual_key = key[:num_actual_tokens]
actual_value = value[:num_actual_tokens]
k_int8 = torch.clamp(
torch.round(actual_key * layer._c8_k_inv_scale_bf16 + layer._c8_k_offset_bf16),
-128,
127,
).to(torch.int8)
v_int8 = torch.clamp(
torch.round(actual_value * layer._c8_v_inv_scale_bf16 + layer._c8_v_offset_bf16),
-128,
127,
).to(torch.int8)
return k_int8, v_int8
def _forward_c8_decode(
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
layer: AttentionLayer,
) -> torch.Tensor:
"""C8 decode via FIA V1 BNSD with native paged INT8 KV + perchannel antiquant."""
num_block, block_size, _, _ = self.key_cache.shape # type: ignore[attr-defined]
assert block_size % 32 == 0, f"C8 INT8 KV cache requires block_size to be a multiple of 32, got {block_size}"
key = self.key_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
value = self.value_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
batch_size = len(attn_metadata.seq_lens_list)
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query[:batch_size].unsqueeze(2),
key,
value,
key_antiquant_scale=layer._c8_k_aq_scale,
key_antiquant_offset=layer._c8_k_aq_offset,
value_antiquant_scale=layer._c8_v_aq_scale,
value_antiquant_offset=layer._c8_v_aq_offset,
block_table=attn_metadata.block_tables,
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BNSD",
scale=self.scale,
block_size=block_size,
key_antiquant_mode=0,
value_antiquant_mode=0,
sparse_mode=0,
)
attn_output = attn_output.squeeze(2)
output[:batch_size] = attn_output
return output
def _forward_c8_chunked_prefill(
self,
query: torch.Tensor,
float_key: torch.Tensor | None,
float_value: torch.Tensor | None,
attn_metadata: AscendMetadata,
output: torch.Tensor,
layer: AttentionLayer,
) -> torch.Tensor:
"""C8 ChunkedPrefill: decode via FIA V1 BNSD paged INT8 (zero gather),
prefill via FIA V1 TND with float KV (new) or gather+dequant (continuing).
"""
num_decode_tokens = attn_metadata.num_decode_tokens
num_decodes = attn_metadata.num_decodes
actual_seq_qlen = attn_metadata.actual_seq_lengths_q
num_tokens = int(actual_seq_qlen[-1]) # type: ignore[index]
if num_decode_tokens > 0:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore[attr-defined]
assert block_size % 32 == 0, (
f"C8 INT8 KV cache requires block_size to be a multiple of 32, got {block_size}"
)
kv_k = self.key_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
kv_v = self.value_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
attn_out, _ = torch_npu.npu_fused_infer_attention_score(
query[:num_decode_tokens].unsqueeze(2),
kv_k,
kv_v,
key_antiquant_scale=layer._c8_k_aq_scale,
key_antiquant_offset=layer._c8_k_aq_offset,
value_antiquant_scale=layer._c8_v_aq_scale,
value_antiquant_offset=layer._c8_v_aq_offset,
block_table=attn_metadata.block_tables[:num_decodes],
actual_seq_lengths_kv=attn_metadata.seq_lens_list[:num_decodes],
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BNSD",
scale=self.scale,
block_size=block_size,
key_antiquant_mode=0,
value_antiquant_mode=0,
sparse_mode=0,
)
output[:num_decode_tokens] = attn_out.squeeze(2)
if attn_metadata.num_prefills > 0:
prefill_q = query[num_decode_tokens:num_tokens]
prefill_seq_qlen = [
actual_seq_qlen[i] - num_decode_tokens for i in range(num_decodes, len(actual_seq_qlen))
]
all_new_prefill = True
for i in range(num_decodes, len(attn_metadata.seq_lens_list)):
q_start = actual_seq_qlen[i - 1] if i > 0 else 0
qlen_i = actual_seq_qlen[i] - q_start
if attn_metadata.seq_lens_list[i] > qlen_i:
all_new_prefill = False
break
if all_new_prefill and float_key is not None and float_value is not None:
prefill_k = float_key[num_decode_tokens:num_tokens]
prefill_v = float_value[num_decode_tokens:num_tokens]
prefill_seq_kvlen = prefill_seq_qlen
else:
num_block, blk_size, _, _ = self.key_cache.shape # type: ignore[attr-defined]
paged_k = self.key_cache.view(num_block, blk_size, -1) # type: ignore[attr-defined]
paged_v = self.value_cache.view(num_block, blk_size, -1) # type: ignore[attr-defined]
prefill_bt = attn_metadata.block_tables[num_decodes:]
prefill_sl = attn_metadata.seq_lens_list[num_decodes:]
prefill_k, prefill_v = self._dequant_paged_kv_to_dense(
paged_k, paged_v, prefill_bt, prefill_sl, query.dtype, layer
)
prefill_seq_kvlen = torch.tensor(prefill_sl, dtype=torch.int32).cumsum(dim=0)
# block_table is None for prefill; FIA ignores block_size in this case.
# Use cache block_size for consistency rather than a magic number.
cache_block_size = self.key_cache.shape[1] # type: ignore[attr-defined]
attn_out, _ = torch_npu.npu_fused_infer_attention_score(
query=prefill_q,
key=prefill_k,
value=prefill_v,
atten_mask=attn_metadata.attn_mask,
block_table=None,
input_layout="TND",
block_size=cache_block_size,
actual_seq_lengths=prefill_seq_qlen,
actual_seq_lengths_kv=prefill_seq_kvlen,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
n_prefill = num_tokens - num_decode_tokens
attn_out = attn_out.view(n_prefill, self.num_heads, self.head_size)
output[num_decode_tokens:num_tokens] = attn_out[:n_prefill]
return output
def _forward_c8_fused_infer_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
layer: AttentionLayer,
):
"""C8 FIA V1 TND for prefill states (PrefillNoCache uses float KV directly,
PrefillCacheHit gathers + dequants paged INT8 KV).
"""
self._prepare_c8_scales(layer, query.device)
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
actual_seq_qlen = attn_metadata.actual_seq_lengths_q
num_tokens = int(actual_seq_qlen[-1]) # type: ignore[index]
query = query[:num_tokens]
if (
attn_metadata.attn_state == AscendAttentionState.PrefillNoCache
and self.attn_type != AttentionType.ENCODER_DECODER
):
key = key[:num_tokens]
value = value[:num_tokens]
if key.dtype == torch.int8:
if block_table is not None:
seq_lens = (
actual_seq_lengths_kv if isinstance(actual_seq_lengths_kv, list) else actual_seq_lengths_kv.tolist()
)
key, value = self._dequant_paged_kv_to_dense(key, value, block_table, seq_lens, query.dtype, layer)
block_table = None
# block_table is None after dequant; FIA ignores block_size.
# Use cache block_size for consistency rather than a magic number.
block_size = self.key_cache.shape[1] # type: ignore[attr-defined]
actual_seq_lengths_kv = torch.tensor(seq_lens, dtype=torch.int32).cumsum(dim=0)
else:
qdt = query.dtype
k_scale = layer._c8_k_scale.to(qdt)
k_offset = layer._c8_k_offset.to(qdt)
v_scale = layer._c8_v_scale.to(qdt)
v_offset = layer._c8_v_offset.to(qdt)
key = (key.to(qdt) - k_offset) * k_scale
value = (value.to(qdt) - v_offset) * v_scale
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=actual_seq_qlen,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size)
output[:num_tokens] = attn_output
return output

View File

@@ -721,3 +721,27 @@
# override _get_deepstack_input_embeds method with the flash comm v1 implementation.
# Future Plan:
# Remove this patch when https://github.com/vllm-project/vllm-ascend/issues/5712 is completed.
#
# ** 29. File: worker/patch_qwen3_c8.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.models.qwen3.Qwen3ForCausalLM.load_weights`
# Why:
# The Qwen3 W8A8C8 model stores per-channel KV cache scales and offsets
# (k_cache_scale, k_cache_offset, v_cache_scale, v_cache_offset) under
# weight names that AutoWeightsLoader does not recognise and would
# silently discard. Without these scales the INT8 KV cache cannot be
# dequantised correctly at inference time.
# How:
# Wrap load_weights to intercept the C8 scale/offset tensors before they
# reach the base loader. Each intercepted tensor is routed to the
# corresponding nn.Parameter via its weight_loader, then excluded from
# the remaining weight stream so the base loader never sees it.
# Related PR (if no, explain why):
# This PR (Qwen3-32B W8A8C8 support). Upstream vLLM's weight-loading
# pipeline does not yet have a generic hook for hardware-plugin-defined
# KV cache parameters.
# Future Plan:
# Remove this patch when vLLM provides a first-class extension point
# for loading extra KV cache quantisation parameters in model load_weights,
# or when the Qwen3 model's weight names are aligned with the parameter
# names expected by the quantisation backend.

View File

@@ -51,3 +51,4 @@ import vllm_ascend.patch.worker.patch_v2.patch_model_state # noqa
import vllm_ascend.patch.worker.patch_v2.patch_block_table # noqa
import vllm_ascend.patch.worker.patch_qwen3vl # noqa
import vllm_ascend.patch.worker.patch_deepencoder2 # noqa
import vllm_ascend.patch.worker.patch_qwen3_c8 # noqa

View File

@@ -0,0 +1,54 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections.abc import Iterable
import torch
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
_orig_qwen3_causal_lm_load_weights = Qwen3ForCausalLM.load_weights
def _patched_qwen3_causal_lm_load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
quant_config = self.quant_config
if quant_config is None or not callable(getattr(quant_config, "get_cache_scale", None)):
return _orig_qwen3_causal_lm_load_weights(self, weights)
params_dict = dict(self.named_parameters())
c8_loaded_params: set[str] = set()
def _intercept_c8_scales(
raw_weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[tuple[str, torch.Tensor]]:
for name, loaded_weight in raw_weights:
scale_name = quant_config.get_cache_scale(name)
if scale_name is not None:
if scale_name in params_dict:
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight.squeeze())
c8_loaded_params.add(scale_name)
else:
yield name, loaded_weight
loaded_params = _orig_qwen3_causal_lm_load_weights(self, _intercept_c8_scales(weights))
loaded_params.update(c8_loaded_params)
return loaded_params
Qwen3ForCausalLM.load_weights = _patched_qwen3_causal_lm_load_weights

View File

@@ -2,11 +2,12 @@ import torch
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
from .base import AscendAttentionScheme
from .registry import register_scheme
def weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor):
"""fa_q weight loader."""
def _fa_quant_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor):
"""Weight loader for MLA-based C8 (FAKQuant) models."""
if param.numel() == 1 and loaded_weight.numel() == 1:
param.data.fill_(loaded_weight.item())
else:
@@ -50,7 +51,7 @@ class AscendFAQuantAttentionMethod:
weight_param = torch.nn.Parameter(weight, requires_grad=False)
module.register_parameter(weight_name, weight_param)
# When loading weights, segment them according to TP
weight_param.weight_loader = weight_loader
weight_param.weight_loader = _fa_quant_weight_loader
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
fa_k_scale = torch.squeeze(layer.fa_k.scale).unsqueeze(0)
@@ -87,3 +88,60 @@ class AscendSFAQuantAttentionMethod:
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def _c8_kv_scale_weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None:
"""Weight loader for dense-attention C8 KV cache scales/offsets."""
loaded_weight = loaded_weight.squeeze()
if param.data.shape != loaded_weight.shape:
param.data = loaded_weight.to(param.dtype).clone()
else:
param.data.copy_(loaded_weight)
class AscendC8KVCacheAttentionMethod(AscendAttentionScheme):
"""C8 INT8 KV cache quantization for dense-attention models (e.g. Qwen3)."""
def __init__(self, quant_description: dict, prefix: str):
self.quant_description = quant_description
self.prefix = prefix
def create_weights(self, layer: torch.nn.Module) -> None:
# Override kv_cache_torch_dtype so Attention.get_kv_cache_spec returns int8 automatically.
layer.kv_cache_torch_dtype = torch.int8
# Upgrade impl to the C8-specific subclass so the C8 forward path is always used.
if hasattr(layer, "impl"):
from vllm_ascend.attention.attention_v1 import AscendC8AttentionBackendImpl
layer.impl.__class__ = AscendC8AttentionBackendImpl
layer.k_cache_scale = torch.nn.Parameter(torch.ones(1, dtype=torch.float32), requires_grad=False)
layer.k_cache_scale.weight_loader = _c8_kv_scale_weight_loader
layer.k_cache_offset = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=False)
layer.k_cache_offset.weight_loader = _c8_kv_scale_weight_loader
layer.v_cache_scale = torch.nn.Parameter(torch.ones(1, dtype=torch.float32), requires_grad=False)
layer.v_cache_scale.weight_loader = _c8_kv_scale_weight_loader
layer.v_cache_offset = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=False)
layer.v_cache_offset.weight_loader = _c8_kv_scale_weight_loader
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.k_cache_scale.data = layer.k_cache_scale.data.flatten()
layer.k_cache_offset.data = layer.k_cache_offset.data.flatten()
layer.v_cache_scale.data = layer.v_cache_scale.data.flatten()
layer.v_cache_offset.data = layer.v_cache_offset.data.flatten()
def apply(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache,
attn_metadata,
attn_type,
scale,
output,
) -> torch.Tensor:
raise RuntimeError(
"AscendC8KVCacheAttentionMethod.apply should not be called. "
"C8 KV cache quantization is handled by the attention backend."
)

View File

@@ -429,6 +429,21 @@ class AscendModelSlimConfig(QuantizationConfig):
self._add_kvcache_quant_metadata()
logger.info("Applied hf_to_vllm_mapper to quant_description keys")
def get_cache_scale(self, name: str) -> str | None:
"""Map checkpoint C8 KV scale/offset names to vLLM parameter names."""
if self.quant_description.get("kv_cache_type") != "C8":
return None
_C8_SCALE_MAPPING = {
"k_proj.kv_cache_scale": "attn.k_cache_scale",
"k_proj.kv_cache_offset": "attn.k_cache_offset",
"v_proj.kv_cache_scale": "attn.v_cache_scale",
"v_proj.kv_cache_offset": "attn.v_cache_offset",
}
for src_suffix, dst_suffix in _C8_SCALE_MAPPING.items():
if name.endswith(src_suffix):
return name[: -len(src_suffix)] + dst_suffix
return None
def quant_prefix_mapper(self, model_type: str, prefix: str) -> str:
self.model_type = model_type
return prefix
@@ -476,6 +491,10 @@ class AscendModelSlimConfig(QuantizationConfig):
):
scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping)
return AscendKVCacheMethod(scheme)
elif isinstance(layer, AttentionLayerBase) and self.quant_description.get("kv_cache_type") == "C8":
from .methods.kv_c8 import AscendC8KVCacheAttentionMethod
return AscendKVCacheMethod(AscendC8KVCacheAttentionMethod(self.quant_description, prefix))
elif isinstance(layer, FusedMoE):
if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping):
# Delayed import to avoid circular import