backport of #7474
This PR adds C8 (INT8) KV cache quantization support for standard GQA
attention models (e.g., Qwen3-32B W8A8C8). C8 uses static per-channel
quantization scales to store KV cache in INT8, reducing KV cache memory
by ~50% compared to BF16, enabling higher batch concurrency and longer
context lengths on the same hardware.
**Key changes:**
1. **`attention_v1.py`** — New `AscendC8AttentionBackendImpl` subclass
of `AscendAttentionBackendImpl`:
- `_prepare_c8_scales`: Shards per-channel scales/offsets to the current
TP rank and pre-computes BF16 BNSD-shaped antiquant tensors (one-time
per layer).
- `_quantize_kv_to_int8`: Quantizes BF16 K/V to INT8 before
`reshape_and_cache`, using pre-cached inverse scales.
- `_forward_c8_decode`: FIA V1 BNSD paged attention with native INT8 KV
and `perchannel` antiquant mode.
- `_forward_c8_chunked_prefill`: Splits decode (FIA V1 BNSD paged INT8)
and prefill (FIA V1 TND float) into two kernel calls.
- `_forward_c8_fused_infer_attention`: Handles `PrefillNoCache` and
`PrefillCacheHit` states.
2. **`quantization/methods/kv_c8.py`** — New
`AscendC8KVCacheAttentionMethod` scheme:
- Creates `k/v_cache_scale/offset` parameters via
`_c8_kv_scale_weight_loader`, which handles per-channel scale shapes and
lazy resizing.
- Sets `layer.kv_cache_torch_dtype = torch.int8` so
`get_kv_cache_spec()` returns INT8 dtype automatically.
- Upgrades `layer.impl` to `AscendC8AttentionBackendImpl` via class
surgery.
3. **`quantization/modelslim_config.py`** — C8 branch in
`get_quant_method()` activates when `kv_cache_type == "C8"` in
`quant_model_description.json`.
4. **`patch/worker/patch_qwen3_c8.py`** — Intercepts per-channel C8
scale/offset weights before `AutoWeightsLoader` discards them, routing
them to the parameters created by `AscendC8KVCacheAttentionMethod`.
5. **`tests/ut/quantization/test_kv_c8.py`** — Unit tests covering
`_c8_kv_scale_weight_loader`, `AscendC8KVCacheAttentionMethod`, and
`AscendC8AttentionBackendImpl` scale helpers.
Yes. Users can now serve Qwen3-32B W8A8C8 quantized models with INT8 KV
cache on Ascend NPU. The model checkpoint must contain a
`quant_model_description.json` with `"kv_cache_type": "C8"` and
per-channel scale/offset tensors in safetensors.
No changes to the serving CLI — the feature activates automatically when
the quantization config is detected.
Benchmarked with `vllm serve` (TP=8, `max_num_seqs=256`,
`max_model_len=131072`, `enable_chunked_prefill=true`) + `random_bench`
(input_len=10240, output_len=2048, 960 prompts, max_concurrency=192):
```
============ Serving Benchmark Result ============
Successful requests: 960
Failed requests: 0
Maximum request concurrency: 192
Benchmark duration (s): 1359.81
Total input tokens: 9830400
Total generated tokens: 1966080
Request throughput (req/s): 0.71
Output token throughput (tok/s): 1445.85
Peak output token throughput (tok/s): 2304.00
Total token throughput (tok/s): 8675.12
---------------Time to First Token----------------
Mean TTFT (ms): 24598.51
Median TTFT (ms): 23167.02
P50 TTFT (ms): 23167.02
P90 TTFT (ms): 47717.08
P99 TTFT (ms): 84402.61
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 120.76
Median TPOT (ms): 121.50
P50 TPOT (ms): 121.50
P90 TPOT (ms): 127.05
P99 TPOT (ms): 130.13
---------------Inter-token Latency----------------
Mean ITL (ms): 120.70
Median ITL (ms): 90.34
P50 ITL (ms): 90.34
P90 ITL (ms): 93.79
P99 ITL (ms): 101.80
==================================================
```
All attention states verified: `PrefillNoCache`, `PrefillCacheHit`,
`ChunkedPrefill`, `DecodeOnly`.
- vLLM version: v0.17.0
- vLLM main:
8b6325758c
Signed-off-by: lico67373 <918688502@qq.com>
Co-authored-by: LICO67373 <110013619+LICO1314@users.noreply.github.com>
690 lines
30 KiB
Python
690 lines
30 KiB
Python
import unittest
|
|
import torch
|
|
import torch.nn as nn
|
|
from unittest.mock import MagicMock, Mock, patch
|
|
|
|
from tests.ut.base import TestBase
|
|
|
|
|
|
class TestWeightLoader(unittest.TestCase):
|
|
"""Test cases for weight_loader function in kv_c8.py"""
|
|
|
|
def setUp(self):
|
|
"""Set up test environment before each test"""
|
|
# Import the module under test
|
|
from vllm_ascend.quantization.methods.kv_c8 import _fa_quant_weight_loader as weight_loader
|
|
self.weight_loader = weight_loader
|
|
|
|
# Mock distributed functions
|
|
self.tp_rank_patch = patch(
|
|
"vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_rank"
|
|
)
|
|
self.tp_size_patch = patch(
|
|
"vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_world_size"
|
|
)
|
|
self.mock_tp_rank = self.tp_rank_patch.start()
|
|
self.mock_tp_size = self.tp_size_patch.start()
|
|
|
|
def tearDown(self):
|
|
"""Clean up after each test"""
|
|
self.tp_rank_patch.stop()
|
|
self.tp_size_patch.stop()
|
|
|
|
def test_weight_loader_single_element(self):
|
|
"""Test weight_loader when both tensors contain a single element"""
|
|
# Create tensors with single element
|
|
param = torch.tensor([0.0])
|
|
loaded_weight = torch.tensor([5.0])
|
|
|
|
# Call weight_loader
|
|
self.weight_loader(param, loaded_weight)
|
|
|
|
# Verify the value was filled correctly
|
|
self.assertEqual(param.item(), 5.0)
|
|
self.assertEqual(param.dtype, torch.float32)
|
|
|
|
def test_weight_loader_single_element_int(self):
|
|
"""Test weight_loader with integer tensors"""
|
|
param = torch.tensor([0], dtype=torch.int32)
|
|
loaded_weight = torch.tensor([10], dtype=torch.int32)
|
|
|
|
self.weight_loader(param, loaded_weight)
|
|
|
|
self.assertEqual(param.item(), 10)
|
|
|
|
def test_weight_loader_tp_sharding_first_rank(self):
|
|
"""Test weight_loader with tensor parallelism sharding for first rank"""
|
|
# Configure mocks for rank 0 of 4
|
|
self.mock_tp_rank.return_value = 0
|
|
self.mock_tp_size.return_value = 4
|
|
|
|
# Create test tensors
|
|
param = torch.zeros(2, 5) # Target param shape (2,5)
|
|
loaded_weight = torch.ones(8, 5) # Full weight (8,5)
|
|
|
|
# Mock narrow to track the call
|
|
with patch.object(loaded_weight, 'narrow', wraps=loaded_weight.narrow) as mock_narrow:
|
|
self.weight_loader(param, loaded_weight)
|
|
|
|
# Verify narrow was called correctly: narrow(dim=0, start=0, length=2)
|
|
mock_narrow.assert_called_once_with(0, 0, 2)
|
|
|
|
# Verify data was copied
|
|
self.assertTrue(torch.all(param == 1))
|
|
|
|
def test_weight_loader_tp_sharding_middle_rank(self):
|
|
"""Test weight_loader with tensor parallelism sharding for middle rank"""
|
|
# Configure mocks for rank 2 of 4
|
|
self.mock_tp_rank.return_value = 2
|
|
self.mock_tp_size.return_value = 4
|
|
|
|
param = torch.zeros(2, 5)
|
|
loaded_weight = torch.ones(8, 5)
|
|
|
|
with patch.object(loaded_weight, 'narrow', wraps=loaded_weight.narrow) as mock_narrow:
|
|
self.weight_loader(param, loaded_weight)
|
|
|
|
# Verify narrow was called correctly: start = shard_size * rank = 2 * 2 = 4
|
|
mock_narrow.assert_called_once_with(0, 4, 2)
|
|
|
|
self.assertTrue(torch.all(param == 1))
|
|
|
|
def test_weight_loader_tp_sharding_last_rank(self):
|
|
"""Test weight_loader with tensor parallelism sharding for last rank"""
|
|
# Configure mocks for rank 3 of 4
|
|
self.mock_tp_rank.return_value = 3
|
|
self.mock_tp_size.return_value = 4
|
|
|
|
param = torch.zeros(2, 5)
|
|
loaded_weight = torch.ones(8, 5)
|
|
|
|
with patch.object(loaded_weight, 'narrow', wraps=loaded_weight.narrow) as mock_narrow:
|
|
self.weight_loader(param, loaded_weight)
|
|
|
|
# Verify narrow was called correctly: start = 2 * 3 = 6
|
|
mock_narrow.assert_called_once_with(0, 6, 2)
|
|
|
|
self.assertTrue(torch.all(param == 1))
|
|
|
|
def test_weight_loader_shape_mismatch(self):
|
|
"""Test weight_loader raises assertion error on shape mismatch"""
|
|
self.mock_tp_rank.return_value = 0
|
|
self.mock_tp_size.return_value = 2
|
|
|
|
param = torch.zeros(2, 3)
|
|
loaded_weight = torch.ones(4, 4) # Different shape after sharding
|
|
|
|
# Mock narrow to return tensor with wrong shape
|
|
with patch.object(loaded_weight, 'narrow', return_value=torch.ones(2, 4)):
|
|
with self.assertRaises(AssertionError) as context:
|
|
self.weight_loader(param, loaded_weight)
|
|
|
|
# Verify error message contains expected information
|
|
self.assertIn("Attempted to load weight", str(context.exception))
|
|
self.assertIn("into parameter", str(context.exception))
|
|
|
|
def test_weight_loader_with_different_dtypes(self):
|
|
"""Test weight_loader handles different dtypes correctly"""
|
|
self.mock_tp_rank.return_value = 0
|
|
self.mock_tp_size.return_value = 1 # No sharding
|
|
|
|
param = torch.zeros(2, 3, dtype=torch.float32)
|
|
loaded_weight = torch.ones(2, 3, dtype=torch.float16)
|
|
|
|
self.weight_loader(param, loaded_weight)
|
|
|
|
# Verify data was copied and converted
|
|
self.assertTrue(torch.all(param == 1))
|
|
self.assertEqual(param.dtype, torch.float32)
|
|
|
|
|
|
class TestAscendFAQuantAttentionMethodInit(unittest.TestCase):
|
|
"""Test cases for AscendFAQuantAttentionMethod initialization"""
|
|
|
|
def setUp(self):
|
|
"""Set up test environment"""
|
|
# Mock vllm_config
|
|
self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config")
|
|
self.mock_get_config = self.config_patch.start()
|
|
|
|
# Create mock config with attributes
|
|
self.mock_config = Mock()
|
|
self.mock_hf_config = Mock()
|
|
self.mock_hf_config.kv_lora_rank = 128
|
|
self.mock_hf_config.qk_rope_head_dim = 64
|
|
self.mock_config.model_config.hf_config = self.mock_hf_config
|
|
self.mock_get_config.return_value = self.mock_config
|
|
|
|
# Import the class after patching
|
|
from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod
|
|
self.method_class = AscendFAQuantAttentionMethod
|
|
|
|
def tearDown(self):
|
|
"""Clean up after each test"""
|
|
self.config_patch.stop()
|
|
|
|
def test_init_with_full_config(self):
|
|
"""Test initialization when config has all attributes"""
|
|
method = self.method_class()
|
|
|
|
self.assertTrue(method.transpose_weight)
|
|
self.assertFalse(method.printFlag)
|
|
self.assertEqual(method.kv_lora_rank, 128)
|
|
self.assertEqual(method.qk_rope_head_dim, 64)
|
|
|
|
def test_init_without_kv_lora_rank(self):
|
|
"""Test initialization when config lacks kv_lora_rank"""
|
|
delattr(self.mock_hf_config, "kv_lora_rank")
|
|
|
|
method = self.method_class()
|
|
|
|
self.assertEqual(method.kv_lora_rank, 0)
|
|
self.assertEqual(method.qk_rope_head_dim, 64)
|
|
|
|
def test_init_without_qk_rope_head_dim(self):
|
|
"""Test initialization when config lacks qk_rope_head_dim"""
|
|
delattr(self.mock_hf_config, "qk_rope_head_dim")
|
|
|
|
method = self.method_class()
|
|
|
|
self.assertEqual(method.kv_lora_rank, 128)
|
|
self.assertEqual(method.qk_rope_head_dim, 0)
|
|
|
|
def test_init_without_both_attributes(self):
|
|
"""Test initialization when config lacks both attributes"""
|
|
delattr(self.mock_hf_config, "kv_lora_rank")
|
|
delattr(self.mock_hf_config, "qk_rope_head_dim")
|
|
|
|
method = self.method_class()
|
|
|
|
self.assertEqual(method.kv_lora_rank, 0)
|
|
self.assertEqual(method.qk_rope_head_dim, 0)
|
|
|
|
|
|
class TestAscendFAQuantAttentionMethodCreateWeights(unittest.TestCase):
|
|
"""Test cases for create_weights method"""
|
|
|
|
def setUp(self):
|
|
"""Set up test environment"""
|
|
# Mock vllm_config
|
|
self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config")
|
|
self.mock_get_config = self.config_patch.start()
|
|
|
|
self.mock_config = Mock()
|
|
self.mock_hf_config = Mock()
|
|
self.mock_hf_config.kv_lora_rank = 128
|
|
self.mock_hf_config.qk_rope_head_dim = 64
|
|
self.mock_config.model_config.hf_config = self.mock_hf_config
|
|
self.mock_get_config.return_value = self.mock_config
|
|
|
|
# Import the class
|
|
from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod
|
|
self.method_class = AscendFAQuantAttentionMethod
|
|
|
|
# Mock torch functions
|
|
self.default_dtype_patch = patch("torch.get_default_dtype", return_value=torch.float32)
|
|
self.mock_default_dtype = self.default_dtype_patch.start()
|
|
|
|
# Create a real nn.Module for testing
|
|
self.layer = nn.Module()
|
|
self.layer.num_heads = 32
|
|
self.layer.num_kv_heads = 1
|
|
|
|
def tearDown(self):
|
|
"""Clean up after each test"""
|
|
self.config_patch.stop()
|
|
self.default_dtype_patch.stop()
|
|
|
|
def test_create_weights_adds_submodules(self):
|
|
"""Test that create_weights adds fa_q, fa_k, fa_v submodules"""
|
|
method = self.method_class()
|
|
|
|
with patch("torch.empty") as mock_empty:
|
|
mock_empty.return_value = torch.zeros(1, 1)
|
|
|
|
method.create_weights(self.layer)
|
|
|
|
# Verify submodules were added
|
|
self.assertTrue(hasattr(self.layer, "fa_q"))
|
|
self.assertTrue(hasattr(self.layer, "fa_k"))
|
|
self.assertTrue(hasattr(self.layer, "fa_v"))
|
|
|
|
# Verify they are instances of nn.Module
|
|
self.assertIsInstance(self.layer.fa_q, nn.Module)
|
|
self.assertIsInstance(self.layer.fa_k, nn.Module)
|
|
self.assertIsInstance(self.layer.fa_v, nn.Module)
|
|
|
|
def test_create_weights_creates_correct_tensors(self):
|
|
"""Test that create_weights creates tensors with correct shapes and dtypes"""
|
|
method = self.method_class()
|
|
|
|
# Track torch.empty calls
|
|
empty_calls = []
|
|
|
|
def mock_empty(size, dtype=None):
|
|
empty_calls.append((size, dtype))
|
|
return torch.zeros(size, dtype=dtype if dtype else torch.float32)
|
|
|
|
with patch("torch.empty", side_effect=mock_empty):
|
|
method.create_weights(self.layer)
|
|
|
|
# Verify tensor creations
|
|
expected_calls = [
|
|
((32, 1), torch.float32), # fa_q.scale
|
|
((1, 1), torch.float32), # fa_k.scale
|
|
((1, 1), torch.float32), # fa_v.scale
|
|
((32, 1), torch.int8), # fa_q.offset
|
|
((1, 1), torch.int8), # fa_k.offset
|
|
((1, 1), torch.int8), # fa_v.offset
|
|
]
|
|
|
|
# Compare without considering order
|
|
self.assertEqual(len(empty_calls), len(expected_calls))
|
|
for call in expected_calls:
|
|
self.assertIn(call, empty_calls)
|
|
|
|
def test_create_weights_registers_parameters(self):
|
|
"""Test that create_weights registers parameters with correct attributes"""
|
|
method = self.method_class()
|
|
|
|
# Create real tensors for testing
|
|
def create_tensor(*args, **kwargs):
|
|
size = args[0] if args else kwargs.get('size', (1,))
|
|
dtype = kwargs.get('dtype', torch.float32)
|
|
return torch.zeros(*size, dtype=dtype)
|
|
|
|
with patch("torch.empty", side_effect=create_tensor):
|
|
method.create_weights(self.layer)
|
|
|
|
# Import weight_loader for comparison
|
|
from vllm_ascend.quantization.methods.kv_c8 import _fa_quant_weight_loader as weight_loader
|
|
|
|
# Verify each parameter exists and has weight_loader
|
|
self.assertTrue(hasattr(self.layer.fa_q, "scale"))
|
|
self.assertTrue(hasattr(self.layer.fa_q.scale, "weight_loader"))
|
|
self.assertEqual(self.layer.fa_q.scale.weight_loader, weight_loader)
|
|
self.assertFalse(self.layer.fa_q.scale.requires_grad)
|
|
|
|
self.assertTrue(hasattr(self.layer.fa_k, "scale"))
|
|
self.assertTrue(hasattr(self.layer.fa_k.scale, "weight_loader"))
|
|
|
|
self.assertTrue(hasattr(self.layer.fa_v, "scale"))
|
|
self.assertTrue(hasattr(self.layer.fa_v.scale, "weight_loader"))
|
|
|
|
self.assertTrue(hasattr(self.layer.fa_q, "offset"))
|
|
self.assertTrue(hasattr(self.layer.fa_q.offset, "weight_loader"))
|
|
self.assertEqual(self.layer.fa_q.offset.dtype, torch.int8)
|
|
|
|
|
|
class TestAscendFAQuantAttentionMethodProcessWeights(unittest.TestCase):
|
|
"""Test cases for process_weights_after_loading method"""
|
|
|
|
def setUp(self):
|
|
"""Set up test environment"""
|
|
# Mock vllm_config
|
|
self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config")
|
|
self.mock_get_config = self.config_patch.start()
|
|
|
|
self.mock_config = Mock()
|
|
self.mock_hf_config = Mock()
|
|
self.mock_hf_config.kv_lora_rank = 64
|
|
self.mock_hf_config.qk_rope_head_dim = 32
|
|
self.mock_config.model_config.hf_config = self.mock_hf_config
|
|
self.mock_get_config.return_value = self.mock_config
|
|
|
|
# Import the class
|
|
from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod
|
|
self.method_class = AscendFAQuantAttentionMethod
|
|
|
|
# Create method instance with real layer
|
|
self.method = self.method_class()
|
|
|
|
# Create a real nn.Module for testing
|
|
self.layer = nn.Module()
|
|
|
|
# Create real tensors for fa_k
|
|
self.fa_k_scale = torch.tensor([[2.0, 3.0, 4.0]], dtype=torch.float16) # Shape (1,3)
|
|
self.fa_k_offset = torch.tensor([[1, 2, 3]], dtype=torch.int8) # Shape (1,3)
|
|
|
|
# Create fa_k module with parameters
|
|
self.layer.fa_k = nn.Module()
|
|
self.layer.fa_k.scale = nn.Parameter(self.fa_k_scale, requires_grad=False)
|
|
self.layer.fa_k.offset = nn.Parameter(self.fa_k_offset, requires_grad=False)
|
|
|
|
def tearDown(self):
|
|
"""Clean up after each test"""
|
|
self.config_patch.stop()
|
|
|
|
def test_process_weights_with_single_value_scale(self):
|
|
"""Test process_weights with single value scale"""
|
|
# Create new layer with single value scale
|
|
layer = nn.Module()
|
|
layer.fa_k = nn.Module()
|
|
layer.fa_k.scale = nn.Parameter(torch.tensor([[2.0]], dtype=torch.float16), requires_grad=False)
|
|
layer.fa_k.offset = nn.Parameter(torch.tensor([[1]], dtype=torch.int8), requires_grad=False)
|
|
|
|
self.method.kv_lora_rank = 4
|
|
self.method.process_weights_after_loading(layer)
|
|
|
|
self.assertEqual(layer.quant_kscale.shape, (1, 4))
|
|
self.assertEqual(layer.quant_kscale.dtype, torch.float32)
|
|
|
|
|
|
class TestIntegration(unittest.TestCase):
|
|
"""Integration tests for the complete kv_c8 functionality"""
|
|
|
|
def setUp(self):
|
|
"""Set up test environment"""
|
|
# Mock vllm_config
|
|
self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config")
|
|
self.mock_get_config = self.config_patch.start()
|
|
|
|
self.mock_config = Mock()
|
|
self.mock_hf_config = Mock()
|
|
self.mock_hf_config.kv_lora_rank = 64
|
|
self.mock_hf_config.qk_rope_head_dim = 32
|
|
self.mock_config.model_config.hf_config = self.mock_hf_config
|
|
self.mock_get_config.return_value = self.mock_config
|
|
|
|
# Mock distributed functions
|
|
self.tp_rank_patch = patch(
|
|
"vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_rank"
|
|
)
|
|
self.tp_size_patch = patch(
|
|
"vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_world_size"
|
|
)
|
|
self.mock_tp_rank = self.tp_rank_patch.start()
|
|
self.mock_tp_size = self.tp_size_patch.start()
|
|
|
|
def tearDown(self):
|
|
"""Clean up after each test"""
|
|
self.config_patch.stop()
|
|
self.tp_rank_patch.stop()
|
|
self.tp_size_patch.stop()
|
|
|
|
def test_complete_workflow(self):
|
|
"""Test complete workflow from weight creation to processing"""
|
|
from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod
|
|
|
|
# Create method instance
|
|
method = AscendFAQuantAttentionMethod()
|
|
|
|
# Create real layer
|
|
layer = nn.Module()
|
|
layer.num_heads = 32
|
|
layer.num_kv_heads = 1
|
|
|
|
# Step 1: Create weights
|
|
method.create_weights(layer)
|
|
|
|
# Verify weights were created with correct structure
|
|
self.assertTrue(hasattr(layer, "fa_q"))
|
|
self.assertTrue(hasattr(layer, "fa_k"))
|
|
self.assertTrue(hasattr(layer, "fa_v"))
|
|
|
|
self.assertTrue(hasattr(layer.fa_q, "scale"))
|
|
self.assertTrue(hasattr(layer.fa_q, "offset"))
|
|
self.assertTrue(hasattr(layer.fa_k, "scale"))
|
|
self.assertTrue(hasattr(layer.fa_k, "offset"))
|
|
self.assertTrue(hasattr(layer.fa_v, "scale"))
|
|
self.assertTrue(hasattr(layer.fa_v, "offset"))
|
|
|
|
# Step 2: Simulate weight loading
|
|
self.mock_tp_rank.return_value = 0
|
|
self.mock_tp_size.return_value = 1
|
|
|
|
# Create dummy weights
|
|
q_scale = torch.randn(32, 1)
|
|
k_scale = torch.randn(1, 1)
|
|
v_scale = torch.randn(1, 1)
|
|
q_offset = torch.randint(-128, 127, (32, 1), dtype=torch.int8)
|
|
k_offset = torch.randint(-128, 127, (1, 1), dtype=torch.int8)
|
|
v_offset = torch.randint(-128, 127, (1, 1), dtype=torch.int8)
|
|
|
|
# Load weights using weight_loader
|
|
from vllm_ascend.quantization.methods.kv_c8 import _fa_quant_weight_loader as weight_loader
|
|
|
|
with torch.no_grad():
|
|
weight_loader(layer.fa_q.scale, q_scale)
|
|
weight_loader(layer.fa_k.scale, k_scale)
|
|
weight_loader(layer.fa_v.scale, v_scale)
|
|
weight_loader(layer.fa_q.offset, q_offset)
|
|
weight_loader(layer.fa_k.offset, k_offset)
|
|
weight_loader(layer.fa_v.offset, v_offset)
|
|
|
|
# Verify weights were loaded correctly
|
|
self.assertTrue(torch.all(layer.fa_q.scale == q_scale))
|
|
self.assertTrue(torch.all(layer.fa_k.scale == k_scale))
|
|
self.assertTrue(torch.all(layer.fa_v.scale == v_scale))
|
|
|
|
# Step 3: Process after loading
|
|
method.process_weights_after_loading(layer)
|
|
|
|
# Verify processed parameters
|
|
self.assertTrue(hasattr(layer, "fak_descale"))
|
|
self.assertTrue(hasattr(layer, "fak_offset"))
|
|
self.assertTrue(hasattr(layer, "quant_kscale"))
|
|
|
|
|
|
class TestC8KVScaleWeightLoader(TestBase):
|
|
"""Tests for _c8_kv_scale_weight_loader in kv_c8.py."""
|
|
|
|
def setUp(self):
|
|
from vllm_ascend.quantization.methods.kv_c8 import _c8_kv_scale_weight_loader
|
|
self.loader = _c8_kv_scale_weight_loader
|
|
|
|
def test_shape_match_copies_value(self):
|
|
param = nn.Parameter(torch.ones(4, dtype=torch.float32), requires_grad=False)
|
|
loaded = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
|
self.loader(param, loaded)
|
|
self.assertTrue(torch.allclose(param.data, loaded.float()))
|
|
|
|
def test_shape_mismatch_resizes_param(self):
|
|
param = nn.Parameter(torch.ones(1, dtype=torch.float32), requires_grad=False)
|
|
loaded = torch.arange(8, dtype=torch.float32)
|
|
self.loader(param, loaded)
|
|
self.assertEqual(param.data.shape, (8,))
|
|
self.assertTrue(torch.allclose(param.data, loaded))
|
|
|
|
def test_squeeze_before_compare(self):
|
|
param = nn.Parameter(torch.ones(4, dtype=torch.float32), requires_grad=False)
|
|
loaded = torch.arange(4, dtype=torch.float32).unsqueeze(0) # shape [1, 4]
|
|
self.loader(param, loaded)
|
|
self.assertEqual(param.data.shape, (4,))
|
|
|
|
def test_dtype_preserved_as_param_dtype(self):
|
|
param = nn.Parameter(torch.ones(4, dtype=torch.float32), requires_grad=False)
|
|
loaded = torch.arange(4, dtype=torch.float16)
|
|
self.loader(param, loaded)
|
|
self.assertEqual(param.data.dtype, torch.float32)
|
|
|
|
|
|
class TestAscendC8KVCacheAttentionMethod(TestBase):
|
|
"""Tests for AscendC8KVCacheAttentionMethod in kv_c8.py."""
|
|
|
|
def _make_method(self):
|
|
from vllm_ascend.quantization.methods.kv_c8 import AscendC8KVCacheAttentionMethod
|
|
return AscendC8KVCacheAttentionMethod(quant_description={}, prefix="model.layers.0.self_attn.attn")
|
|
|
|
def _make_layer_with_impl(self):
|
|
layer = nn.Module()
|
|
layer.impl = MagicMock()
|
|
return layer
|
|
|
|
def test_create_weights_sets_kv_cache_torch_dtype(self):
|
|
method = self._make_method()
|
|
layer = self._make_layer_with_impl()
|
|
method.create_weights(layer)
|
|
self.assertEqual(layer.kv_cache_torch_dtype, torch.int8)
|
|
|
|
def test_create_weights_registers_scale_offset_params(self):
|
|
method = self._make_method()
|
|
layer = self._make_layer_with_impl()
|
|
method.create_weights(layer)
|
|
self.assertIsInstance(layer.k_cache_scale, nn.Parameter)
|
|
self.assertIsInstance(layer.k_cache_offset, nn.Parameter)
|
|
self.assertIsInstance(layer.v_cache_scale, nn.Parameter)
|
|
self.assertIsInstance(layer.v_cache_offset, nn.Parameter)
|
|
self.assertFalse(layer.k_cache_scale.requires_grad)
|
|
self.assertFalse(layer.v_cache_offset.requires_grad)
|
|
|
|
def test_create_weights_initial_values(self):
|
|
method = self._make_method()
|
|
layer = self._make_layer_with_impl()
|
|
method.create_weights(layer)
|
|
self.assertEqual(layer.k_cache_scale.data.item(), 1.0)
|
|
self.assertEqual(layer.v_cache_scale.data.item(), 1.0)
|
|
self.assertEqual(layer.k_cache_offset.data.item(), 0.0)
|
|
self.assertEqual(layer.v_cache_offset.data.item(), 0.0)
|
|
|
|
def test_create_weights_assigns_weight_loader(self):
|
|
from vllm_ascend.quantization.methods.kv_c8 import _c8_kv_scale_weight_loader
|
|
method = self._make_method()
|
|
layer = self._make_layer_with_impl()
|
|
method.create_weights(layer)
|
|
self.assertIs(layer.k_cache_scale.weight_loader, _c8_kv_scale_weight_loader)
|
|
self.assertIs(layer.v_cache_scale.weight_loader, _c8_kv_scale_weight_loader)
|
|
self.assertIs(layer.k_cache_offset.weight_loader, _c8_kv_scale_weight_loader)
|
|
self.assertIs(layer.v_cache_offset.weight_loader, _c8_kv_scale_weight_loader)
|
|
|
|
def test_process_weights_after_loading_flattens(self):
|
|
method = self._make_method()
|
|
layer = nn.Module()
|
|
layer.k_cache_scale = nn.Parameter(torch.ones(2, 4), requires_grad=False)
|
|
layer.k_cache_offset = nn.Parameter(torch.zeros(2, 4), requires_grad=False)
|
|
layer.v_cache_scale = nn.Parameter(torch.ones(2, 4), requires_grad=False)
|
|
layer.v_cache_offset = nn.Parameter(torch.zeros(2, 4), requires_grad=False)
|
|
method.process_weights_after_loading(layer)
|
|
self.assertEqual(layer.k_cache_scale.data.dim(), 1)
|
|
self.assertEqual(layer.k_cache_scale.data.shape[0], 8)
|
|
self.assertEqual(layer.v_cache_offset.data.dim(), 1)
|
|
|
|
def test_apply_raises_runtime_error(self):
|
|
method = self._make_method()
|
|
layer = MagicMock()
|
|
with self.assertRaises(RuntimeError):
|
|
method.apply(layer, MagicMock(), MagicMock(), MagicMock(), None, None, None, None, None)
|
|
|
|
|
|
class TestAscendC8AttentionBackendImplScales(TestBase):
|
|
"""Tests for AscendC8AttentionBackendImpl scale helpers."""
|
|
|
|
def _make_impl(self, num_kv_heads=4, head_size=8):
|
|
from vllm_ascend.attention.attention_v1 import AscendC8AttentionBackendImpl
|
|
impl = object.__new__(AscendC8AttentionBackendImpl)
|
|
impl.num_heads = num_kv_heads
|
|
impl.num_kv_heads = num_kv_heads
|
|
impl.head_size = head_size
|
|
impl.scale = 1.0
|
|
impl.key_cache = None
|
|
impl.value_cache = None
|
|
return impl
|
|
|
|
def _make_layer(self, num_kv_heads=4, head_size=8):
|
|
layer = nn.Module()
|
|
layer.k_cache_scale = nn.Parameter(
|
|
torch.ones(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False
|
|
)
|
|
layer.k_cache_offset = nn.Parameter(
|
|
torch.zeros(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False
|
|
)
|
|
layer.v_cache_scale = nn.Parameter(
|
|
torch.ones(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False
|
|
)
|
|
layer.v_cache_offset = nn.Parameter(
|
|
torch.zeros(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False
|
|
)
|
|
return layer
|
|
|
|
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0)
|
|
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1)
|
|
def test_prepare_c8_scales_runs_once(self, mock_tp_size, mock_tp_rank):
|
|
impl = self._make_impl()
|
|
layer = self._make_layer()
|
|
impl._prepare_c8_scales(layer, torch.device("cpu"))
|
|
self.assertTrue(hasattr(layer, "_c8_scales_prepared"))
|
|
self.assertTrue(layer._c8_scales_prepared)
|
|
|
|
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0)
|
|
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1)
|
|
def test_prepare_c8_scales_idempotent(self, mock_tp_size, mock_tp_rank):
|
|
impl = self._make_impl()
|
|
layer = self._make_layer()
|
|
impl._prepare_c8_scales(layer, torch.device("cpu"))
|
|
k_scale_after_first = layer._c8_k_scale.clone()
|
|
layer.k_cache_scale.data = torch.ones(32, dtype=torch.float32) * 99
|
|
impl._prepare_c8_scales(layer, torch.device("cpu"))
|
|
self.assertTrue(torch.allclose(layer._c8_k_scale, k_scale_after_first))
|
|
|
|
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0)
|
|
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1)
|
|
def test_prepare_c8_scales_creates_bnsd_shape(self, mock_tp_size, mock_tp_rank):
|
|
num_kv_heads, head_size = 4, 8
|
|
impl = self._make_impl(num_kv_heads, head_size)
|
|
layer = self._make_layer(num_kv_heads, head_size)
|
|
impl._prepare_c8_scales(layer, torch.device("cpu"))
|
|
self.assertEqual(layer._c8_k_aq_scale.shape, (1, num_kv_heads, 1, head_size))
|
|
self.assertEqual(layer._c8_v_aq_scale.shape, (1, num_kv_heads, 1, head_size))
|
|
self.assertEqual(layer._c8_k_aq_scale.dtype, torch.bfloat16)
|
|
|
|
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0)
|
|
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1)
|
|
def test_quantize_kv_to_int8_output_dtype(self, mock_tp_size, mock_tp_rank):
|
|
num_kv_heads, head_size = 4, 8
|
|
impl = self._make_impl(num_kv_heads, head_size)
|
|
layer = self._make_layer(num_kv_heads, head_size)
|
|
impl._prepare_c8_scales(layer, torch.device("cpu"))
|
|
num_tokens = 6
|
|
key = torch.zeros(num_tokens, num_kv_heads, head_size, dtype=torch.bfloat16)
|
|
value = torch.zeros(num_tokens, num_kv_heads, head_size, dtype=torch.bfloat16)
|
|
key_q, value_q = impl._quantize_kv_to_int8(key, value, layer, num_tokens)
|
|
self.assertEqual(key_q.dtype, torch.int8)
|
|
self.assertEqual(value_q.dtype, torch.int8)
|
|
self.assertEqual(key_q.shape, key.shape)
|
|
|
|
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0)
|
|
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1)
|
|
def test_quantize_kv_to_int8_formula(self, mock_tp_size, mock_tp_rank):
|
|
"""With scale=2.0, offset=0: q = round(x / 2)."""
|
|
num_kv_heads, head_size = 1, 4
|
|
impl = self._make_impl(num_kv_heads, head_size)
|
|
layer = nn.Module()
|
|
scale_val = torch.full((num_kv_heads * head_size,), 2.0, dtype=torch.float32)
|
|
layer.k_cache_scale = nn.Parameter(scale_val.clone(), requires_grad=False)
|
|
layer.k_cache_offset = nn.Parameter(torch.zeros(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False)
|
|
layer.v_cache_scale = nn.Parameter(scale_val.clone(), requires_grad=False)
|
|
layer.v_cache_offset = nn.Parameter(torch.zeros(num_kv_heads * head_size, dtype=torch.float32), requires_grad=False)
|
|
impl._prepare_c8_scales(layer, torch.device("cpu"))
|
|
key = torch.full((1, num_kv_heads, head_size), 4.0, dtype=torch.bfloat16)
|
|
value = torch.full((1, num_kv_heads, head_size), 4.0, dtype=torch.bfloat16)
|
|
key_q, _ = impl._quantize_kv_to_int8(key, value, layer, 1)
|
|
self.assertTrue(torch.all(key_q[0] == 2))
|
|
|
|
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_rank", return_value=0)
|
|
@patch("vllm_ascend.attention.attention_v1.get_tensor_model_parallel_world_size", return_value=1)
|
|
def test_dequant_paged_kv_to_dense_round_trip(self, mock_tp_size, mock_tp_rank):
|
|
"""With scale=1, offset=0: dequant(int8) == float(int8)."""
|
|
num_kv_heads, head_size = 2, 4
|
|
block_size = 32
|
|
num_blocks = 2
|
|
H = num_kv_heads * head_size
|
|
impl = self._make_impl(num_kv_heads, head_size)
|
|
layer = self._make_layer(num_kv_heads, head_size)
|
|
impl._prepare_c8_scales(layer, torch.device("cpu"))
|
|
|
|
key_int8 = torch.randint(-10, 10, (num_blocks, block_size, H), dtype=torch.int8)
|
|
value_int8 = torch.randint(-10, 10, (num_blocks, block_size, H), dtype=torch.int8)
|
|
seq_lens = [32, 32]
|
|
block_table = torch.tensor([[0], [1]], dtype=torch.long)
|
|
|
|
dense_k, dense_v = impl._dequant_paged_kv_to_dense(
|
|
key_int8, value_int8, block_table, seq_lens, torch.float32, layer
|
|
)
|
|
expected_k = key_int8.view(-1, num_kv_heads, head_size).float()
|
|
self.assertEqual(dense_k.shape, (64, num_kv_heads, head_size))
|
|
self.assertTrue(torch.allclose(dense_k, expected_k))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|