Files
xc-llm-ascend/tests/ut/quantization/test_kv_c8.py
pichangping 3f39ac9c8d [Feature]Supports DSv3.1 PD separation and C8 quantization (#7222)
Co-authored-by: kunpengW-code <1289706727@qq.com>
Co-authored-by: linsheng1 <1950916997@qq.com>

### What this PR does / why we need it?
Currently, chunked prefill is forcibly enabled. DeepSeek V3.1 W8A8C8
supports only the PD separation scenario. C8 refers to quantizing the KV
cache to int8, which aims to reduce the GPU memory usage of the KV cache
and improve the inference throughput.
Constraints: 
1. Only the PD separation mode can be used and
MooncakeLayerwiseConnector can be used to run the model.
2. Currently, only the activation value supports dynamic quantization,
and the KV cache supports static quantization. C8 quantization with MTP
is not supported. You can use ModelSlim for quantization. The
quantization procedure is as follows:
pip install transformers==4.48.2
git clone https://gitcode.com/Ascend/msmodelslim.git
cd msmodelslim
bash install.sh
cd example/DeepSeek/
python3 quant_deepseek_w8a8.py --model_path <path/weight> --save_path
<path/quant_weight>
--anti_dataset../common/deepseek_anti_prompt_50_v3_1.json
--calib_dataset../common/deepseek_calib_prompt_50_v3_1.json --rot
--trust_remote_code True --fa_quant --dynamic --anti_method m6

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: pichangping <1337510399@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Co-authored-by: Wang Kunpeng <1289706727@qq.com>
2026-03-16 22:49:05 +08:00

468 lines
18 KiB
Python

import unittest
import torch
import torch.nn as nn
from unittest.mock import Mock, patch
class TestWeightLoader(unittest.TestCase):
"""Test cases for weight_loader function in kv_c8.py"""
def setUp(self):
"""Set up test environment before each test"""
# Import the module under test
from vllm_ascend.quantization.methods.kv_c8 import weight_loader
self.weight_loader = weight_loader
# Mock distributed functions
self.tp_rank_patch = patch(
"vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_rank"
)
self.tp_size_patch = patch(
"vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_world_size"
)
self.mock_tp_rank = self.tp_rank_patch.start()
self.mock_tp_size = self.tp_size_patch.start()
def tearDown(self):
"""Clean up after each test"""
self.tp_rank_patch.stop()
self.tp_size_patch.stop()
def test_weight_loader_single_element(self):
"""Test weight_loader when both tensors contain a single element"""
# Create tensors with single element
param = torch.tensor([0.0])
loaded_weight = torch.tensor([5.0])
# Call weight_loader
self.weight_loader(param, loaded_weight)
# Verify the value was filled correctly
self.assertEqual(param.item(), 5.0)
self.assertEqual(param.dtype, torch.float32)
def test_weight_loader_single_element_int(self):
"""Test weight_loader with integer tensors"""
param = torch.tensor([0], dtype=torch.int32)
loaded_weight = torch.tensor([10], dtype=torch.int32)
self.weight_loader(param, loaded_weight)
self.assertEqual(param.item(), 10)
def test_weight_loader_tp_sharding_first_rank(self):
"""Test weight_loader with tensor parallelism sharding for first rank"""
# Configure mocks for rank 0 of 4
self.mock_tp_rank.return_value = 0
self.mock_tp_size.return_value = 4
# Create test tensors
param = torch.zeros(2, 5) # Target param shape (2,5)
loaded_weight = torch.ones(8, 5) # Full weight (8,5)
# Mock narrow to track the call
with patch.object(loaded_weight, 'narrow', wraps=loaded_weight.narrow) as mock_narrow:
self.weight_loader(param, loaded_weight)
# Verify narrow was called correctly: narrow(dim=0, start=0, length=2)
mock_narrow.assert_called_once_with(0, 0, 2)
# Verify data was copied
self.assertTrue(torch.all(param == 1))
def test_weight_loader_tp_sharding_middle_rank(self):
"""Test weight_loader with tensor parallelism sharding for middle rank"""
# Configure mocks for rank 2 of 4
self.mock_tp_rank.return_value = 2
self.mock_tp_size.return_value = 4
param = torch.zeros(2, 5)
loaded_weight = torch.ones(8, 5)
with patch.object(loaded_weight, 'narrow', wraps=loaded_weight.narrow) as mock_narrow:
self.weight_loader(param, loaded_weight)
# Verify narrow was called correctly: start = shard_size * rank = 2 * 2 = 4
mock_narrow.assert_called_once_with(0, 4, 2)
self.assertTrue(torch.all(param == 1))
def test_weight_loader_tp_sharding_last_rank(self):
"""Test weight_loader with tensor parallelism sharding for last rank"""
# Configure mocks for rank 3 of 4
self.mock_tp_rank.return_value = 3
self.mock_tp_size.return_value = 4
param = torch.zeros(2, 5)
loaded_weight = torch.ones(8, 5)
with patch.object(loaded_weight, 'narrow', wraps=loaded_weight.narrow) as mock_narrow:
self.weight_loader(param, loaded_weight)
# Verify narrow was called correctly: start = 2 * 3 = 6
mock_narrow.assert_called_once_with(0, 6, 2)
self.assertTrue(torch.all(param == 1))
def test_weight_loader_shape_mismatch(self):
"""Test weight_loader raises assertion error on shape mismatch"""
self.mock_tp_rank.return_value = 0
self.mock_tp_size.return_value = 2
param = torch.zeros(2, 3)
loaded_weight = torch.ones(4, 4) # Different shape after sharding
# Mock narrow to return tensor with wrong shape
with patch.object(loaded_weight, 'narrow', return_value=torch.ones(2, 4)):
with self.assertRaises(AssertionError) as context:
self.weight_loader(param, loaded_weight)
# Verify error message contains expected information
self.assertIn("Attempted to load weight", str(context.exception))
self.assertIn("into parameter", str(context.exception))
def test_weight_loader_with_different_dtypes(self):
"""Test weight_loader handles different dtypes correctly"""
self.mock_tp_rank.return_value = 0
self.mock_tp_size.return_value = 1 # No sharding
param = torch.zeros(2, 3, dtype=torch.float32)
loaded_weight = torch.ones(2, 3, dtype=torch.float16)
self.weight_loader(param, loaded_weight)
# Verify data was copied and converted
self.assertTrue(torch.all(param == 1))
self.assertEqual(param.dtype, torch.float32)
class TestAscendFAQuantAttentionMethodInit(unittest.TestCase):
"""Test cases for AscendFAQuantAttentionMethod initialization"""
def setUp(self):
"""Set up test environment"""
# Mock vllm_config
self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config")
self.mock_get_config = self.config_patch.start()
# Create mock config with attributes
self.mock_config = Mock()
self.mock_hf_config = Mock()
self.mock_hf_config.kv_lora_rank = 128
self.mock_hf_config.qk_rope_head_dim = 64
self.mock_config.model_config.hf_config = self.mock_hf_config
self.mock_get_config.return_value = self.mock_config
# Import the class after patching
from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod
self.method_class = AscendFAQuantAttentionMethod
def tearDown(self):
"""Clean up after each test"""
self.config_patch.stop()
def test_init_with_full_config(self):
"""Test initialization when config has all attributes"""
method = self.method_class()
self.assertTrue(method.transpose_weight)
self.assertFalse(method.printFlag)
self.assertEqual(method.kv_lora_rank, 128)
self.assertEqual(method.qk_rope_head_dim, 64)
def test_init_without_kv_lora_rank(self):
"""Test initialization when config lacks kv_lora_rank"""
delattr(self.mock_hf_config, "kv_lora_rank")
method = self.method_class()
self.assertEqual(method.kv_lora_rank, 0)
self.assertEqual(method.qk_rope_head_dim, 64)
def test_init_without_qk_rope_head_dim(self):
"""Test initialization when config lacks qk_rope_head_dim"""
delattr(self.mock_hf_config, "qk_rope_head_dim")
method = self.method_class()
self.assertEqual(method.kv_lora_rank, 128)
self.assertEqual(method.qk_rope_head_dim, 0)
def test_init_without_both_attributes(self):
"""Test initialization when config lacks both attributes"""
delattr(self.mock_hf_config, "kv_lora_rank")
delattr(self.mock_hf_config, "qk_rope_head_dim")
method = self.method_class()
self.assertEqual(method.kv_lora_rank, 0)
self.assertEqual(method.qk_rope_head_dim, 0)
class TestAscendFAQuantAttentionMethodCreateWeights(unittest.TestCase):
"""Test cases for create_weights method"""
def setUp(self):
"""Set up test environment"""
# Mock vllm_config
self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config")
self.mock_get_config = self.config_patch.start()
self.mock_config = Mock()
self.mock_hf_config = Mock()
self.mock_hf_config.kv_lora_rank = 128
self.mock_hf_config.qk_rope_head_dim = 64
self.mock_config.model_config.hf_config = self.mock_hf_config
self.mock_get_config.return_value = self.mock_config
# Import the class
from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod
self.method_class = AscendFAQuantAttentionMethod
# Mock torch functions
self.default_dtype_patch = patch("torch.get_default_dtype", return_value=torch.float32)
self.mock_default_dtype = self.default_dtype_patch.start()
# Create a real nn.Module for testing
self.layer = nn.Module()
self.layer.num_heads = 32
self.layer.num_kv_heads = 1
def tearDown(self):
"""Clean up after each test"""
self.config_patch.stop()
self.default_dtype_patch.stop()
def test_create_weights_adds_submodules(self):
"""Test that create_weights adds fa_q, fa_k, fa_v submodules"""
method = self.method_class()
with patch("torch.empty") as mock_empty:
mock_empty.return_value = torch.zeros(1, 1)
method.create_weights(self.layer)
# Verify submodules were added
self.assertTrue(hasattr(self.layer, "fa_q"))
self.assertTrue(hasattr(self.layer, "fa_k"))
self.assertTrue(hasattr(self.layer, "fa_v"))
# Verify they are instances of nn.Module
self.assertIsInstance(self.layer.fa_q, nn.Module)
self.assertIsInstance(self.layer.fa_k, nn.Module)
self.assertIsInstance(self.layer.fa_v, nn.Module)
def test_create_weights_creates_correct_tensors(self):
"""Test that create_weights creates tensors with correct shapes and dtypes"""
method = self.method_class()
# Track torch.empty calls
empty_calls = []
def mock_empty(size, dtype=None):
empty_calls.append((size, dtype))
return torch.zeros(size, dtype=dtype if dtype else torch.float32)
with patch("torch.empty", side_effect=mock_empty):
method.create_weights(self.layer)
# Verify tensor creations
expected_calls = [
((32, 1), torch.float32), # fa_q.scale
((1, 1), torch.float32), # fa_k.scale
((1, 1), torch.float32), # fa_v.scale
((32, 1), torch.int8), # fa_q.offset
((1, 1), torch.int8), # fa_k.offset
((1, 1), torch.int8), # fa_v.offset
]
# Compare without considering order
self.assertEqual(len(empty_calls), len(expected_calls))
for call in expected_calls:
self.assertIn(call, empty_calls)
def test_create_weights_registers_parameters(self):
"""Test that create_weights registers parameters with correct attributes"""
method = self.method_class()
# Create real tensors for testing
def create_tensor(*args, **kwargs):
size = args[0] if args else kwargs.get('size', (1,))
dtype = kwargs.get('dtype', torch.float32)
return torch.zeros(*size, dtype=dtype)
with patch("torch.empty", side_effect=create_tensor):
method.create_weights(self.layer)
# Import weight_loader for comparison
from vllm_ascend.quantization.methods.kv_c8 import weight_loader
# Verify each parameter exists and has weight_loader
self.assertTrue(hasattr(self.layer.fa_q, "scale"))
self.assertTrue(hasattr(self.layer.fa_q.scale, "weight_loader"))
self.assertEqual(self.layer.fa_q.scale.weight_loader, weight_loader)
self.assertFalse(self.layer.fa_q.scale.requires_grad)
self.assertTrue(hasattr(self.layer.fa_k, "scale"))
self.assertTrue(hasattr(self.layer.fa_k.scale, "weight_loader"))
self.assertTrue(hasattr(self.layer.fa_v, "scale"))
self.assertTrue(hasattr(self.layer.fa_v.scale, "weight_loader"))
self.assertTrue(hasattr(self.layer.fa_q, "offset"))
self.assertTrue(hasattr(self.layer.fa_q.offset, "weight_loader"))
self.assertEqual(self.layer.fa_q.offset.dtype, torch.int8)
class TestAscendFAQuantAttentionMethodProcessWeights(unittest.TestCase):
"""Test cases for process_weights_after_loading method"""
def setUp(self):
"""Set up test environment"""
# Mock vllm_config
self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config")
self.mock_get_config = self.config_patch.start()
self.mock_config = Mock()
self.mock_hf_config = Mock()
self.mock_hf_config.kv_lora_rank = 64
self.mock_hf_config.qk_rope_head_dim = 32
self.mock_config.model_config.hf_config = self.mock_hf_config
self.mock_get_config.return_value = self.mock_config
# Import the class
from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod
self.method_class = AscendFAQuantAttentionMethod
# Create method instance with real layer
self.method = self.method_class()
# Create a real nn.Module for testing
self.layer = nn.Module()
# Create real tensors for fa_k
self.fa_k_scale = torch.tensor([[2.0, 3.0, 4.0]], dtype=torch.float16) # Shape (1,3)
self.fa_k_offset = torch.tensor([[1, 2, 3]], dtype=torch.int8) # Shape (1,3)
# Create fa_k module with parameters
self.layer.fa_k = nn.Module()
self.layer.fa_k.scale = nn.Parameter(self.fa_k_scale, requires_grad=False)
self.layer.fa_k.offset = nn.Parameter(self.fa_k_offset, requires_grad=False)
def tearDown(self):
"""Clean up after each test"""
self.config_patch.stop()
def test_process_weights_with_single_value_scale(self):
"""Test process_weights with single value scale"""
# Create new layer with single value scale
layer = nn.Module()
layer.fa_k = nn.Module()
layer.fa_k.scale = nn.Parameter(torch.tensor([[2.0]], dtype=torch.float16), requires_grad=False)
layer.fa_k.offset = nn.Parameter(torch.tensor([[1]], dtype=torch.int8), requires_grad=False)
self.method.kv_lora_rank = 4
self.method.process_weights_after_loading(layer)
self.assertEqual(layer.quant_kscale.shape, (1, 4))
self.assertEqual(layer.quant_kscale.dtype, torch.float32)
class TestIntegration(unittest.TestCase):
"""Integration tests for the complete kv_c8 functionality"""
def setUp(self):
"""Set up test environment"""
# Mock vllm_config
self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config")
self.mock_get_config = self.config_patch.start()
self.mock_config = Mock()
self.mock_hf_config = Mock()
self.mock_hf_config.kv_lora_rank = 64
self.mock_hf_config.qk_rope_head_dim = 32
self.mock_config.model_config.hf_config = self.mock_hf_config
self.mock_get_config.return_value = self.mock_config
# Mock distributed functions
self.tp_rank_patch = patch(
"vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_rank"
)
self.tp_size_patch = patch(
"vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_world_size"
)
self.mock_tp_rank = self.tp_rank_patch.start()
self.mock_tp_size = self.tp_size_patch.start()
def tearDown(self):
"""Clean up after each test"""
self.config_patch.stop()
self.tp_rank_patch.stop()
self.tp_size_patch.stop()
def test_complete_workflow(self):
"""Test complete workflow from weight creation to processing"""
from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod
# Create method instance
method = AscendFAQuantAttentionMethod()
# Create real layer
layer = nn.Module()
layer.num_heads = 32
layer.num_kv_heads = 1
# Step 1: Create weights
method.create_weights(layer)
# Verify weights were created with correct structure
self.assertTrue(hasattr(layer, "fa_q"))
self.assertTrue(hasattr(layer, "fa_k"))
self.assertTrue(hasattr(layer, "fa_v"))
self.assertTrue(hasattr(layer.fa_q, "scale"))
self.assertTrue(hasattr(layer.fa_q, "offset"))
self.assertTrue(hasattr(layer.fa_k, "scale"))
self.assertTrue(hasattr(layer.fa_k, "offset"))
self.assertTrue(hasattr(layer.fa_v, "scale"))
self.assertTrue(hasattr(layer.fa_v, "offset"))
# Step 2: Simulate weight loading
self.mock_tp_rank.return_value = 0
self.mock_tp_size.return_value = 1
# Create dummy weights
q_scale = torch.randn(32, 1)
k_scale = torch.randn(1, 1)
v_scale = torch.randn(1, 1)
q_offset = torch.randint(-128, 127, (32, 1), dtype=torch.int8)
k_offset = torch.randint(-128, 127, (1, 1), dtype=torch.int8)
v_offset = torch.randint(-128, 127, (1, 1), dtype=torch.int8)
# Load weights using weight_loader
from vllm_ascend.quantization.methods.kv_c8 import weight_loader
with torch.no_grad():
weight_loader(layer.fa_q.scale, q_scale)
weight_loader(layer.fa_k.scale, k_scale)
weight_loader(layer.fa_v.scale, v_scale)
weight_loader(layer.fa_q.offset, q_offset)
weight_loader(layer.fa_k.offset, k_offset)
weight_loader(layer.fa_v.offset, v_offset)
# Verify weights were loaded correctly
self.assertTrue(torch.all(layer.fa_q.scale == q_scale))
self.assertTrue(torch.all(layer.fa_k.scale == k_scale))
self.assertTrue(torch.all(layer.fa_v.scale == v_scale))
# Step 3: Process after loading
method.process_weights_after_loading(layer)
# Verify processed parameters
self.assertTrue(hasattr(layer, "fak_descale"))
self.assertTrue(hasattr(layer, "fak_offset"))
self.assertTrue(hasattr(layer, "quant_kscale"))
if __name__ == "__main__":
unittest.main(verbosity=2)