From 044d4c39742df98ef83fb208e19deab6fbd1ca40 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Wed, 8 Apr 2026 10:51:58 +0800 Subject: [PATCH] [v0.18.0]feat(quant): add C8 INT8 KV cache support for GQA attention models (#7474) (#8007) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit backport of #7474 This PR adds C8 (INT8) KV cache quantization support for standard GQA attention models (e.g., Qwen3-32B W8A8C8). C8 uses static per-channel quantization scales to store KV cache in INT8, reducing KV cache memory by ~50% compared to BF16, enabling higher batch concurrency and longer context lengths on the same hardware. **Key changes:** 1. **`attention_v1.py`** — New `AscendC8AttentionBackendImpl` subclass of `AscendAttentionBackendImpl`: - `_prepare_c8_scales`: Shards per-channel scales/offsets to the current TP rank and pre-computes BF16 BNSD-shaped antiquant tensors (one-time per layer). - `_quantize_kv_to_int8`: Quantizes BF16 K/V to INT8 before `reshape_and_cache`, using pre-cached inverse scales. - `_forward_c8_decode`: FIA V1 BNSD paged attention with native INT8 KV and `perchannel` antiquant mode. - `_forward_c8_chunked_prefill`: Splits decode (FIA V1 BNSD paged INT8) and prefill (FIA V1 TND float) into two kernel calls. - `_forward_c8_fused_infer_attention`: Handles `PrefillNoCache` and `PrefillCacheHit` states. 2. **`quantization/methods/kv_c8.py`** — New `AscendC8KVCacheAttentionMethod` scheme: - Creates `k/v_cache_scale/offset` parameters via `_c8_kv_scale_weight_loader`, which handles per-channel scale shapes and lazy resizing. - Sets `layer.kv_cache_torch_dtype = torch.int8` so `get_kv_cache_spec()` returns INT8 dtype automatically. - Upgrades `layer.impl` to `AscendC8AttentionBackendImpl` via class surgery. 3. **`quantization/modelslim_config.py`** — C8 branch in `get_quant_method()` activates when `kv_cache_type == "C8"` in `quant_model_description.json`. 4. **`patch/worker/patch_qwen3_c8.py`** — Intercepts per-channel C8 scale/offset weights before `AutoWeightsLoader` discards them, routing them to the parameters created by `AscendC8KVCacheAttentionMethod`. 5. **`tests/ut/quantization/test_kv_c8.py`** — Unit tests covering `_c8_kv_scale_weight_loader`, `AscendC8KVCacheAttentionMethod`, and `AscendC8AttentionBackendImpl` scale helpers. Yes. Users can now serve Qwen3-32B W8A8C8 quantized models with INT8 KV cache on Ascend NPU. The model checkpoint must contain a `quant_model_description.json` with `"kv_cache_type": "C8"` and per-channel scale/offset tensors in safetensors. No changes to the serving CLI — the feature activates automatically when the quantization config is detected. Benchmarked with `vllm serve` (TP=8, `max_num_seqs=256`, `max_model_len=131072`, `enable_chunked_prefill=true`) + `random_bench` (input_len=10240, output_len=2048, 960 prompts, max_concurrency=192): ``` ============ Serving Benchmark Result ============ Successful requests: 960 Failed requests: 0 Maximum request concurrency: 192 Benchmark duration (s): 1359.81 Total input tokens: 9830400 Total generated tokens: 1966080 Request throughput (req/s): 0.71 Output token throughput (tok/s): 1445.85 Peak output token throughput (tok/s): 2304.00 Total token throughput (tok/s): 8675.12 ---------------Time to First Token---------------- Mean TTFT (ms): 24598.51 Median TTFT (ms): 23167.02 P50 TTFT (ms): 23167.02 P90 TTFT (ms): 47717.08 P99 TTFT (ms): 84402.61 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 120.76 Median TPOT (ms): 121.50 P50 TPOT (ms): 121.50 P90 TPOT (ms): 127.05 P99 TPOT (ms): 130.13 ---------------Inter-token Latency---------------- Mean ITL (ms): 120.70 Median ITL (ms): 90.34 P50 ITL (ms): 90.34 P90 ITL (ms): 93.79 P99 ITL (ms): 101.80 ================================================== ``` All attention states verified: `PrefillNoCache`, `PrefillCacheHit`, `ChunkedPrefill`, `DecodeOnly`. - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/8b6325758cce5f9c36d38f2462edbd368b97a07c Signed-off-by: lico67373 <918688502@qq.com> Co-authored-by: LICO67373 <110013619+LICO1314@users.noreply.github.com> --- tests/ut/quantization/test_kv_c8.py | 231 ++++++++++- .../ut/quantization/test_modelslim_config.py | 14 + vllm_ascend/attention/attention_v1.py | 362 ++++++++++++++++++ vllm_ascend/patch/__init__.py | 24 ++ vllm_ascend/patch/worker/__init__.py | 1 + vllm_ascend/patch/worker/patch_qwen3_c8.py | 54 +++ vllm_ascend/quantization/methods/kv_c8.py | 64 +++- vllm_ascend/quantization/modelslim_config.py | 19 + 8 files changed, 761 insertions(+), 8 deletions(-) create mode 100644 vllm_ascend/patch/worker/patch_qwen3_c8.py diff --git a/tests/ut/quantization/test_kv_c8.py b/tests/ut/quantization/test_kv_c8.py index cfafefc9..2173b999 100644 --- a/tests/ut/quantization/test_kv_c8.py +++ b/tests/ut/quantization/test_kv_c8.py @@ -1,7 +1,9 @@ import unittest import torch import torch.nn as nn -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch + +from tests.ut.base import TestBase class TestWeightLoader(unittest.TestCase): @@ -10,7 +12,7 @@ class TestWeightLoader(unittest.TestCase): def setUp(self): """Set up test environment before each test""" # Import the module under test - from vllm_ascend.quantization.methods.kv_c8 import weight_loader + from vllm_ascend.quantization.methods.kv_c8 import _fa_quant_weight_loader as weight_loader self.weight_loader = weight_loader # Mock distributed functions @@ -295,7 +297,7 @@ class TestAscendFAQuantAttentionMethodCreateWeights(unittest.TestCase): method.create_weights(self.layer) # Import weight_loader for comparison - from vllm_ascend.quantization.methods.kv_c8 import weight_loader + from vllm_ascend.quantization.methods.kv_c8 import _fa_quant_weight_loader as weight_loader # Verify each parameter exists and has weight_loader self.assertTrue(hasattr(self.layer.fa_q, "scale")) @@ -440,7 +442,7 @@ class TestIntegration(unittest.TestCase): v_offset = torch.randint(-128, 127, (1, 1), dtype=torch.int8) # Load weights using weight_loader - from vllm_ascend.quantization.methods.kv_c8 import weight_loader + from vllm_ascend.quantization.methods.kv_c8 import _fa_quant_weight_loader as weight_loader with torch.no_grad(): weight_loader(layer.fa_q.scale, q_scale) @@ -464,5 +466,224 @@ class TestIntegration(unittest.TestCase): self.assertTrue(hasattr(layer, "quant_kscale")) +class TestC8KVScaleWeightLoader(TestBase): + """Tests for _c8_kv_scale_weight_loader in kv_c8.py.""" + + def setUp(self): + from vllm_ascend.quantization.methods.kv_c8 import _c8_kv_scale_weight_loader + self.loader = _c8_kv_scale_weight_loader + + def test_shape_match_copies_value(self): + param = nn.Parameter(torch.ones(4, dtype=torch.float32), requires_grad=False) + loaded = torch.tensor([1.0, 2.0, 3.0, 4.0]) + self.loader(param, loaded) + self.assertTrue(torch.allclose(param.data, loaded.float())) + + def test_shape_mismatch_resizes_param(self): + param = nn.Parameter(torch.ones(1, dtype=torch.float32), requires_grad=False) + loaded = torch.arange(8, dtype=torch.float32) + self.loader(param, loaded) + self.assertEqual(param.data.shape, (8,)) + self.assertTrue(torch.allclose(param.data, loaded)) + + def test_squeeze_before_compare(self): + param = nn.Parameter(torch.ones(4, dtype=torch.float32), requires_grad=False) + loaded = torch.arange(4, dtype=torch.float32).unsqueeze(0) # shape [1, 4] + self.loader(param, loaded) + self.assertEqual(param.data.shape, (4,)) + + def test_dtype_preserved_as_param_dtype(self): + param = nn.Parameter(torch.ones(4, dtype=torch.float32), requires_grad=False) + loaded = torch.arange(4, dtype=torch.float16) + self.loader(param, loaded) + self.assertEqual(param.data.dtype, torch.float32) + + +class TestAscendC8KVCacheAttentionMethod(TestBase): + """Tests for AscendC8KVCacheAttentionMethod in kv_c8.py.""" + + def _make_method(self): + from vllm_ascend.quantization.methods.kv_c8 import AscendC8KVCacheAttentionMethod + return AscendC8KVCacheAttentionMethod(quant_description={}, prefix="model.layers.0.self_attn.attn") + + def _make_layer_with_impl(self): + layer = nn.Module() + layer.impl = MagicMock() + return layer + + def test_create_weights_sets_kv_cache_torch_dtype(self): + method = self._make_method() + layer = self._make_layer_with_impl() + method.create_weights(layer) + self.assertEqual(layer.kv_cache_torch_dtype, torch.int8) + + def test_create_weights_registers_scale_offset_params(self): + method = self._make_method() + layer = self._make_layer_with_impl() + method.create_weights(layer) + self.assertIsInstance(layer.k_cache_scale, nn.Parameter) + self.assertIsInstance(layer.k_cache_offset, nn.Parameter) + self.assertIsInstance(layer.v_cache_scale, nn.Parameter) + self.assertIsInstance(layer.v_cache_offset, nn.Parameter) + self.assertFalse(layer.k_cache_scale.requires_grad) + self.assertFalse(layer.v_cache_offset.requires_grad) + + def test_create_weights_initial_values(self): + method = self._make_method() + layer = self._make_layer_with_impl() + method.create_weights(layer) + self.assertEqual(layer.k_cache_scale.data.item(), 1.0) + self.assertEqual(layer.v_cache_scale.data.item(), 1.0) + self.assertEqual(layer.k_cache_offset.data.item(), 0.0) + self.assertEqual(layer.v_cache_offset.data.item(), 0.0) + + def test_create_weights_assigns_weight_loader(self): + from vllm_ascend.quantization.methods.kv_c8 import _c8_kv_scale_weight_loader + method = self._make_method() + layer = self._make_layer_with_impl() + method.create_weights(layer) + self.assertIs(layer.k_cache_scale.weight_loader, _c8_kv_scale_weight_loader) + self.assertIs(layer.v_cache_scale.weight_loader, _c8_kv_scale_weight_loader) + self.assertIs(layer.k_cache_offset.weight_loader, _c8_kv_scale_weight_loader) + self.assertIs(layer.v_cache_offset.weight_loader, _c8_kv_scale_weight_loader) + + def test_process_weights_after_loading_flattens(self): + method = self._make_method() + layer = nn.Module() + layer.k_cache_scale = nn.Parameter(torch.ones(2, 4), requires_grad=False) + layer.k_cache_offset = nn.Parameter(torch.zeros(2, 4), requires_grad=False) + layer.v_cache_scale = nn.Parameter(torch.ones(2, 4), requires_grad=False) + layer.v_cache_offset = nn.Parameter(torch.zeros(2, 4), requires_grad=False) + method.process_weights_after_loading(layer) + self.assertEqual(layer.k_cache_scale.data.dim(), 1) + self.assertEqual(layer.k_cache_scale.data.shape[0], 8) + self.assertEqual(layer.v_cache_offset.data.dim(), 1) + + def test_apply_raises_runtime_error(self): + method = self._make_method() + layer = MagicMock() + with self.assertRaises(RuntimeError): + method.apply(layer, MagicMock(), MagicMock(), MagicMock(), None, None, None, None, None) + + +class TestAscendC8AttentionBackendImplScales(TestBase): + """Tests for AscendC8AttentionBackendImpl scale helpers.""" + + def _make_impl(self, num_kv_heads=4, head_size=8): + from vllm_ascend.attention.attention_v1 import AscendC8AttentionBackendImpl + impl = object.__new__(AscendC8AttentionBackendImpl) + impl.num_heads = num_kv_heads + impl.num_kv_heads = num_kv_heads + impl.head_size = head_size + impl.scale = 1.0 + impl.key_cache = None + impl.value_cache = None + return impl + + def _make_layer(self, num_kv_heads=4, head_size=8): + layer = nn.Module() + layer.k_cache_scale = nn.Parameter( + torch.ones(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False + ) + layer.k_cache_offset = nn.Parameter( + torch.zeros(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False + ) + layer.v_cache_scale = nn.Parameter( + torch.ones(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False + ) + layer.v_cache_offset = nn.Parameter( + torch.zeros(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False + ) + return layer + + @patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0) + @patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1) + def test_prepare_c8_scales_runs_once(self, mock_tp_size, mock_tp_rank): + impl = self._make_impl() + layer = self._make_layer() + impl._prepare_c8_scales(layer, torch.device("cpu")) + self.assertTrue(hasattr(layer, "_c8_scales_prepared")) + self.assertTrue(layer._c8_scales_prepared) + + @patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0) + @patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1) + def test_prepare_c8_scales_idempotent(self, mock_tp_size, mock_tp_rank): + impl = self._make_impl() + layer = self._make_layer() + impl._prepare_c8_scales(layer, torch.device("cpu")) + k_scale_after_first = layer._c8_k_scale.clone() + layer.k_cache_scale.data = torch.ones(32, dtype=torch.float32) * 99 + impl._prepare_c8_scales(layer, torch.device("cpu")) + self.assertTrue(torch.allclose(layer._c8_k_scale, k_scale_after_first)) + + @patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0) + @patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1) + def test_prepare_c8_scales_creates_bnsd_shape(self, mock_tp_size, mock_tp_rank): + num_kv_heads, head_size = 4, 8 + impl = self._make_impl(num_kv_heads, head_size) + layer = self._make_layer(num_kv_heads, head_size) + impl._prepare_c8_scales(layer, torch.device("cpu")) + self.assertEqual(layer._c8_k_aq_scale.shape, (1, num_kv_heads, 1, head_size)) + self.assertEqual(layer._c8_v_aq_scale.shape, (1, num_kv_heads, 1, head_size)) + self.assertEqual(layer._c8_k_aq_scale.dtype, torch.bfloat16) + + @patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0) + @patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1) + def test_quantize_kv_to_int8_output_dtype(self, mock_tp_size, mock_tp_rank): + num_kv_heads, head_size = 4, 8 + impl = self._make_impl(num_kv_heads, head_size) + layer = self._make_layer(num_kv_heads, head_size) + impl._prepare_c8_scales(layer, torch.device("cpu")) + num_tokens = 6 + key = torch.zeros(num_tokens, num_kv_heads, head_size, dtype=torch.bfloat16) + value = torch.zeros(num_tokens, num_kv_heads, head_size, dtype=torch.bfloat16) + key_q, value_q = impl._quantize_kv_to_int8(key, value, layer, num_tokens) + self.assertEqual(key_q.dtype, torch.int8) + self.assertEqual(value_q.dtype, torch.int8) + self.assertEqual(key_q.shape, key.shape) + + @patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0) + @patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1) + def test_quantize_kv_to_int8_formula(self, mock_tp_size, mock_tp_rank): + """With scale=2.0, offset=0: q = round(x / 2).""" + num_kv_heads, head_size = 1, 4 + impl = self._make_impl(num_kv_heads, head_size) + layer = nn.Module() + scale_val = torch.full((num_kv_heads * head_size,), 2.0, dtype=torch.float32) + layer.k_cache_scale = nn.Parameter(scale_val.clone(), requires_grad=False) + layer.k_cache_offset = nn.Parameter(torch.zeros(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False) + layer.v_cache_scale = nn.Parameter(scale_val.clone(), requires_grad=False) + layer.v_cache_offset = nn.Parameter(torch.zeros(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False) + impl._prepare_c8_scales(layer, torch.device("cpu")) + key = torch.full((1, num_kv_heads, head_size), 4.0, dtype=torch.bfloat16) + value = torch.full((1, num_kv_heads, head_size), 4.0, dtype=torch.bfloat16) + key_q, _ = impl._quantize_kv_to_int8(key, value, layer, 1) + self.assertTrue(torch.all(key_q[0] == 2)) + + @patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0) + @patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1) + def test_dequant_paged_kv_to_dense_round_trip(self, mock_tp_size, mock_tp_rank): + """With scale=1, offset=0: dequant(int8) == float(int8).""" + num_kv_heads, head_size = 2, 4 + block_size = 32 + num_blocks = 2 + H = num_kv_heads * head_size + impl = self._make_impl(num_kv_heads, head_size) + layer = self._make_layer(num_kv_heads, head_size) + impl._prepare_c8_scales(layer, torch.device("cpu")) + + key_int8 = torch.randint(-10, 10, (num_blocks, block_size, H), dtype=torch.int8) + value_int8 = torch.randint(-10, 10, (num_blocks, block_size, H), dtype=torch.int8) + seq_lens = [32, 32] + block_table = torch.tensor([[0], [1]], dtype=torch.long) + + dense_k, dense_v = impl._dequant_paged_kv_to_dense( + key_int8, value_int8, block_table, seq_lens, torch.float32, layer + ) + expected_k = key_int8.view(-1, num_kv_heads, head_size).float() + self.assertEqual(dense_k.shape, (64, num_kv_heads, head_size)) + self.assertTrue(torch.allclose(dense_k, expected_k)) + + if __name__ == "__main__": - unittest.main(verbosity=2) \ No newline at end of file + unittest.main(verbosity=2) diff --git a/tests/ut/quantization/test_modelslim_config.py b/tests/ut/quantization/test_modelslim_config.py index 250c576a..8970596f 100644 --- a/tests/ut/quantization/test_modelslim_config.py +++ b/tests/ut/quantization/test_modelslim_config.py @@ -16,6 +16,7 @@ from vllm_ascend.quantization.modelslim_config import ( from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase class TestAscendModelSlimConfig(TestBase): @@ -125,6 +126,19 @@ class TestAscendModelSlimConfig(TestBase): attention_layer, "layers.1.attn") self.assertIs(method, mock_ascend_kvcache.return_value) + def test_get_quant_method_for_c8_kv_cache_attention(self): + c8_config = AscendModelSlimConfig({"kv_cache_type": "C8"}) + attention_layer = MagicMock(spec=AttentionLayerBase) + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.hf_config.model_type = None + with patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_vllm_config), \ + patch("vllm_ascend.quantization.method_adapters.AscendKVCacheMethod", return_value=MagicMock()) as mock_kvcache: + method = c8_config.get_quant_method(attention_layer, "model.layers.0.self_attn.attn") + self.assertIs(method, mock_kvcache.return_value) + args, _ = mock_kvcache.call_args + from vllm_ascend.quantization.methods.kv_c8 import AscendC8KVCacheAttentionMethod + self.assertIsInstance(args[0], AscendC8KVCacheAttentionMethod) + def test_get_quant_method_for_fused_moe(self): fused_moe_layer = MagicMock(spec=FusedMoE) fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 689084f4..443e0451 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -22,6 +22,7 @@ import torch import torch_npu import vllm.envs as envs_vllm from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size from vllm.utils.math_utils import cdiv from vllm.v1.attention.backend import ( # type: ignore AttentionBackend, @@ -978,3 +979,364 @@ class AscendAttentionBackendImpl(AttentionImpl): attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output) output[:num_tokens] = attn_output[:num_tokens] return output + + +class AscendC8AttentionBackendImpl(AscendAttentionBackendImpl): + """Attention backend implementation for INT8 KV cache (C8/QuaRot) models. + + This subclass handles static per-channel INT8 KV cache quantization. + It is activated via class surgery in AscendC8KVCacheAttentionMethod.create_weights + (vllm_ascend/quantization/methods/kv_c8.py) + so that C8 attention layers automatically use this forward path. + """ + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: tuple[torch.Tensor], + attn_metadata: AscendMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError("fused output quantization is not yet supported for AscendC8AttentionBackendImpl") + + num_tokens = query.shape[0] + if attn_metadata is None: + return output.fill_(0) + + float_key, float_value = None, None + if key is not None and value is not None: + if attn_metadata.attn_state != AscendAttentionState.DecodeOnly: + float_key, float_value = key, value + key, value = self._quantize_kv_to_int8(key, value, layer, attn_metadata.num_actual_tokens) + query, key, value, _ = self.reshape_and_cache(query, key, value, kv_cache, attn_metadata, output) + + if attn_metadata.model_runner_type == "pooling": + attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output) + output[:num_tokens] = attn_output[:num_tokens] + return output + + self._prepare_c8_scales(layer, query.device) + if attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + return self._forward_c8_decode(query, attn_metadata, output, layer) + elif attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill: + return self._forward_c8_chunked_prefill(query, float_key, float_value, attn_metadata, output, layer) + else: + return self._forward_c8_fused_infer_attention( + query, + float_key if float_key is not None else key, + float_value if float_value is not None else value, + attn_metadata, + output, + layer, + ) + + def _prepare_c8_scales(self, layer: AttentionLayer, device: torch.device) -> None: + """Shard per-channel C8 scales/offsets to this TP rank and pre-compute + BF16 BNSD antiquant tensors for FIA V1 decode fast path. + """ + if hasattr(layer, "_c8_scales_prepared"): + return + + def _shard_and_reshape(raw: torch.Tensor) -> torch.Tensor: + if raw.numel() == 1: + return raw.to(device=device) + expected = self.num_kv_heads * self.head_size + if raw.numel() != expected: + total_kv_heads = raw.numel() // self.head_size + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + kv_head_start = tp_rank * total_kv_heads // tp_size + raw = raw.view(total_kv_heads, self.head_size)[ + kv_head_start : kv_head_start + self.num_kv_heads + ].contiguous() + return raw.view(1, self.num_kv_heads, self.head_size).to(device=device) + + layer._c8_k_scale = _shard_and_reshape(layer.k_cache_scale.data) + layer._c8_k_offset = _shard_and_reshape(layer.k_cache_offset.data) + layer._c8_v_scale = _shard_and_reshape(layer.v_cache_scale.data) + layer._c8_v_offset = _shard_and_reshape(layer.v_cache_offset.data) + + bnsd = (1, self.num_kv_heads, 1, self.head_size) + layer._c8_k_aq_scale = layer._c8_k_scale.to(torch.bfloat16).view(bnsd).contiguous() + layer._c8_k_aq_offset = layer._c8_k_offset.to(torch.bfloat16).view(bnsd).contiguous() + layer._c8_v_aq_scale = layer._c8_v_scale.to(torch.bfloat16).view(bnsd).contiguous() + layer._c8_v_aq_offset = layer._c8_v_offset.to(torch.bfloat16).view(bnsd).contiguous() + + layer._c8_k_inv_scale_bf16 = (1.0 / layer._c8_k_scale).to(torch.bfloat16) + layer._c8_k_offset_bf16 = layer._c8_k_offset.to(torch.bfloat16) + layer._c8_v_inv_scale_bf16 = (1.0 / layer._c8_v_scale).to(torch.bfloat16) + layer._c8_v_offset_bf16 = layer._c8_v_offset.to(torch.bfloat16) + + layer._c8_scales_prepared = True + + def _dequant_paged_kv_to_dense( + self, + key: torch.Tensor, + value: torch.Tensor, + block_table: torch.Tensor, + seq_lens: list, + target_dtype: torch.dtype, + layer, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Gather paged INT8 KV blocks and dequantize to target_dtype.""" + batch_size = block_table.shape[0] + block_size = key.shape[1] + H = key.shape[2] + max_blocks_per_seq = block_table.shape[1] + max_tokens_padded = max_blocks_per_seq * block_size + + flat_ids = block_table.reshape(-1) + gathered_k = key[flat_ids].view(batch_size, max_tokens_padded, H) + gathered_v = value[flat_ids].view(batch_size, max_tokens_padded, H) + + seq_lens_t = torch.tensor(seq_lens, dtype=torch.long, device=key.device) + positions = torch.arange(max_tokens_padded, dtype=torch.long, device=key.device) + valid_mask = (positions.unsqueeze(0) < seq_lens_t.unsqueeze(1)).view(-1) + + dense_k = gathered_k.view(-1, H)[valid_mask] + dense_v = gathered_v.view(-1, H)[valid_mask] + + dense_k = dense_k.view(-1, self.num_kv_heads, self.head_size) + dense_v = dense_v.view(-1, self.num_kv_heads, self.head_size) + k_scale = layer._c8_k_scale.to(target_dtype) + k_offset = layer._c8_k_offset.to(target_dtype) + v_scale = layer._c8_v_scale.to(target_dtype) + v_offset = layer._c8_v_offset.to(target_dtype) + dense_k = (dense_k.to(target_dtype) - k_offset) * k_scale + dense_v = (dense_v.to(target_dtype) - v_offset) * v_scale + return dense_k, dense_v + + def _quantize_kv_to_int8( + self, + key: torch.Tensor, + value: torch.Tensor, + layer: AttentionLayer, + num_actual_tokens: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize K/V from float to INT8 using static per-channel C8 scales.""" + self._prepare_c8_scales(layer, key.device) + + actual_key = key[:num_actual_tokens] + actual_value = value[:num_actual_tokens] + + k_int8 = torch.clamp( + torch.round(actual_key * layer._c8_k_inv_scale_bf16 + layer._c8_k_offset_bf16), + -128, + 127, + ).to(torch.int8) + v_int8 = torch.clamp( + torch.round(actual_value * layer._c8_v_inv_scale_bf16 + layer._c8_v_offset_bf16), + -128, + 127, + ).to(torch.int8) + return k_int8, v_int8 + + def _forward_c8_decode( + self, + query: torch.Tensor, + attn_metadata: AscendMetadata, + output: torch.Tensor, + layer: AttentionLayer, + ) -> torch.Tensor: + """C8 decode via FIA V1 BNSD with native paged INT8 KV + perchannel antiquant.""" + num_block, block_size, _, _ = self.key_cache.shape # type: ignore[attr-defined] + assert block_size % 32 == 0, f"C8 INT8 KV cache requires block_size to be a multiple of 32, got {block_size}" + key = self.key_cache.view(num_block, block_size, -1) # type: ignore[attr-defined] + value = self.value_cache.view(num_block, block_size, -1) # type: ignore[attr-defined] + batch_size = len(attn_metadata.seq_lens_list) + + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + query[:batch_size].unsqueeze(2), + key, + value, + key_antiquant_scale=layer._c8_k_aq_scale, + key_antiquant_offset=layer._c8_k_aq_offset, + value_antiquant_scale=layer._c8_v_aq_scale, + value_antiquant_offset=layer._c8_v_aq_offset, + block_table=attn_metadata.block_tables, + actual_seq_lengths_kv=attn_metadata.seq_lens_list, + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="BNSD", + scale=self.scale, + block_size=block_size, + key_antiquant_mode=0, + value_antiquant_mode=0, + sparse_mode=0, + ) + attn_output = attn_output.squeeze(2) + output[:batch_size] = attn_output + return output + + def _forward_c8_chunked_prefill( + self, + query: torch.Tensor, + float_key: torch.Tensor | None, + float_value: torch.Tensor | None, + attn_metadata: AscendMetadata, + output: torch.Tensor, + layer: AttentionLayer, + ) -> torch.Tensor: + """C8 ChunkedPrefill: decode via FIA V1 BNSD paged INT8 (zero gather), + prefill via FIA V1 TND with float KV (new) or gather+dequant (continuing). + """ + num_decode_tokens = attn_metadata.num_decode_tokens + num_decodes = attn_metadata.num_decodes + actual_seq_qlen = attn_metadata.actual_seq_lengths_q + num_tokens = int(actual_seq_qlen[-1]) # type: ignore[index] + + if num_decode_tokens > 0: + num_block, block_size, _, _ = self.key_cache.shape # type: ignore[attr-defined] + assert block_size % 32 == 0, ( + f"C8 INT8 KV cache requires block_size to be a multiple of 32, got {block_size}" + ) + kv_k = self.key_cache.view(num_block, block_size, -1) # type: ignore[attr-defined] + kv_v = self.value_cache.view(num_block, block_size, -1) # type: ignore[attr-defined] + + attn_out, _ = torch_npu.npu_fused_infer_attention_score( + query[:num_decode_tokens].unsqueeze(2), + kv_k, + kv_v, + key_antiquant_scale=layer._c8_k_aq_scale, + key_antiquant_offset=layer._c8_k_aq_offset, + value_antiquant_scale=layer._c8_v_aq_scale, + value_antiquant_offset=layer._c8_v_aq_offset, + block_table=attn_metadata.block_tables[:num_decodes], + actual_seq_lengths_kv=attn_metadata.seq_lens_list[:num_decodes], + num_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + input_layout="BNSD", + scale=self.scale, + block_size=block_size, + key_antiquant_mode=0, + value_antiquant_mode=0, + sparse_mode=0, + ) + output[:num_decode_tokens] = attn_out.squeeze(2) + + if attn_metadata.num_prefills > 0: + prefill_q = query[num_decode_tokens:num_tokens] + + prefill_seq_qlen = [ + actual_seq_qlen[i] - num_decode_tokens for i in range(num_decodes, len(actual_seq_qlen)) + ] + + all_new_prefill = True + for i in range(num_decodes, len(attn_metadata.seq_lens_list)): + q_start = actual_seq_qlen[i - 1] if i > 0 else 0 + qlen_i = actual_seq_qlen[i] - q_start + if attn_metadata.seq_lens_list[i] > qlen_i: + all_new_prefill = False + break + + if all_new_prefill and float_key is not None and float_value is not None: + prefill_k = float_key[num_decode_tokens:num_tokens] + prefill_v = float_value[num_decode_tokens:num_tokens] + prefill_seq_kvlen = prefill_seq_qlen + else: + num_block, blk_size, _, _ = self.key_cache.shape # type: ignore[attr-defined] + paged_k = self.key_cache.view(num_block, blk_size, -1) # type: ignore[attr-defined] + paged_v = self.value_cache.view(num_block, blk_size, -1) # type: ignore[attr-defined] + prefill_bt = attn_metadata.block_tables[num_decodes:] + prefill_sl = attn_metadata.seq_lens_list[num_decodes:] + prefill_k, prefill_v = self._dequant_paged_kv_to_dense( + paged_k, paged_v, prefill_bt, prefill_sl, query.dtype, layer + ) + prefill_seq_kvlen = torch.tensor(prefill_sl, dtype=torch.int32).cumsum(dim=0) + + # block_table is None for prefill; FIA ignores block_size in this case. + # Use cache block_size for consistency rather than a magic number. + cache_block_size = self.key_cache.shape[1] # type: ignore[attr-defined] + attn_out, _ = torch_npu.npu_fused_infer_attention_score( + query=prefill_q, + key=prefill_k, + value=prefill_v, + atten_mask=attn_metadata.attn_mask, + block_table=None, + input_layout="TND", + block_size=cache_block_size, + actual_seq_lengths=prefill_seq_qlen, + actual_seq_lengths_kv=prefill_seq_kvlen, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3, + ) + n_prefill = num_tokens - num_decode_tokens + attn_out = attn_out.view(n_prefill, self.num_heads, self.head_size) + output[num_decode_tokens:num_tokens] = attn_out[:n_prefill] + + return output + + def _forward_c8_fused_infer_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata, + output: torch.Tensor, + layer: AttentionLayer, + ): + """C8 FIA V1 TND for prefill states (PrefillNoCache uses float KV directly, + PrefillCacheHit gathers + dequants paged INT8 KV). + """ + self._prepare_c8_scales(layer, query.device) + key, value, block_size, block_table, actual_seq_lengths_kv = self._get_fia_params(key, value, attn_metadata) + + actual_seq_qlen = attn_metadata.actual_seq_lengths_q + num_tokens = int(actual_seq_qlen[-1]) # type: ignore[index] + query = query[:num_tokens] + + if ( + attn_metadata.attn_state == AscendAttentionState.PrefillNoCache + and self.attn_type != AttentionType.ENCODER_DECODER + ): + key = key[:num_tokens] + value = value[:num_tokens] + + if key.dtype == torch.int8: + if block_table is not None: + seq_lens = ( + actual_seq_lengths_kv if isinstance(actual_seq_lengths_kv, list) else actual_seq_lengths_kv.tolist() + ) + key, value = self._dequant_paged_kv_to_dense(key, value, block_table, seq_lens, query.dtype, layer) + block_table = None + # block_table is None after dequant; FIA ignores block_size. + # Use cache block_size for consistency rather than a magic number. + block_size = self.key_cache.shape[1] # type: ignore[attr-defined] + actual_seq_lengths_kv = torch.tensor(seq_lens, dtype=torch.int32).cumsum(dim=0) + else: + qdt = query.dtype + k_scale = layer._c8_k_scale.to(qdt) + k_offset = layer._c8_k_offset.to(qdt) + v_scale = layer._c8_v_scale.to(qdt) + v_offset = layer._c8_v_offset.to(qdt) + key = (key.to(qdt) - k_offset) * k_scale + value = (value.to(qdt) - v_offset) * v_scale + + attn_output, _ = torch_npu.npu_fused_infer_attention_score( + query=query, + key=key, + value=value, + atten_mask=attn_metadata.attn_mask, + block_table=block_table, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=actual_seq_qlen, + actual_seq_lengths_kv=actual_seq_lengths_kv, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3, + ) + attn_output = attn_output.view(num_tokens, self.num_heads, self.head_size) + output[:num_tokens] = attn_output + return output diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index c8da9365..4d8c086d 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -721,3 +721,27 @@ # override _get_deepstack_input_embeds method with the flash comm v1 implementation. # Future Plan: # Remove this patch when https://github.com/vllm-project/vllm-ascend/issues/5712 is completed. +# +# ** 29. File: worker/patch_qwen3_c8.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.qwen3.Qwen3ForCausalLM.load_weights` +# Why: +# The Qwen3 W8A8C8 model stores per-channel KV cache scales and offsets +# (k_cache_scale, k_cache_offset, v_cache_scale, v_cache_offset) under +# weight names that AutoWeightsLoader does not recognise and would +# silently discard. Without these scales the INT8 KV cache cannot be +# dequantised correctly at inference time. +# How: +# Wrap load_weights to intercept the C8 scale/offset tensors before they +# reach the base loader. Each intercepted tensor is routed to the +# corresponding nn.Parameter via its weight_loader, then excluded from +# the remaining weight stream so the base loader never sees it. +# Related PR (if no, explain why): +# This PR (Qwen3-32B W8A8C8 support). Upstream vLLM's weight-loading +# pipeline does not yet have a generic hook for hardware-plugin-defined +# KV cache parameters. +# Future Plan: +# Remove this patch when vLLM provides a first-class extension point +# for loading extra KV cache quantisation parameters in model load_weights, +# or when the Qwen3 model's weight names are aligned with the parameter +# names expected by the quantisation backend. diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index b79ad7c0..afb10063 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -51,3 +51,4 @@ import vllm_ascend.patch.worker.patch_v2.patch_model_state # noqa import vllm_ascend.patch.worker.patch_v2.patch_block_table # noqa import vllm_ascend.patch.worker.patch_qwen3vl # noqa import vllm_ascend.patch.worker.patch_deepencoder2 # noqa +import vllm_ascend.patch.worker.patch_qwen3_c8 # noqa diff --git a/vllm_ascend/patch/worker/patch_qwen3_c8.py b/vllm_ascend/patch/worker/patch_qwen3_c8.py new file mode 100644 index 00000000..674f82e3 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_qwen3_c8.py @@ -0,0 +1,54 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections.abc import Iterable + +import torch +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM + +_orig_qwen3_causal_lm_load_weights = Qwen3ForCausalLM.load_weights + + +def _patched_qwen3_causal_lm_load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + quant_config = self.quant_config + if quant_config is None or not callable(getattr(quant_config, "get_cache_scale", None)): + return _orig_qwen3_causal_lm_load_weights(self, weights) + + params_dict = dict(self.named_parameters()) + c8_loaded_params: set[str] = set() + + def _intercept_c8_scales( + raw_weights: Iterable[tuple[str, torch.Tensor]], + ) -> Iterable[tuple[str, torch.Tensor]]: + for name, loaded_weight in raw_weights: + scale_name = quant_config.get_cache_scale(name) + if scale_name is not None: + if scale_name in params_dict: + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight.squeeze()) + c8_loaded_params.add(scale_name) + else: + yield name, loaded_weight + + loaded_params = _orig_qwen3_causal_lm_load_weights(self, _intercept_c8_scales(weights)) + loaded_params.update(c8_loaded_params) + return loaded_params + + +Qwen3ForCausalLM.load_weights = _patched_qwen3_causal_lm_load_weights diff --git a/vllm_ascend/quantization/methods/kv_c8.py b/vllm_ascend/quantization/methods/kv_c8.py index 10056741..6e1ab302 100644 --- a/vllm_ascend/quantization/methods/kv_c8.py +++ b/vllm_ascend/quantization/methods/kv_c8.py @@ -2,11 +2,12 @@ import torch from vllm.config import get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size +from .base import AscendAttentionScheme from .registry import register_scheme -def weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor): - """fa_q weight loader.""" +def _fa_quant_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor): + """Weight loader for MLA-based C8 (FAKQuant) models.""" if param.numel() == 1 and loaded_weight.numel() == 1: param.data.fill_(loaded_weight.item()) else: @@ -50,7 +51,7 @@ class AscendFAQuantAttentionMethod: weight_param = torch.nn.Parameter(weight, requires_grad=False) module.register_parameter(weight_name, weight_param) # When loading weights, segment them according to TP - weight_param.weight_loader = weight_loader + weight_param.weight_loader = _fa_quant_weight_loader def process_weights_after_loading(self, layer: torch.nn.Module) -> None: fa_k_scale = torch.squeeze(layer.fa_k.scale).unsqueeze(0) @@ -87,3 +88,60 @@ class AscendSFAQuantAttentionMethod: def process_weights_after_loading(self, layer: torch.nn.Module) -> None: pass + + +def _c8_kv_scale_weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None: + """Weight loader for dense-attention C8 KV cache scales/offsets.""" + loaded_weight = loaded_weight.squeeze() + if param.data.shape != loaded_weight.shape: + param.data = loaded_weight.to(param.dtype).clone() + else: + param.data.copy_(loaded_weight) + + +class AscendC8KVCacheAttentionMethod(AscendAttentionScheme): + """C8 INT8 KV cache quantization for dense-attention models (e.g. Qwen3).""" + + def __init__(self, quant_description: dict, prefix: str): + self.quant_description = quant_description + self.prefix = prefix + + def create_weights(self, layer: torch.nn.Module) -> None: + # Override kv_cache_torch_dtype so Attention.get_kv_cache_spec returns int8 automatically. + layer.kv_cache_torch_dtype = torch.int8 + # Upgrade impl to the C8-specific subclass so the C8 forward path is always used. + if hasattr(layer, "impl"): + from vllm_ascend.attention.attention_v1 import AscendC8AttentionBackendImpl + + layer.impl.__class__ = AscendC8AttentionBackendImpl + layer.k_cache_scale = torch.nn.Parameter(torch.ones(1, dtype=torch.float32), requires_grad=False) + layer.k_cache_scale.weight_loader = _c8_kv_scale_weight_loader + layer.k_cache_offset = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=False) + layer.k_cache_offset.weight_loader = _c8_kv_scale_weight_loader + layer.v_cache_scale = torch.nn.Parameter(torch.ones(1, dtype=torch.float32), requires_grad=False) + layer.v_cache_scale.weight_loader = _c8_kv_scale_weight_loader + layer.v_cache_offset = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=False) + layer.v_cache_offset.weight_loader = _c8_kv_scale_weight_loader + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.k_cache_scale.data = layer.k_cache_scale.data.flatten() + layer.k_cache_offset.data = layer.k_cache_offset.data.flatten() + layer.v_cache_scale.data = layer.v_cache_scale.data.flatten() + layer.v_cache_offset.data = layer.v_cache_offset.data.flatten() + + def apply( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache, + attn_metadata, + attn_type, + scale, + output, + ) -> torch.Tensor: + raise RuntimeError( + "AscendC8KVCacheAttentionMethod.apply should not be called. " + "C8 KV cache quantization is handled by the attention backend." + ) diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index da7e03e4..b6ca0c83 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -429,6 +429,21 @@ class AscendModelSlimConfig(QuantizationConfig): self._add_kvcache_quant_metadata() logger.info("Applied hf_to_vllm_mapper to quant_description keys") + def get_cache_scale(self, name: str) -> str | None: + """Map checkpoint C8 KV scale/offset names to vLLM parameter names.""" + if self.quant_description.get("kv_cache_type") != "C8": + return None + _C8_SCALE_MAPPING = { + "k_proj.kv_cache_scale": "attn.k_cache_scale", + "k_proj.kv_cache_offset": "attn.k_cache_offset", + "v_proj.kv_cache_scale": "attn.v_cache_scale", + "v_proj.kv_cache_offset": "attn.v_cache_offset", + } + for src_suffix, dst_suffix in _C8_SCALE_MAPPING.items(): + if name.endswith(src_suffix): + return name[: -len(src_suffix)] + dst_suffix + return None + def quant_prefix_mapper(self, model_type: str, prefix: str) -> str: self.model_type = model_type return prefix @@ -476,6 +491,10 @@ class AscendModelSlimConfig(QuantizationConfig): ): scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping) return AscendKVCacheMethod(scheme) + elif isinstance(layer, AttentionLayerBase) and self.quant_description.get("kv_cache_type") == "C8": + from .methods.kv_c8 import AscendC8KVCacheAttentionMethod + + return AscendKVCacheMethod(AscendC8KVCacheAttentionMethod(self.quant_description, prefix)) elif isinstance(layer, FusedMoE): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): # Delayed import to avoid circular import