[v0.18.0]feat(quant): add C8 INT8 KV cache support for GQA attention models (#7474) (#8007)

backport of #7474

This PR adds C8 (INT8) KV cache quantization support for standard GQA
attention models (e.g., Qwen3-32B W8A8C8). C8 uses static per-channel
quantization scales to store KV cache in INT8, reducing KV cache memory
by ~50% compared to BF16, enabling higher batch concurrency and longer
context lengths on the same hardware.

**Key changes:**

1. **`attention_v1.py`** — New `AscendC8AttentionBackendImpl` subclass
of `AscendAttentionBackendImpl`:
- `_prepare_c8_scales`: Shards per-channel scales/offsets to the current
TP rank and pre-computes BF16 BNSD-shaped antiquant tensors (one-time
per layer).
- `_quantize_kv_to_int8`: Quantizes BF16 K/V to INT8 before
`reshape_and_cache`, using pre-cached inverse scales.
- `_forward_c8_decode`: FIA V1 BNSD paged attention with native INT8 KV
and `perchannel` antiquant mode.
- `_forward_c8_chunked_prefill`: Splits decode (FIA V1 BNSD paged INT8)
and prefill (FIA V1 TND float) into two kernel calls.
- `_forward_c8_fused_infer_attention`: Handles `PrefillNoCache` and
`PrefillCacheHit` states.

2. **`quantization/methods/kv_c8.py`** — New
`AscendC8KVCacheAttentionMethod` scheme:
- Creates `k/v_cache_scale/offset` parameters via
`_c8_kv_scale_weight_loader`, which handles per-channel scale shapes and
lazy resizing.
- Sets `layer.kv_cache_torch_dtype = torch.int8` so
`get_kv_cache_spec()` returns INT8 dtype automatically.
- Upgrades `layer.impl` to `AscendC8AttentionBackendImpl` via class
surgery.

3. **`quantization/modelslim_config.py`** — C8 branch in
`get_quant_method()` activates when `kv_cache_type == "C8"` in
`quant_model_description.json`.

4. **`patch/worker/patch_qwen3_c8.py`** — Intercepts per-channel C8
scale/offset weights before `AutoWeightsLoader` discards them, routing
them to the parameters created by `AscendC8KVCacheAttentionMethod`.

5. **`tests/ut/quantization/test_kv_c8.py`** — Unit tests covering
`_c8_kv_scale_weight_loader`, `AscendC8KVCacheAttentionMethod`, and
`AscendC8AttentionBackendImpl` scale helpers.

Yes. Users can now serve Qwen3-32B W8A8C8 quantized models with INT8 KV
cache on Ascend NPU. The model checkpoint must contain a
`quant_model_description.json` with `"kv_cache_type": "C8"` and
per-channel scale/offset tensors in safetensors.

No changes to the serving CLI — the feature activates automatically when
the quantization config is detected.

Benchmarked with `vllm serve` (TP=8, `max_num_seqs=256`,
`max_model_len=131072`, `enable_chunked_prefill=true`) + `random_bench`
(input_len=10240, output_len=2048, 960 prompts, max_concurrency=192):

```
============ Serving Benchmark Result ============
Successful requests:                     960
Failed requests:                         0
Maximum request concurrency:             192
Benchmark duration (s):                  1359.81
Total input tokens:                      9830400
Total generated tokens:                  1966080
Request throughput (req/s):              0.71
Output token throughput (tok/s):         1445.85
Peak output token throughput (tok/s):    2304.00
Total token throughput (tok/s):          8675.12
---------------Time to First Token----------------
Mean TTFT (ms):                          24598.51
Median TTFT (ms):                        23167.02
P50 TTFT (ms):                           23167.02
P90 TTFT (ms):                           47717.08
P99 TTFT (ms):                           84402.61
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          120.76
Median TPOT (ms):                        121.50
P50 TPOT (ms):                           121.50
P90 TPOT (ms):                           127.05
P99 TPOT (ms):                           130.13
---------------Inter-token Latency----------------
Mean ITL (ms):                           120.70
Median ITL (ms):                         90.34
P50 ITL (ms):                            90.34
P90 ITL (ms):                            93.79
P99 ITL (ms):                            101.80
==================================================
```

All attention states verified: `PrefillNoCache`, `PrefillCacheHit`,
`ChunkedPrefill`, `DecodeOnly`.

- vLLM version: v0.17.0
- vLLM main:
8b6325758c

Signed-off-by: lico67373 <918688502@qq.com>
Co-authored-by: LICO67373 <110013619+LICO1314@users.noreply.github.com>
This commit is contained in:
Mengqing Cao
2026-04-08 10:51:58 +08:00
committed by GitHub
parent fbd5d0fd55
commit 044d4c3974
8 changed files with 761 additions and 8 deletions

View File

@@ -1,7 +1,9 @@
import unittest import unittest
import torch import torch
import torch.nn as nn import torch.nn as nn
from unittest.mock import Mock, patch from unittest.mock import MagicMock, Mock, patch
from tests.ut.base import TestBase
class TestWeightLoader(unittest.TestCase): class TestWeightLoader(unittest.TestCase):
@@ -10,7 +12,7 @@ class TestWeightLoader(unittest.TestCase):
def setUp(self): def setUp(self):
"""Set up test environment before each test""" """Set up test environment before each test"""
# Import the module under test # Import the module under test
from vllm_ascend.quantization.methods.kv_c8 import weight_loader from vllm_ascend.quantization.methods.kv_c8 import _fa_quant_weight_loader as weight_loader
self.weight_loader = weight_loader self.weight_loader = weight_loader
# Mock distributed functions # Mock distributed functions
@@ -295,7 +297,7 @@ class TestAscendFAQuantAttentionMethodCreateWeights(unittest.TestCase):
method.create_weights(self.layer) method.create_weights(self.layer)
# Import weight_loader for comparison # Import weight_loader for comparison
from vllm_ascend.quantization.methods.kv_c8 import weight_loader from vllm_ascend.quantization.methods.kv_c8 import _fa_quant_weight_loader as weight_loader
# Verify each parameter exists and has weight_loader # Verify each parameter exists and has weight_loader
self.assertTrue(hasattr(self.layer.fa_q, "scale")) self.assertTrue(hasattr(self.layer.fa_q, "scale"))
@@ -440,7 +442,7 @@ class TestIntegration(unittest.TestCase):
v_offset = torch.randint(-128, 127, (1, 1), dtype=torch.int8) v_offset = torch.randint(-128, 127, (1, 1), dtype=torch.int8)
# Load weights using weight_loader # Load weights using weight_loader
from vllm_ascend.quantization.methods.kv_c8 import weight_loader from vllm_ascend.quantization.methods.kv_c8 import _fa_quant_weight_loader as weight_loader
with torch.no_grad(): with torch.no_grad():
weight_loader(layer.fa_q.scale, q_scale) weight_loader(layer.fa_q.scale, q_scale)
@@ -464,5 +466,224 @@ class TestIntegration(unittest.TestCase):
self.assertTrue(hasattr(layer, "quant_kscale")) self.assertTrue(hasattr(layer, "quant_kscale"))
class TestC8KVScaleWeightLoader(TestBase):
"""Tests for _c8_kv_scale_weight_loader in kv_c8.py."""
def setUp(self):
from vllm_ascend.quantization.methods.kv_c8 import _c8_kv_scale_weight_loader
self.loader = _c8_kv_scale_weight_loader
def test_shape_match_copies_value(self):
param = nn.Parameter(torch.ones(4, dtype=torch.float32), requires_grad=False)
loaded = torch.tensor([1.0, 2.0, 3.0, 4.0])
self.loader(param, loaded)
self.assertTrue(torch.allclose(param.data, loaded.float()))
def test_shape_mismatch_resizes_param(self):
param = nn.Parameter(torch.ones(1, dtype=torch.float32), requires_grad=False)
loaded = torch.arange(8, dtype=torch.float32)
self.loader(param, loaded)
self.assertEqual(param.data.shape, (8,))
self.assertTrue(torch.allclose(param.data, loaded))
def test_squeeze_before_compare(self):
param = nn.Parameter(torch.ones(4, dtype=torch.float32), requires_grad=False)
loaded = torch.arange(4, dtype=torch.float32).unsqueeze(0) # shape [1, 4]
self.loader(param, loaded)
self.assertEqual(param.data.shape, (4,))
def test_dtype_preserved_as_param_dtype(self):
param = nn.Parameter(torch.ones(4, dtype=torch.float32), requires_grad=False)
loaded = torch.arange(4, dtype=torch.float16)
self.loader(param, loaded)
self.assertEqual(param.data.dtype, torch.float32)
class TestAscendC8KVCacheAttentionMethod(TestBase):
"""Tests for AscendC8KVCacheAttentionMethod in kv_c8.py."""
def _make_method(self):
from vllm_ascend.quantization.methods.kv_c8 import AscendC8KVCacheAttentionMethod
return AscendC8KVCacheAttentionMethod(quant_description={}, prefix="model.layers.0.self_attn.attn")
def _make_layer_with_impl(self):
layer = nn.Module()
layer.impl = MagicMock()
return layer
def test_create_weights_sets_kv_cache_torch_dtype(self):
method = self._make_method()
layer = self._make_layer_with_impl()
method.create_weights(layer)
self.assertEqual(layer.kv_cache_torch_dtype, torch.int8)
def test_create_weights_registers_scale_offset_params(self):
method = self._make_method()
layer = self._make_layer_with_impl()
method.create_weights(layer)
self.assertIsInstance(layer.k_cache_scale, nn.Parameter)
self.assertIsInstance(layer.k_cache_offset, nn.Parameter)
self.assertIsInstance(layer.v_cache_scale, nn.Parameter)
self.assertIsInstance(layer.v_cache_offset, nn.Parameter)
self.assertFalse(layer.k_cache_scale.requires_grad)
self.assertFalse(layer.v_cache_offset.requires_grad)
def test_create_weights_initial_values(self):
method = self._make_method()
layer = self._make_layer_with_impl()
method.create_weights(layer)
self.assertEqual(layer.k_cache_scale.data.item(), 1.0)
self.assertEqual(layer.v_cache_scale.data.item(), 1.0)
self.assertEqual(layer.k_cache_offset.data.item(), 0.0)
self.assertEqual(layer.v_cache_offset.data.item(), 0.0)
def test_create_weights_assigns_weight_loader(self):
from vllm_ascend.quantization.methods.kv_c8 import _c8_kv_scale_weight_loader
method = self._make_method()
layer = self._make_layer_with_impl()
method.create_weights(layer)
self.assertIs(layer.k_cache_scale.weight_loader, _c8_kv_scale_weight_loader)
self.assertIs(layer.v_cache_scale.weight_loader, _c8_kv_scale_weight_loader)
self.assertIs(layer.k_cache_offset.weight_loader, _c8_kv_scale_weight_loader)
self.assertIs(layer.v_cache_offset.weight_loader, _c8_kv_scale_weight_loader)
def test_process_weights_after_loading_flattens(self):
method = self._make_method()
layer = nn.Module()
layer.k_cache_scale = nn.Parameter(torch.ones(2, 4), requires_grad=False)
layer.k_cache_offset = nn.Parameter(torch.zeros(2, 4), requires_grad=False)
layer.v_cache_scale = nn.Parameter(torch.ones(2, 4), requires_grad=False)
layer.v_cache_offset = nn.Parameter(torch.zeros(2, 4), requires_grad=False)
method.process_weights_after_loading(layer)
self.assertEqual(layer.k_cache_scale.data.dim(), 1)
self.assertEqual(layer.k_cache_scale.data.shape[0], 8)
self.assertEqual(layer.v_cache_offset.data.dim(), 1)
def test_apply_raises_runtime_error(self):
method = self._make_method()
layer = MagicMock()
with self.assertRaises(RuntimeError):
method.apply(layer, MagicMock(), MagicMock(), MagicMock(), None, None, None, None, None)
class TestAscendC8AttentionBackendImplScales(TestBase):
"""Tests for AscendC8AttentionBackendImpl scale helpers."""
def _make_impl(self, num_kv_heads=4, head_size=8):
from vllm_ascend.attention.attention_v1 import AscendC8AttentionBackendImpl
impl = object.__new__(AscendC8AttentionBackendImpl)
impl.num_heads = num_kv_heads
impl.num_kv_heads = num_kv_heads
impl.head_size = head_size
impl.scale = 1.0
impl.key_cache = None
impl.value_cache = None
return impl
def _make_layer(self, num_kv_heads=4, head_size=8):
layer = nn.Module()
layer.k_cache_scale = nn.Parameter(
torch.ones(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False
)
layer.k_cache_offset = nn.Parameter(
torch.zeros(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False
)
layer.v_cache_scale = nn.Parameter(
torch.ones(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False
)
layer.v_cache_offset = nn.Parameter(
torch.zeros(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False
)
return layer
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0)
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1)
def test_prepare_c8_scales_runs_once(self, mock_tp_size, mock_tp_rank):
impl = self._make_impl()
layer = self._make_layer()
impl._prepare_c8_scales(layer, torch.device("cpu"))
self.assertTrue(hasattr(layer, "_c8_scales_prepared"))
self.assertTrue(layer._c8_scales_prepared)
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0)
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1)
def test_prepare_c8_scales_idempotent(self, mock_tp_size, mock_tp_rank):
impl = self._make_impl()
layer = self._make_layer()
impl._prepare_c8_scales(layer, torch.device("cpu"))
k_scale_after_first = layer._c8_k_scale.clone()
layer.k_cache_scale.data = torch.ones(32, dtype=torch.float32) * 99
impl._prepare_c8_scales(layer, torch.device("cpu"))
self.assertTrue(torch.allclose(layer._c8_k_scale, k_scale_after_first))
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0)
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1)
def test_prepare_c8_scales_creates_bnsd_shape(self, mock_tp_size, mock_tp_rank):
num_kv_heads, head_size = 4, 8
impl = self._make_impl(num_kv_heads, head_size)
layer = self._make_layer(num_kv_heads, head_size)
impl._prepare_c8_scales(layer, torch.device("cpu"))
self.assertEqual(layer._c8_k_aq_scale.shape, (1, num_kv_heads, 1, head_size))
self.assertEqual(layer._c8_v_aq_scale.shape, (1, num_kv_heads, 1, head_size))
self.assertEqual(layer._c8_k_aq_scale.dtype, torch.bfloat16)
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0)
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1)
def test_quantize_kv_to_int8_output_dtype(self, mock_tp_size, mock_tp_rank):
num_kv_heads, head_size = 4, 8
impl = self._make_impl(num_kv_heads, head_size)
layer = self._make_layer(num_kv_heads, head_size)
impl._prepare_c8_scales(layer, torch.device("cpu"))
num_tokens = 6
key = torch.zeros(num_tokens, num_kv_heads, head_size, dtype=torch.bfloat16)
value = torch.zeros(num_tokens, num_kv_heads, head_size, dtype=torch.bfloat16)
key_q, value_q = impl._quantize_kv_to_int8(key, value, layer, num_tokens)
self.assertEqual(key_q.dtype, torch.int8)
self.assertEqual(value_q.dtype, torch.int8)
self.assertEqual(key_q.shape, key.shape)
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0)
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1)
def test_quantize_kv_to_int8_formula(self, mock_tp_size, mock_tp_rank):
"""With scale=2.0, offset=0: q = round(x / 2)."""
num_kv_heads, head_size = 1, 4
impl = self._make_impl(num_kv_heads, head_size)
layer = nn.Module()
scale_val = torch.full((num_kv_heads * head_size,), 2.0, dtype=torch.float32)
layer.k_cache_scale = nn.Parameter(scale_val.clone(), requires_grad=False)
layer.k_cache_offset = nn.Parameter(torch.zeros(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False)
layer.v_cache_scale = nn.Parameter(scale_val.clone(), requires_grad=False)
layer.v_cache_offset = nn.Parameter(torch.zeros(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False)
impl._prepare_c8_scales(layer, torch.device("cpu"))
key = torch.full((1, num_kv_heads, head_size), 4.0, dtype=torch.bfloat16)
value = torch.full((1, num_kv_heads, head_size), 4.0, dtype=torch.bfloat16)
key_q, _ = impl._quantize_kv_to_int8(key, value, layer, 1)
self.assertTrue(torch.all(key_q[0] == 2))
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0)
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1)
def test_dequant_paged_kv_to_dense_round_trip(self, mock_tp_size, mock_tp_rank):
"""With scale=1, offset=0: dequant(int8) == float(int8)."""
num_kv_heads, head_size = 2, 4
block_size = 32
num_blocks = 2
H = num_kv_heads * head_size
impl = self._make_impl(num_kv_heads, head_size)
layer = self._make_layer(num_kv_heads, head_size)
impl._prepare_c8_scales(layer, torch.device("cpu"))
key_int8 = torch.randint(-10, 10, (num_blocks, block_size, H), dtype=torch.int8)
value_int8 = torch.randint(-10, 10, (num_blocks, block_size, H), dtype=torch.int8)
seq_lens = [32, 32]
block_table = torch.tensor([[0], [1]], dtype=torch.long)
dense_k, dense_v = impl._dequant_paged_kv_to_dense(
key_int8, value_int8, block_table, seq_lens, torch.float32, layer
)
expected_k = key_int8.view(-1, num_kv_heads, head_size).float()
self.assertEqual(dense_k.shape, (64, num_kv_heads, head_size))
self.assertTrue(torch.allclose(dense_k, expected_k))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main(verbosity=2) unittest.main(verbosity=2)

View File

@@ -16,6 +16,7 @@ from vllm_ascend.quantization.modelslim_config import (
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
class TestAscendModelSlimConfig(TestBase): class TestAscendModelSlimConfig(TestBase):
@@ -125,6 +126,19 @@ class TestAscendModelSlimConfig(TestBase):
attention_layer, "layers.1.attn") attention_layer, "layers.1.attn")
self.assertIs(method, mock_ascend_kvcache.return_value) self.assertIs(method, mock_ascend_kvcache.return_value)
def test_get_quant_method_for_c8_kv_cache_attention(self):
c8_config = AscendModelSlimConfig({"kv_cache_type": "C8"})
attention_layer = MagicMock(spec=AttentionLayerBase)
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.hf_config.model_type = None
with patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_vllm_config), \
patch("vllm_ascend.quantization.method_adapters.AscendKVCacheMethod", return_value=MagicMock()) as mock_kvcache:
method = c8_config.get_quant_method(attention_layer, "model.layers.0.self_attn.attn")
self.assertIs(method, mock_kvcache.return_value)
args, _ = mock_kvcache.call_args
from vllm_ascend.quantization.methods.kv_c8 import AscendC8KVCacheAttentionMethod
self.assertIsInstance(args[0], AscendC8KVCacheAttentionMethod)
def test_get_quant_method_for_fused_moe(self): def test_get_quant_method_for_fused_moe(self):
fused_moe_layer = MagicMock(spec=FusedMoE) fused_moe_layer = MagicMock(spec=FusedMoE)
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig) fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)

View File

@@ -22,6 +22,7 @@ import torch
import torch_npu import torch_npu
import vllm.envs as envs_vllm import vllm.envs as envs_vllm
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import ( # type: ignore from vllm.v1.attention.backend import ( # type: ignore
AttentionBackend, AttentionBackend,
@@ -978,3 +979,364 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output) attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens] output[:num_tokens] = attn_output[:num_tokens]
return output return output
class AscendC8AttentionBackendImpl(AscendAttentionBackendImpl):
"""Attention backend implementation for INT8 KV cache (C8/QuaRot) models.
This subclass handles static per-channel INT8 KV cache quantization.
It is activated via class surgery in AscendC8KVCacheAttentionMethod.create_weights
(vllm_ascend/quantization/methods/kv_c8.py)
so that C8 attention layers automatically use this forward path.
"""
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError("fused output quantization is not yet supported for AscendC8AttentionBackendImpl")
num_tokens = query.shape[0]
if attn_metadata is None:
return output.fill_(0)
float_key, float_value = None, None
if key is not None and value is not None:
if attn_metadata.attn_state != AscendAttentionState.DecodeOnly:
float_key, float_value = key, value
key, value = self._quantize_kv_to_int8(key, value, layer, attn_metadata.num_actual_tokens)
query, key, value, _ = self.reshape_and_cache(query, key, value, kv_cache, attn_metadata, output)
if attn_metadata.model_runner_type == "pooling":
attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output
self._prepare_c8_scales(layer, query.device)
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
return self._forward_c8_decode(query, attn_metadata, output, layer)
elif attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill:
return self._forward_c8_chunked_prefill(query, float_key, float_value, attn_metadata, output, layer)
else:
return self._forward_c8_fused_infer_attention(
query,
float_key if float_key is not None else key,
float_value if float_value is not None else value,
attn_metadata,
output,
layer,
)
def _prepare_c8_scales(self, layer: AttentionLayer, device: torch.device) -> None:
"""Shard per-channel C8 scales/offsets to this TP rank and pre-compute
BF16 BNSD antiquant tensors for FIA V1 decode fast path.
"""
if hasattr(layer, "_c8_scales_prepared"):
return
def _shard_and_reshape(raw: torch.Tensor) -> torch.Tensor:
if raw.numel() == 1:
return raw.to(device=device)
expected = self.num_kv_heads * self.head_size
if raw.numel() != expected:
total_kv_heads = raw.numel() // self.head_size
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
kv_head_start = tp_rank * total_kv_heads // tp_size
raw = raw.view(total_kv_heads, self.head_size)[
kv_head_start : kv_head_start + self.num_kv_heads
].contiguous()
return raw.view(1, self.num_kv_heads, self.head_size).to(device=device)
layer._c8_k_scale = _shard_and_reshape(layer.k_cache_scale.data)
layer._c8_k_offset = _shard_and_reshape(layer.k_cache_offset.data)
layer._c8_v_scale = _shard_and_reshape(layer.v_cache_scale.data)
layer._c8_v_offset = _shard_and_reshape(layer.v_cache_offset.data)
bnsd = (1, self.num_kv_heads, 1, self.head_size)
layer._c8_k_aq_scale = layer._c8_k_scale.to(torch.bfloat16).view(bnsd).contiguous()
layer._c8_k_aq_offset = layer._c8_k_offset.to(torch.bfloat16).view(bnsd).contiguous()
layer._c8_v_aq_scale = layer._c8_v_scale.to(torch.bfloat16).view(bnsd).contiguous()
layer._c8_v_aq_offset = layer._c8_v_offset.to(torch.bfloat16).view(bnsd).contiguous()
layer._c8_k_inv_scale_bf16 = (1.0 / layer._c8_k_scale).to(torch.bfloat16)
layer._c8_k_offset_bf16 = layer._c8_k_offset.to(torch.bfloat16)
layer._c8_v_inv_scale_bf16 = (1.0 / layer._c8_v_scale).to(torch.bfloat16)
layer._c8_v_offset_bf16 = layer._c8_v_offset.to(torch.bfloat16)
layer._c8_scales_prepared = True
def _dequant_paged_kv_to_dense(
self,
key: torch.Tensor,
value: torch.Tensor,
block_table: torch.Tensor,
seq_lens: list,
target_dtype: torch.dtype,
layer,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Gather paged INT8 KV blocks and dequantize to target_dtype."""
batch_size = block_table.shape[0]
block_size = key.shape[1]
H = key.shape[2]
max_blocks_per_seq = block_table.shape[1]
max_tokens_padded = max_blocks_per_seq * block_size
flat_ids = block_table.reshape(-1)
gathered_k = key[flat_ids].view(batch_size, max_tokens_padded, H)
gathered_v = value[flat_ids].view(batch_size, max_tokens_padded, H)
seq_lens_t = torch.tensor(seq_lens, dtype=torch.long, device=key.device)
positions = torch.arange(max_tokens_padded, dtype=torch.long, device=key.device)
valid_mask = (positions.unsqueeze(0) < seq_lens_t.unsqueeze(1)).view(-1)
dense_k = gathered_k.view(-1, H)[valid_mask]
dense_v = gathered_v.view(-1, H)[valid_mask]
dense_k = dense_k.view(-1, self.num_kv_heads, self.head_size)
dense_v = dense_v.view(-1, self.num_kv_heads, self.head_size)
k_scale = layer._c8_k_scale.to(target_dtype)
k_offset = layer._c8_k_offset.to(target_dtype)
v_scale = layer._c8_v_scale.to(target_dtype)
v_offset = layer._c8_v_offset.to(target_dtype)
dense_k = (dense_k.to(target_dtype) - k_offset) * k_scale
dense_v = (dense_v.to(target_dtype) - v_offset) * v_scale
return dense_k, dense_v
def _quantize_kv_to_int8(
self,
key: torch.Tensor,
value: torch.Tensor,
layer: AttentionLayer,
num_actual_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize K/V from float to INT8 using static per-channel C8 scales."""
self._prepare_c8_scales(layer, key.device)
actual_key = key[:num_actual_tokens]
actual_value = value[:num_actual_tokens]
k_int8 = torch.clamp(
torch.round(actual_key * layer._c8_k_inv_scale_bf16 + layer._c8_k_offset_bf16),
-128,
127,
).to(torch.int8)
v_int8 = torch.clamp(
torch.round(actual_value * layer._c8_v_inv_scale_bf16 + layer._c8_v_offset_bf16),
-128,
127,
).to(torch.int8)
return k_int8, v_int8
def _forward_c8_decode(
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
layer: AttentionLayer,
) -> torch.Tensor:
"""C8 decode via FIA V1 BNSD with native paged INT8 KV + perchannel antiquant."""
num_block, block_size, _, _ = self.key_cache.shape # type: ignore[attr-defined]
assert block_size % 32 == 0, f"C8 INT8 KV cache requires block_size to be a multiple of 32, got {block_size}"
key = self.key_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
value = self.value_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
batch_size = len(attn_metadata.seq_lens_list)
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query[:batch_size].unsqueeze(2),
key,
value,
key_antiquant_scale=layer._c8_k_aq_scale,
key_antiquant_offset=layer._c8_k_aq_offset,
value_antiquant_scale=layer._c8_v_aq_scale,
value_antiquant_offset=layer._c8_v_aq_offset,
block_table=attn_metadata.block_tables,
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BNSD",
scale=self.scale,
block_size=block_size,
key_antiquant_mode=0,
value_antiquant_mode=0,
sparse_mode=0,
)
attn_output = attn_output.squeeze(2)
output[:batch_size] = attn_output
return output
def _forward_c8_chunked_prefill(
self,
query: torch.Tensor,
float_key: torch.Tensor | None,
float_value: torch.Tensor | None,
attn_metadata: AscendMetadata,
output: torch.Tensor,
layer: AttentionLayer,
) -> torch.Tensor:
"""C8 ChunkedPrefill: decode via FIA V1 BNSD paged INT8 (zero gather),
prefill via FIA V1 TND with float KV (new) or gather+dequant (continuing).
"""
num_decode_tokens = attn_metadata.num_decode_tokens
num_decodes = attn_metadata.num_decodes
actual_seq_qlen = attn_metadata.actual_seq_lengths_q
num_tokens = int(actual_seq_qlen[-1]) # type: ignore[index]
if num_decode_tokens > 0:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore[attr-defined]
assert block_size % 32 == 0, (
f"C8 INT8 KV cache requires block_size to be a multiple of 32, got {block_size}"
)
kv_k = self.key_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
kv_v = self.value_cache.view(num_block, block_size, -1) # type: ignore[attr-defined]
attn_out, _ = torch_npu.npu_fused_infer_attention_score(
query[:num_decode_tokens].unsqueeze(2),
kv_k,
kv_v,
key_antiquant_scale=layer._c8_k_aq_scale,
key_antiquant_offset=layer._c8_k_aq_offset,
value_antiquant_scale=layer._c8_v_aq_scale,
value_antiquant_offset=layer._c8_v_aq_offset,
block_table=attn_metadata.block_tables[:num_decodes],
actual_seq_lengths_kv=attn_metadata.seq_lens_list[:num_decodes],
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BNSD",
scale=self.scale,
block_size=block_size,
key_antiquant_mode=0,
value_antiquant_mode=0,
sparse_mode=0,
)
output[:num_decode_tokens] = attn_out.squeeze(2)
if attn_metadata.num_prefills > 0:
prefill_q = query[num_decode_tokens:num_tokens]
prefill_seq_qlen = [
actual_seq_qlen[i] - num_decode_tokens for i in range(num_decodes, len(actual_seq_qlen))
]
all_new_prefill = True
for i in range(num_decodes, len(attn_metadata.seq_lens_list)):
q_start = actual_seq_qlen[i - 1] if i > 0 else 0
qlen_i = actual_seq_qlen[i] - q_start
if attn_metadata.seq_lens_list[i] > qlen_i:
all_new_prefill = False
break
if all_new_prefill and float_key is not None and float_value is not None:
prefill_k = float_key[num_decode_tokens:num_tokens]
prefill_v = float_value[num_decode_tokens:num_tokens]
prefill_seq_kvlen = prefill_seq_qlen
else:
num_block, blk_size, _, _ = self.key_cache.shape # type: ignore[attr-defined]
paged_k = self.key_cache.view(num_block, blk_size, -1) # type: ignore[attr-defined]
paged_v = self.value_cache.view(num_block, blk_size, -1) # type: ignore[attr-defined]
prefill_bt = attn_metadata.block_tables[num_decodes:]
prefill_sl = attn_metadata.seq_lens_list[num_decodes:]
prefill_k, prefill_v = self._dequant_paged_kv_to_dense(
paged_k, paged_v, prefill_bt, prefill_sl, query.dtype, layer
)
prefill_seq_kvlen = torch.tensor(prefill_sl, dtype=torch.int32).cumsum(dim=0)
# block_table is None for prefill; FIA ignores block_size in this case.
# Use cache block_size for consistency rather than a magic number.
cache_block_size = self.key_cache.shape[1] # type: ignore[attr-defined]
attn_out, _ = torch_npu.npu_fused_infer_attention_score(
query=prefill_q,
key=prefill_k,
value=prefill_v,
atten_mask=attn_metadata.attn_mask,
block_table=None,
input_layout="TND",
block_size=cache_block_size,
actual_seq_lengths=prefill_seq_qlen,
actual_seq_lengths_kv=prefill_seq_kvlen,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
n_prefill = num_tokens - num_decode_tokens
attn_out = attn_out.view(n_prefill, self.num_heads, self.head_size)
output[num_decode_tokens:num_tokens] = attn_out[:n_prefill]
return output
def _forward_c8_fused_infer_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: torch.Tensor,
layer: AttentionLayer,
):
"""C8 FIA V1 TND for prefill states (PrefillNoCache uses float KV directly,
PrefillCacheHit gathers + dequants paged INT8 KV).
"""
self._prepare_c8_scales(layer, query.device)
key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata)
actual_seq_qlen = attn_metadata.actual_seq_lengths_q
num_tokens = int(actual_seq_qlen[-1]) # type: ignore[index]
query = query[:num_tokens]
if (
attn_metadata.attn_state == AscendAttentionState.PrefillNoCache
and self.attn_type != AttentionType.ENCODER_DECODER
):
key = key[:num_tokens]
value = value[:num_tokens]
if key.dtype == torch.int8:
if block_table is not None:
seq_lens = (
actual_seq_lengths_kv if isinstance(actual_seq_lengths_kv, list) else actual_seq_lengths_kv.tolist()
)
key, value = self._dequant_paged_kv_to_dense(key, value, block_table, seq_lens, query.dtype, layer)
block_table = None
# block_table is None after dequant; FIA ignores block_size.
# Use cache block_size for consistency rather than a magic number.
block_size = self.key_cache.shape[1] # type: ignore[attr-defined]
actual_seq_lengths_kv = torch.tensor(seq_lens, dtype=torch.int32).cumsum(dim=0)
else:
qdt = query.dtype
k_scale = layer._c8_k_scale.to(qdt)
k_offset = layer._c8_k_offset.to(qdt)
v_scale = layer._c8_v_scale.to(qdt)
v_offset = layer._c8_v_offset.to(qdt)
key = (key.to(qdt) - k_offset) * k_scale
value = (value.to(qdt) - v_offset) * v_scale
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=actual_seq_qlen,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size)
output[:num_tokens] = attn_output
return output

View File

@@ -721,3 +721,27 @@
# override _get_deepstack_input_embeds method with the flash comm v1 implementation. # override _get_deepstack_input_embeds method with the flash comm v1 implementation.
# Future Plan: # Future Plan:
# Remove this patch when https://github.com/vllm-project/vllm-ascend/issues/5712 is completed. # Remove this patch when https://github.com/vllm-project/vllm-ascend/issues/5712 is completed.
#
# ** 29. File: worker/patch_qwen3_c8.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.models.qwen3.Qwen3ForCausalLM.load_weights`
# Why:
# The Qwen3 W8A8C8 model stores per-channel KV cache scales and offsets
# (k_cache_scale, k_cache_offset, v_cache_scale, v_cache_offset) under
# weight names that AutoWeightsLoader does not recognise and would
# silently discard. Without these scales the INT8 KV cache cannot be
# dequantised correctly at inference time.
# How:
# Wrap load_weights to intercept the C8 scale/offset tensors before they
# reach the base loader. Each intercepted tensor is routed to the
# corresponding nn.Parameter via its weight_loader, then excluded from
# the remaining weight stream so the base loader never sees it.
# Related PR (if no, explain why):
# This PR (Qwen3-32B W8A8C8 support). Upstream vLLM's weight-loading
# pipeline does not yet have a generic hook for hardware-plugin-defined
# KV cache parameters.
# Future Plan:
# Remove this patch when vLLM provides a first-class extension point
# for loading extra KV cache quantisation parameters in model load_weights,
# or when the Qwen3 model's weight names are aligned with the parameter
# names expected by the quantisation backend.

View File

@@ -51,3 +51,4 @@ import vllm_ascend.patch.worker.patch_v2.patch_model_state # noqa
import vllm_ascend.patch.worker.patch_v2.patch_block_table # noqa import vllm_ascend.patch.worker.patch_v2.patch_block_table # noqa
import vllm_ascend.patch.worker.patch_qwen3vl # noqa import vllm_ascend.patch.worker.patch_qwen3vl # noqa
import vllm_ascend.patch.worker.patch_deepencoder2 # noqa import vllm_ascend.patch.worker.patch_deepencoder2 # noqa
import vllm_ascend.patch.worker.patch_qwen3_c8 # noqa

View File

@@ -0,0 +1,54 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections.abc import Iterable
import torch
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
_orig_qwen3_causal_lm_load_weights = Qwen3ForCausalLM.load_weights
def _patched_qwen3_causal_lm_load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
quant_config = self.quant_config
if quant_config is None or not callable(getattr(quant_config, "get_cache_scale", None)):
return _orig_qwen3_causal_lm_load_weights(self, weights)
params_dict = dict(self.named_parameters())
c8_loaded_params: set[str] = set()
def _intercept_c8_scales(
raw_weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[tuple[str, torch.Tensor]]:
for name, loaded_weight in raw_weights:
scale_name = quant_config.get_cache_scale(name)
if scale_name is not None:
if scale_name in params_dict:
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight.squeeze())
c8_loaded_params.add(scale_name)
else:
yield name, loaded_weight
loaded_params = _orig_qwen3_causal_lm_load_weights(self, _intercept_c8_scales(weights))
loaded_params.update(c8_loaded_params)
return loaded_params
Qwen3ForCausalLM.load_weights = _patched_qwen3_causal_lm_load_weights

View File

@@ -2,11 +2,12 @@ import torch
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
from .base import AscendAttentionScheme
from .registry import register_scheme from .registry import register_scheme
def weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor): def _fa_quant_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor):
"""fa_q weight loader.""" """Weight loader for MLA-based C8 (FAKQuant) models."""
if param.numel() == 1 and loaded_weight.numel() == 1: if param.numel() == 1 and loaded_weight.numel() == 1:
param.data.fill_(loaded_weight.item()) param.data.fill_(loaded_weight.item())
else: else:
@@ -50,7 +51,7 @@ class AscendFAQuantAttentionMethod:
weight_param = torch.nn.Parameter(weight, requires_grad=False) weight_param = torch.nn.Parameter(weight, requires_grad=False)
module.register_parameter(weight_name, weight_param) module.register_parameter(weight_name, weight_param)
# When loading weights, segment them according to TP # When loading weights, segment them according to TP
weight_param.weight_loader = weight_loader weight_param.weight_loader = _fa_quant_weight_loader
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
fa_k_scale = torch.squeeze(layer.fa_k.scale).unsqueeze(0) fa_k_scale = torch.squeeze(layer.fa_k.scale).unsqueeze(0)
@@ -87,3 +88,60 @@ class AscendSFAQuantAttentionMethod:
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass pass
def _c8_kv_scale_weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None:
"""Weight loader for dense-attention C8 KV cache scales/offsets."""
loaded_weight = loaded_weight.squeeze()
if param.data.shape != loaded_weight.shape:
param.data = loaded_weight.to(param.dtype).clone()
else:
param.data.copy_(loaded_weight)
class AscendC8KVCacheAttentionMethod(AscendAttentionScheme):
"""C8 INT8 KV cache quantization for dense-attention models (e.g. Qwen3)."""
def __init__(self, quant_description: dict, prefix: str):
self.quant_description = quant_description
self.prefix = prefix
def create_weights(self, layer: torch.nn.Module) -> None:
# Override kv_cache_torch_dtype so Attention.get_kv_cache_spec returns int8 automatically.
layer.kv_cache_torch_dtype = torch.int8
# Upgrade impl to the C8-specific subclass so the C8 forward path is always used.
if hasattr(layer, "impl"):
from vllm_ascend.attention.attention_v1 import AscendC8AttentionBackendImpl
layer.impl.__class__ = AscendC8AttentionBackendImpl
layer.k_cache_scale = torch.nn.Parameter(torch.ones(1, dtype=torch.float32), requires_grad=False)
layer.k_cache_scale.weight_loader = _c8_kv_scale_weight_loader
layer.k_cache_offset = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=False)
layer.k_cache_offset.weight_loader = _c8_kv_scale_weight_loader
layer.v_cache_scale = torch.nn.Parameter(torch.ones(1, dtype=torch.float32), requires_grad=False)
layer.v_cache_scale.weight_loader = _c8_kv_scale_weight_loader
layer.v_cache_offset = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=False)
layer.v_cache_offset.weight_loader = _c8_kv_scale_weight_loader
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.k_cache_scale.data = layer.k_cache_scale.data.flatten()
layer.k_cache_offset.data = layer.k_cache_offset.data.flatten()
layer.v_cache_scale.data = layer.v_cache_scale.data.flatten()
layer.v_cache_offset.data = layer.v_cache_offset.data.flatten()
def apply(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache,
attn_metadata,
attn_type,
scale,
output,
) -> torch.Tensor:
raise RuntimeError(
"AscendC8KVCacheAttentionMethod.apply should not be called. "
"C8 KV cache quantization is handled by the attention backend."
)

View File

@@ -429,6 +429,21 @@ class AscendModelSlimConfig(QuantizationConfig):
self._add_kvcache_quant_metadata() self._add_kvcache_quant_metadata()
logger.info("Applied hf_to_vllm_mapper to quant_description keys") logger.info("Applied hf_to_vllm_mapper to quant_description keys")
def get_cache_scale(self, name: str) -> str | None:
"""Map checkpoint C8 KV scale/offset names to vLLM parameter names."""
if self.quant_description.get("kv_cache_type") != "C8":
return None
_C8_SCALE_MAPPING = {
"k_proj.kv_cache_scale": "attn.k_cache_scale",
"k_proj.kv_cache_offset": "attn.k_cache_offset",
"v_proj.kv_cache_scale": "attn.v_cache_scale",
"v_proj.kv_cache_offset": "attn.v_cache_offset",
}
for src_suffix, dst_suffix in _C8_SCALE_MAPPING.items():
if name.endswith(src_suffix):
return name[: -len(src_suffix)] + dst_suffix
return None
def quant_prefix_mapper(self, model_type: str, prefix: str) -> str: def quant_prefix_mapper(self, model_type: str, prefix: str) -> str:
self.model_type = model_type self.model_type = model_type
return prefix return prefix
@@ -476,6 +491,10 @@ class AscendModelSlimConfig(QuantizationConfig):
): ):
scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping) scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping)
return AscendKVCacheMethod(scheme) return AscendKVCacheMethod(scheme)
elif isinstance(layer, AttentionLayerBase) and self.quant_description.get("kv_cache_type") == "C8":
from .methods.kv_c8 import AscendC8KVCacheAttentionMethod
return AscendKVCacheMethod(AscendC8KVCacheAttentionMethod(self.quant_description, prefix))
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping):
# Delayed import to avoid circular import # Delayed import to avoid circular import