[Feature]Supports DSv3.1 PD separation and C8 quantization (#7222)

Co-authored-by: kunpengW-code <1289706727@qq.com>
Co-authored-by: linsheng1 <1950916997@qq.com>

### What this PR does / why we need it?
Currently, chunked prefill is forcibly enabled. DeepSeek V3.1 W8A8C8
supports only the PD separation scenario. C8 refers to quantizing the KV
cache to int8, which aims to reduce the GPU memory usage of the KV cache
and improve the inference throughput.
Constraints: 
1. Only the PD separation mode can be used and
MooncakeLayerwiseConnector can be used to run the model.
2. Currently, only the activation value supports dynamic quantization,
and the KV cache supports static quantization. C8 quantization with MTP
is not supported. You can use ModelSlim for quantization. The
quantization procedure is as follows:
pip install transformers==4.48.2
git clone https://gitcode.com/Ascend/msmodelslim.git
cd msmodelslim
bash install.sh
cd example/DeepSeek/
python3 quant_deepseek_w8a8.py --model_path <path/weight> --save_path
<path/quant_weight>
--anti_dataset../common/deepseek_anti_prompt_50_v3_1.json
--calib_dataset../common/deepseek_calib_prompt_50_v3_1.json --rot
--trust_remote_code True --fa_quant --dynamic --anti_method m6

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: pichangping <1337510399@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Co-authored-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
pichangping
2026-03-16 22:49:05 +08:00
committed by GitHub
parent a6f6e919e6
commit 3f39ac9c8d
15 changed files with 1112 additions and 161 deletions

View File

@@ -807,6 +807,7 @@ class TestAscendMLAImpl(TestBase):
attn_type=None, attn_type=None,
kv_sharing_target_layer_name=None, kv_sharing_target_layer_name=None,
**kwargs) **kwargs)
self.impl.fa_quant_layer = False
def test_init(self): def test_init(self):
self.assertEqual(self.impl.num_heads, 256) self.assertEqual(self.impl.num_heads, 256)
@@ -938,9 +939,9 @@ class TestAscendMLAImpl(TestBase):
@patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj") @patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")
@patch("torch_npu.npu_fused_infer_attention_score") @patch("torch_npu.npu_fused_infer_attention_score_v2")
def test_forward_decode_without_graph(self, def test_forward_decode_without_graph(self,
mock_npu_fused_infer_attention_score, mock_npu_fused_infer_attention_score_v2,
mock_up_proj, mock_up_proj,
mock_get_forward_context): mock_get_forward_context):
num_tokens = 100 num_tokens = 100
@@ -956,8 +957,8 @@ class TestAscendMLAImpl(TestBase):
metadata = MagicMock() metadata = MagicMock()
metadata.decode = MagicMock() metadata.decode = MagicMock()
metadata.decode.block_table = MagicMock() metadata.decode.block_table = MagicMock()
metadata.decode.seq_lens = 10 metadata.decode.actual_seq_lengths = 10
mock_npu_fused_infer_attention_score.return_value = [ mock_npu_fused_infer_attention_score_v2.return_value = [
torch.randn(num_tokens, self.impl.num_heads, torch.randn(num_tokens, self.impl.num_heads,
self.impl.kv_lora_rank), None self.impl.kv_lora_rank), None
] ]
@@ -971,7 +972,7 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(result.shape[1], self.impl.num_heads) self.assertEqual(result.shape[1], self.impl.num_heads)
self.assertEqual(result.shape[2], self.impl.v_head_dim) self.assertEqual(result.shape[2], self.impl.v_head_dim)
mock_up_proj.assert_called_once() mock_up_proj.assert_called_once()
mock_npu_fused_infer_attention_score.assert_called_once() mock_npu_fused_infer_attention_score_v2.assert_called_once()
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad") @patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_v1.get_weight_prefetch_method", @patch("vllm_ascend.attention.mla_v1.get_weight_prefetch_method",
@@ -1103,8 +1104,8 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank) self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
@patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch('vllm_ascend.ascend_forward_context.get_forward_context')
@patch("torch_npu.npu_fused_infer_attention_score") @patch("torch_npu.npu_fused_infer_attention_score_v2")
def test_forward_decode(self, mock_npu_fused_infer_attention_score, def test_forward_decode(self, mock_npu_fused_infer_attention_score_v2,
mock_get_forward_context): mock_get_forward_context):
B = 2 B = 2
N = self.impl.num_kv_heads N = self.impl.num_kv_heads
@@ -1121,11 +1122,11 @@ class TestAscendMLAImpl(TestBase):
attn_metadata = MagicMock() attn_metadata = MagicMock()
attn_metadata.attn_state = AscendAttentionState.SpecDecoding attn_metadata.attn_state = AscendAttentionState.SpecDecoding
attn_metadata.decode = MagicMock() attn_metadata.decode = MagicMock()
attn_metadata.decode.actual_seq_lengths_q = MagicMock() attn_metadata.decode.actual_seq_qlen = MagicMock()
attn_metadata.decode.seq_lens_list = MagicMock() attn_metadata.decode.actual_seq_kvlen = MagicMock()
self.impl.enable_kv_nz = True self.impl.enable_kv_nz = True
mock_npu_fused_infer_attention_score.return_value = [ mock_npu_fused_infer_attention_score_v2.return_value = [
torch.randn(B, N, self.impl.kv_lora_rank), None torch.randn(B, N, self.impl.kv_lora_rank), None
] ]
mock_get_forward_context.return_value = MagicMock(capturing=False) mock_get_forward_context.return_value = MagicMock(capturing=False)

View File

@@ -0,0 +1,468 @@
import unittest
import torch
import torch.nn as nn
from unittest.mock import Mock, patch
class TestWeightLoader(unittest.TestCase):
"""Test cases for weight_loader function in kv_c8.py"""
def setUp(self):
"""Set up test environment before each test"""
# Import the module under test
from vllm_ascend.quantization.methods.kv_c8 import weight_loader
self.weight_loader = weight_loader
# Mock distributed functions
self.tp_rank_patch = patch(
"vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_rank"
)
self.tp_size_patch = patch(
"vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_world_size"
)
self.mock_tp_rank = self.tp_rank_patch.start()
self.mock_tp_size = self.tp_size_patch.start()
def tearDown(self):
"""Clean up after each test"""
self.tp_rank_patch.stop()
self.tp_size_patch.stop()
def test_weight_loader_single_element(self):
"""Test weight_loader when both tensors contain a single element"""
# Create tensors with single element
param = torch.tensor([0.0])
loaded_weight = torch.tensor([5.0])
# Call weight_loader
self.weight_loader(param, loaded_weight)
# Verify the value was filled correctly
self.assertEqual(param.item(), 5.0)
self.assertEqual(param.dtype, torch.float32)
def test_weight_loader_single_element_int(self):
"""Test weight_loader with integer tensors"""
param = torch.tensor([0], dtype=torch.int32)
loaded_weight = torch.tensor([10], dtype=torch.int32)
self.weight_loader(param, loaded_weight)
self.assertEqual(param.item(), 10)
def test_weight_loader_tp_sharding_first_rank(self):
"""Test weight_loader with tensor parallelism sharding for first rank"""
# Configure mocks for rank 0 of 4
self.mock_tp_rank.return_value = 0
self.mock_tp_size.return_value = 4
# Create test tensors
param = torch.zeros(2, 5) # Target param shape (2,5)
loaded_weight = torch.ones(8, 5) # Full weight (8,5)
# Mock narrow to track the call
with patch.object(loaded_weight, 'narrow', wraps=loaded_weight.narrow) as mock_narrow:
self.weight_loader(param, loaded_weight)
# Verify narrow was called correctly: narrow(dim=0, start=0, length=2)
mock_narrow.assert_called_once_with(0, 0, 2)
# Verify data was copied
self.assertTrue(torch.all(param == 1))
def test_weight_loader_tp_sharding_middle_rank(self):
"""Test weight_loader with tensor parallelism sharding for middle rank"""
# Configure mocks for rank 2 of 4
self.mock_tp_rank.return_value = 2
self.mock_tp_size.return_value = 4
param = torch.zeros(2, 5)
loaded_weight = torch.ones(8, 5)
with patch.object(loaded_weight, 'narrow', wraps=loaded_weight.narrow) as mock_narrow:
self.weight_loader(param, loaded_weight)
# Verify narrow was called correctly: start = shard_size * rank = 2 * 2 = 4
mock_narrow.assert_called_once_with(0, 4, 2)
self.assertTrue(torch.all(param == 1))
def test_weight_loader_tp_sharding_last_rank(self):
"""Test weight_loader with tensor parallelism sharding for last rank"""
# Configure mocks for rank 3 of 4
self.mock_tp_rank.return_value = 3
self.mock_tp_size.return_value = 4
param = torch.zeros(2, 5)
loaded_weight = torch.ones(8, 5)
with patch.object(loaded_weight, 'narrow', wraps=loaded_weight.narrow) as mock_narrow:
self.weight_loader(param, loaded_weight)
# Verify narrow was called correctly: start = 2 * 3 = 6
mock_narrow.assert_called_once_with(0, 6, 2)
self.assertTrue(torch.all(param == 1))
def test_weight_loader_shape_mismatch(self):
"""Test weight_loader raises assertion error on shape mismatch"""
self.mock_tp_rank.return_value = 0
self.mock_tp_size.return_value = 2
param = torch.zeros(2, 3)
loaded_weight = torch.ones(4, 4) # Different shape after sharding
# Mock narrow to return tensor with wrong shape
with patch.object(loaded_weight, 'narrow', return_value=torch.ones(2, 4)):
with self.assertRaises(AssertionError) as context:
self.weight_loader(param, loaded_weight)
# Verify error message contains expected information
self.assertIn("Attempted to load weight", str(context.exception))
self.assertIn("into parameter", str(context.exception))
def test_weight_loader_with_different_dtypes(self):
"""Test weight_loader handles different dtypes correctly"""
self.mock_tp_rank.return_value = 0
self.mock_tp_size.return_value = 1 # No sharding
param = torch.zeros(2, 3, dtype=torch.float32)
loaded_weight = torch.ones(2, 3, dtype=torch.float16)
self.weight_loader(param, loaded_weight)
# Verify data was copied and converted
self.assertTrue(torch.all(param == 1))
self.assertEqual(param.dtype, torch.float32)
class TestAscendFAQuantAttentionMethodInit(unittest.TestCase):
"""Test cases for AscendFAQuantAttentionMethod initialization"""
def setUp(self):
"""Set up test environment"""
# Mock vllm_config
self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config")
self.mock_get_config = self.config_patch.start()
# Create mock config with attributes
self.mock_config = Mock()
self.mock_hf_config = Mock()
self.mock_hf_config.kv_lora_rank = 128
self.mock_hf_config.qk_rope_head_dim = 64
self.mock_config.model_config.hf_config = self.mock_hf_config
self.mock_get_config.return_value = self.mock_config
# Import the class after patching
from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod
self.method_class = AscendFAQuantAttentionMethod
def tearDown(self):
"""Clean up after each test"""
self.config_patch.stop()
def test_init_with_full_config(self):
"""Test initialization when config has all attributes"""
method = self.method_class()
self.assertTrue(method.transpose_weight)
self.assertFalse(method.printFlag)
self.assertEqual(method.kv_lora_rank, 128)
self.assertEqual(method.qk_rope_head_dim, 64)
def test_init_without_kv_lora_rank(self):
"""Test initialization when config lacks kv_lora_rank"""
delattr(self.mock_hf_config, "kv_lora_rank")
method = self.method_class()
self.assertEqual(method.kv_lora_rank, 0)
self.assertEqual(method.qk_rope_head_dim, 64)
def test_init_without_qk_rope_head_dim(self):
"""Test initialization when config lacks qk_rope_head_dim"""
delattr(self.mock_hf_config, "qk_rope_head_dim")
method = self.method_class()
self.assertEqual(method.kv_lora_rank, 128)
self.assertEqual(method.qk_rope_head_dim, 0)
def test_init_without_both_attributes(self):
"""Test initialization when config lacks both attributes"""
delattr(self.mock_hf_config, "kv_lora_rank")
delattr(self.mock_hf_config, "qk_rope_head_dim")
method = self.method_class()
self.assertEqual(method.kv_lora_rank, 0)
self.assertEqual(method.qk_rope_head_dim, 0)
class TestAscendFAQuantAttentionMethodCreateWeights(unittest.TestCase):
"""Test cases for create_weights method"""
def setUp(self):
"""Set up test environment"""
# Mock vllm_config
self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config")
self.mock_get_config = self.config_patch.start()
self.mock_config = Mock()
self.mock_hf_config = Mock()
self.mock_hf_config.kv_lora_rank = 128
self.mock_hf_config.qk_rope_head_dim = 64
self.mock_config.model_config.hf_config = self.mock_hf_config
self.mock_get_config.return_value = self.mock_config
# Import the class
from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod
self.method_class = AscendFAQuantAttentionMethod
# Mock torch functions
self.default_dtype_patch = patch("torch.get_default_dtype", return_value=torch.float32)
self.mock_default_dtype = self.default_dtype_patch.start()
# Create a real nn.Module for testing
self.layer = nn.Module()
self.layer.num_heads = 32
self.layer.num_kv_heads = 1
def tearDown(self):
"""Clean up after each test"""
self.config_patch.stop()
self.default_dtype_patch.stop()
def test_create_weights_adds_submodules(self):
"""Test that create_weights adds fa_q, fa_k, fa_v submodules"""
method = self.method_class()
with patch("torch.empty") as mock_empty:
mock_empty.return_value = torch.zeros(1, 1)
method.create_weights(self.layer)
# Verify submodules were added
self.assertTrue(hasattr(self.layer, "fa_q"))
self.assertTrue(hasattr(self.layer, "fa_k"))
self.assertTrue(hasattr(self.layer, "fa_v"))
# Verify they are instances of nn.Module
self.assertIsInstance(self.layer.fa_q, nn.Module)
self.assertIsInstance(self.layer.fa_k, nn.Module)
self.assertIsInstance(self.layer.fa_v, nn.Module)
def test_create_weights_creates_correct_tensors(self):
"""Test that create_weights creates tensors with correct shapes and dtypes"""
method = self.method_class()
# Track torch.empty calls
empty_calls = []
def mock_empty(size, dtype=None):
empty_calls.append((size, dtype))
return torch.zeros(size, dtype=dtype if dtype else torch.float32)
with patch("torch.empty", side_effect=mock_empty):
method.create_weights(self.layer)
# Verify tensor creations
expected_calls = [
((32, 1), torch.float32), # fa_q.scale
((1, 1), torch.float32), # fa_k.scale
((1, 1), torch.float32), # fa_v.scale
((32, 1), torch.int8), # fa_q.offset
((1, 1), torch.int8), # fa_k.offset
((1, 1), torch.int8), # fa_v.offset
]
# Compare without considering order
self.assertEqual(len(empty_calls), len(expected_calls))
for call in expected_calls:
self.assertIn(call, empty_calls)
def test_create_weights_registers_parameters(self):
"""Test that create_weights registers parameters with correct attributes"""
method = self.method_class()
# Create real tensors for testing
def create_tensor(*args, **kwargs):
size = args[0] if args else kwargs.get('size', (1,))
dtype = kwargs.get('dtype', torch.float32)
return torch.zeros(*size, dtype=dtype)
with patch("torch.empty", side_effect=create_tensor):
method.create_weights(self.layer)
# Import weight_loader for comparison
from vllm_ascend.quantization.methods.kv_c8 import weight_loader
# Verify each parameter exists and has weight_loader
self.assertTrue(hasattr(self.layer.fa_q, "scale"))
self.assertTrue(hasattr(self.layer.fa_q.scale, "weight_loader"))
self.assertEqual(self.layer.fa_q.scale.weight_loader, weight_loader)
self.assertFalse(self.layer.fa_q.scale.requires_grad)
self.assertTrue(hasattr(self.layer.fa_k, "scale"))
self.assertTrue(hasattr(self.layer.fa_k.scale, "weight_loader"))
self.assertTrue(hasattr(self.layer.fa_v, "scale"))
self.assertTrue(hasattr(self.layer.fa_v.scale, "weight_loader"))
self.assertTrue(hasattr(self.layer.fa_q, "offset"))
self.assertTrue(hasattr(self.layer.fa_q.offset, "weight_loader"))
self.assertEqual(self.layer.fa_q.offset.dtype, torch.int8)
class TestAscendFAQuantAttentionMethodProcessWeights(unittest.TestCase):
"""Test cases for process_weights_after_loading method"""
def setUp(self):
"""Set up test environment"""
# Mock vllm_config
self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config")
self.mock_get_config = self.config_patch.start()
self.mock_config = Mock()
self.mock_hf_config = Mock()
self.mock_hf_config.kv_lora_rank = 64
self.mock_hf_config.qk_rope_head_dim = 32
self.mock_config.model_config.hf_config = self.mock_hf_config
self.mock_get_config.return_value = self.mock_config
# Import the class
from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod
self.method_class = AscendFAQuantAttentionMethod
# Create method instance with real layer
self.method = self.method_class()
# Create a real nn.Module for testing
self.layer = nn.Module()
# Create real tensors for fa_k
self.fa_k_scale = torch.tensor([[2.0, 3.0, 4.0]], dtype=torch.float16) # Shape (1,3)
self.fa_k_offset = torch.tensor([[1, 2, 3]], dtype=torch.int8) # Shape (1,3)
# Create fa_k module with parameters
self.layer.fa_k = nn.Module()
self.layer.fa_k.scale = nn.Parameter(self.fa_k_scale, requires_grad=False)
self.layer.fa_k.offset = nn.Parameter(self.fa_k_offset, requires_grad=False)
def tearDown(self):
"""Clean up after each test"""
self.config_patch.stop()
def test_process_weights_with_single_value_scale(self):
"""Test process_weights with single value scale"""
# Create new layer with single value scale
layer = nn.Module()
layer.fa_k = nn.Module()
layer.fa_k.scale = nn.Parameter(torch.tensor([[2.0]], dtype=torch.float16), requires_grad=False)
layer.fa_k.offset = nn.Parameter(torch.tensor([[1]], dtype=torch.int8), requires_grad=False)
self.method.kv_lora_rank = 4
self.method.process_weights_after_loading(layer)
self.assertEqual(layer.quant_kscale.shape, (1, 4))
self.assertEqual(layer.quant_kscale.dtype, torch.float32)
class TestIntegration(unittest.TestCase):
"""Integration tests for the complete kv_c8 functionality"""
def setUp(self):
"""Set up test environment"""
# Mock vllm_config
self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config")
self.mock_get_config = self.config_patch.start()
self.mock_config = Mock()
self.mock_hf_config = Mock()
self.mock_hf_config.kv_lora_rank = 64
self.mock_hf_config.qk_rope_head_dim = 32
self.mock_config.model_config.hf_config = self.mock_hf_config
self.mock_get_config.return_value = self.mock_config
# Mock distributed functions
self.tp_rank_patch = patch(
"vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_rank"
)
self.tp_size_patch = patch(
"vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_world_size"
)
self.mock_tp_rank = self.tp_rank_patch.start()
self.mock_tp_size = self.tp_size_patch.start()
def tearDown(self):
"""Clean up after each test"""
self.config_patch.stop()
self.tp_rank_patch.stop()
self.tp_size_patch.stop()
def test_complete_workflow(self):
"""Test complete workflow from weight creation to processing"""
from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod
# Create method instance
method = AscendFAQuantAttentionMethod()
# Create real layer
layer = nn.Module()
layer.num_heads = 32
layer.num_kv_heads = 1
# Step 1: Create weights
method.create_weights(layer)
# Verify weights were created with correct structure
self.assertTrue(hasattr(layer, "fa_q"))
self.assertTrue(hasattr(layer, "fa_k"))
self.assertTrue(hasattr(layer, "fa_v"))
self.assertTrue(hasattr(layer.fa_q, "scale"))
self.assertTrue(hasattr(layer.fa_q, "offset"))
self.assertTrue(hasattr(layer.fa_k, "scale"))
self.assertTrue(hasattr(layer.fa_k, "offset"))
self.assertTrue(hasattr(layer.fa_v, "scale"))
self.assertTrue(hasattr(layer.fa_v, "offset"))
# Step 2: Simulate weight loading
self.mock_tp_rank.return_value = 0
self.mock_tp_size.return_value = 1
# Create dummy weights
q_scale = torch.randn(32, 1)
k_scale = torch.randn(1, 1)
v_scale = torch.randn(1, 1)
q_offset = torch.randint(-128, 127, (32, 1), dtype=torch.int8)
k_offset = torch.randint(-128, 127, (1, 1), dtype=torch.int8)
v_offset = torch.randint(-128, 127, (1, 1), dtype=torch.int8)
# Load weights using weight_loader
from vllm_ascend.quantization.methods.kv_c8 import weight_loader
with torch.no_grad():
weight_loader(layer.fa_q.scale, q_scale)
weight_loader(layer.fa_k.scale, k_scale)
weight_loader(layer.fa_v.scale, v_scale)
weight_loader(layer.fa_q.offset, q_offset)
weight_loader(layer.fa_k.offset, k_offset)
weight_loader(layer.fa_v.offset, v_offset)
# Verify weights were loaded correctly
self.assertTrue(torch.all(layer.fa_q.scale == q_scale))
self.assertTrue(torch.all(layer.fa_k.scale == k_scale))
self.assertTrue(torch.all(layer.fa_v.scale == v_scale))
# Step 3: Process after loading
method.process_weights_after_loading(layer)
# Verify processed parameters
self.assertTrue(hasattr(layer, "fak_descale"))
self.assertTrue(hasattr(layer, "fak_offset"))
self.assertTrue(hasattr(layer, "quant_kscale"))
if __name__ == "__main__":
unittest.main(verbosity=2)

View File

@@ -24,6 +24,7 @@ class TestAscendModelSlimConfig(TestBase):
self.sample_config = { self.sample_config = {
"weight": "INT8", "weight": "INT8",
"fa_quant_type": "C8", "fa_quant_type": "C8",
"layers.1.fa_k.scale": "C8",
"layer1.weight": "INT8", "layer1.weight": "INT8",
"layer2.weight": "FLOAT", "layer2.weight": "FLOAT",
"fused_layer.weight": "FLOAT", "fused_layer.weight": "FLOAT",
@@ -119,6 +120,9 @@ class TestAscendModelSlimConfig(TestBase):
# Test with fa_quant_type # Test with fa_quant_type
method = self.ascend_config.get_quant_method( method = self.ascend_config.get_quant_method(
attention_layer, ".attn") attention_layer, ".attn")
self.assertIs(method, None)
method = self.ascend_config.get_quant_method(
attention_layer, "layers.1.attn")
self.assertIs(method, mock_ascend_kvcache.return_value) self.assertIs(method, mock_ascend_kvcache.return_value)
def test_get_quant_method_for_fused_moe(self): def test_get_quant_method_for_fused_moe(self):

View File

@@ -612,6 +612,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
k_pe: torch.Tensor, k_pe: torch.Tensor,
block_size: int, block_size: int,
attn_metadata: AscendMLAMetadata, attn_metadata: AscendMLAMetadata,
dequant_scale_q_nope=None,
) -> torch.Tensor: ) -> torch.Tensor:
decode_meta = attn_metadata.decode decode_meta = attn_metadata.decode
assert decode_meta is not None assert decode_meta is not None

View File

@@ -47,7 +47,7 @@ from vllm_ascend.ops.layer_shard_linear import (
register_all_layers_to_shard_weight_series, register_all_layers_to_shard_weight_series,
) )
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod from vllm_ascend.quantization.methods.w8a8_static import AscendW8A8LinearMethod
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, get_weight_prefetch_method, maybe_trans_nz, weak_ref_tensors from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, get_weight_prefetch_method, maybe_trans_nz, weak_ref_tensors
from vllm_ascend.worker.npu_input_batch import NPUInputBatch from vllm_ascend.worker.npu_input_batch import NPUInputBatch
@@ -658,6 +658,7 @@ class DecodeMLAPreprocessResult(NamedTuple):
k_nope: torch.Tensor | None = None k_nope: torch.Tensor | None = None
k_pe: torch.Tensor | None = None k_pe: torch.Tensor | None = None
decode_q_wo_k_up: torch.Tensor | None = None decode_q_wo_k_up: torch.Tensor | None = None
dequant_scale_q_nope: torch.Tensor | None = None
class PrefillMLAPreprocessResult(NamedTuple): class PrefillMLAPreprocessResult(NamedTuple):
@@ -725,6 +726,12 @@ class AscendMLAImpl(MLAAttentionImpl):
self.is_kv_producer = ( self.is_kv_producer = (
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
) )
self.layer_name = kwargs.get("layer_name")
quant_config = self.vllm_config.quant_config
self.fa_quant_layer = (
quant_config.enabling_fa_quant(self.vllm_config, self.layer_name) if quant_config is not None else False
)
self.dtype = torch.int8 if self.fa_quant_layer else self.vllm_config.model_config.dtype
self.layer_sharding_kwargs = [] self.layer_sharding_kwargs = []
for layer_name in get_ascend_config().layer_sharding or []: for layer_name in get_ascend_config().layer_sharding or []:
if layer_name in kwargs: if layer_name in kwargs:
@@ -775,6 +782,8 @@ class AscendMLAImpl(MLAAttentionImpl):
actual_seq_lengths, actual_seq_lengths,
attn_output, attn_output,
softmax_lse, softmax_lse,
dequant_scale_q_nope,
fak_descale_float,
) = param ) = param
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
if speculative_config and speculative_config.method == "mtp" and not _EXTRA_CTX.is_draft_model: if speculative_config and speculative_config.method == "mtp" and not _EXTRA_CTX.is_draft_model:
@@ -793,26 +802,35 @@ class AscendMLAImpl(MLAAttentionImpl):
seq_lens_list = seq_lens_list + [0] * (num_tokens - len(seq_lens_list)) seq_lens_list = seq_lens_list + [0] * (num_tokens - len(seq_lens_list))
torch.npu.graph_task_update_begin(update_stream, handle) torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out( extra_args = {}
if dequant_scale_q_nope is not None:
extra_args = {
"query_quant_mode": 3,
"key_quant_mode": 0,
"value_quant_mode": 0,
"dequant_scale_query": dequant_scale_q_nope,
"dequant_scale_key": fak_descale_float,
"dequant_scale_value": fak_descale_float,
}
torch_npu.npu_fused_infer_attention_score_v2.out(
q_nope, q_nope,
k_nope, k_nope,
k_nope, k_nope,
query_rope=q_pe, query_rope=q_pe,
key_rope=k_pe, key_rope=k_pe,
num_heads=num_heads, num_query_heads=num_heads,
num_key_value_heads=num_kv_heads, num_key_value_heads=num_kv_heads,
input_layout=input_layout, input_layout=input_layout,
atten_mask=attn_mask, atten_mask=attn_mask,
sparse_mode=sparse_mode, sparse_mode=sparse_mode,
scale=scale, softmax_scale=scale,
antiquant_mode=0,
antiquant_scale=None,
block_table=block_table, block_table=block_table,
block_size=block_size, block_size=block_size,
actual_seq_lengths_kv=seq_lens_list, actual_seq_kvlen=seq_lens_list,
actual_seq_lengths=actual_seq_lengths, actual_seq_qlen=actual_seq_lengths,
workspace=graph_params.workspaces.get(num_tokens), workspace=graph_params.workspaces.get(num_tokens),
out=[attn_output, softmax_lse], out=[attn_output, softmax_lse],
**extra_args,
) )
torch.npu.graph_task_update_end(update_stream) torch.npu.graph_task_update_end(update_stream)
@@ -887,6 +905,8 @@ class AscendMLAImpl(MLAAttentionImpl):
) )
if self.enable_mlapo: if self.enable_mlapo:
self._process_weights_for_fused_mlapo(act_dtype) self._process_weights_for_fused_mlapo(act_dtype)
elif self.fa_quant_layer:
self._process_weights_for_fused_fa_quant()
else: else:
# if mlapo, W_UK_T can't trans nz # if mlapo, W_UK_T can't trans nz
self.W_UK_T = maybe_trans_nz(self.W_UK_T) self.W_UK_T = maybe_trans_nz(self.W_UK_T)
@@ -895,6 +915,32 @@ class AscendMLAImpl(MLAAttentionImpl):
if is_hidden_layer(layer): if is_hidden_layer(layer):
post_process_after_loading_for_shard_weight_series(layer) post_process_after_loading_for_shard_weight_series(layer)
def _process_weights_for_fused_fa_quant(self):
self.gamma1 = self.q_a_layernorm.weight.data # type: ignore[union-attr]
self.gamma2 = self.kv_a_layernorm.weight.data # type: ignore[union-attr]
wu_q = self.q_proj.weight.data
self.wu_q = wu_q
q_a_proj_fa3 = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous() # type: ignore[union-attr]
self.wd_q = q_a_proj_fa3
kv_a_proj_fa3 = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous() # type: ignore[union-attr]
self.wd_kv = kv_a_proj_fa3
self.dequant_scale_w_uq_qr = self.q_proj.weight_scale.data.view(1, -1).to(torch.float)
q_a_proj_deq_scl = self.fused_qkv_a_proj.weight_scale[: self.q_lora_rank].contiguous() # type: ignore[union-attr]
self.dequant_scale_w_dq = q_a_proj_deq_scl.view(1, -1).to(torch.float)
kv_a_proj_deq_scl = self.fused_qkv_a_proj.weight_scale[self.q_lora_rank :].contiguous() # type: ignore[union-attr]
self.dequant_scale_w_dkv_kr = kv_a_proj_deq_scl.view(1, -1).to(torch.float)
layer = self.vllm_config.compilation_config.static_forward_context[self.layer_name]
self.quant_kscale = layer.quant_kscale
self.fak_descale_float = layer.fak_descale_float
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
assert self.fused_qkv_a_proj is not None assert self.fused_qkv_a_proj is not None
assert self.q_a_layernorm is not None assert self.q_a_layernorm is not None
@@ -1236,6 +1282,7 @@ class AscendMLAImpl(MLAAttentionImpl):
k_pe: torch.Tensor, k_pe: torch.Tensor,
block_size: int, block_size: int,
attn_metadata: AscendMLAMetadata, attn_metadata: AscendMLAMetadata,
dequant_scale_q_nope=None,
) -> torch.Tensor: ) -> torch.Tensor:
decode_meta = attn_metadata.decode decode_meta = attn_metadata.decode
assert decode_meta is not None assert decode_meta is not None
@@ -1243,7 +1290,15 @@ class AscendMLAImpl(MLAAttentionImpl):
# shape of knope/k_pe for npu graph mode should be: # shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
actual_seq_lengths = None actual_seq_lengths = None
if self.enable_kv_nz: if self.fa_quant_layer:
nz_fmt_last_dim = 16
k_nope = k_nope.view(
-1, self.num_kv_heads, self.kv_lora_rank // (nz_fmt_last_dim * 2), block_size, nz_fmt_last_dim * 2
)
k_pe = k_pe.view(
-1, self.num_kv_heads, self.qk_rope_head_dim // nz_fmt_last_dim, block_size, nz_fmt_last_dim
)
elif self.enable_kv_nz:
nz_fmt_last_dim = 16 nz_fmt_last_dim = 16
k_nope = k_nope.view( k_nope = k_nope.view(
-1, self.num_kv_heads, self.kv_lora_rank // nz_fmt_last_dim, block_size, nz_fmt_last_dim -1, self.num_kv_heads, self.kv_lora_rank // nz_fmt_last_dim, block_size, nz_fmt_last_dim
@@ -1278,6 +1333,15 @@ class AscendMLAImpl(MLAAttentionImpl):
sparse_mode = 3 sparse_mode = 3
attn_mask = attn_metadata.decode.attn_mask # type:ignore attn_mask = attn_metadata.decode.attn_mask # type:ignore
actual_seq_lengths = decode_meta.actual_seq_lengths_q actual_seq_lengths = decode_meta.actual_seq_lengths_q
elif self.fa_quant_layer:
attn_mask = None
input_layout = "BSND_NBSD"
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1).contiguous()
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1).contiguous()
dequant_scale_q_nope = dequant_scale_q_nope.view(num_tokens, 1, self.num_heads)
sparse_mode = 0
actual_seq_lengths = None
attn_output_shape = (self.num_heads, num_tokens, 1, self.kv_lora_rank)
else: else:
# The output layout is set to NBSD to eliminate the need for a # The output layout is set to NBSD to eliminate the need for a
# transpose operation after attention. # transpose operation after attention.
@@ -1299,19 +1363,27 @@ class AscendMLAImpl(MLAAttentionImpl):
common_kwargs = { common_kwargs = {
"query_rope": q_pe, "query_rope": q_pe,
"key_rope": k_pe, "key_rope": k_pe,
"num_heads": self.num_heads, "num_query_heads": self.num_heads,
"num_key_value_heads": self.num_kv_heads, "num_key_value_heads": self.num_kv_heads,
"input_layout": input_layout, "input_layout": input_layout,
"atten_mask": attn_mask, "atten_mask": attn_mask,
"sparse_mode": sparse_mode, "sparse_mode": sparse_mode,
"scale": self.scale, "softmax_scale": self.scale,
"antiquant_mode": 0,
"antiquant_scale": None,
"block_table": decode_meta.block_table, "block_table": decode_meta.block_table,
"block_size": block_size, "block_size": block_size,
"actual_seq_lengths": actual_seq_lengths, "actual_seq_qlen": actual_seq_lengths,
"actual_seq_lengths_kv": decode_meta.seq_lens_list, "actual_seq_kvlen": decode_meta.seq_lens_list,
} }
if self.fa_quant_layer:
extra_fa_args = {
"query_quant_mode": 3,
"key_quant_mode": 0,
"value_quant_mode": 0,
"dequant_scale_query": dequant_scale_q_nope,
"dequant_scale_key": self.fak_descale_float,
"dequant_scale_value": self.fak_descale_float,
}
common_kwargs.update(extra_fa_args)
if _EXTRA_CTX.is_draft_model: if _EXTRA_CTX.is_draft_model:
graph_params = get_draft_graph_params() graph_params = get_draft_graph_params()
else: else:
@@ -1325,8 +1397,33 @@ class AscendMLAImpl(MLAAttentionImpl):
graph_params.events[num_tokens].append(event) graph_params.events[num_tokens].append(event)
workspace = graph_params.workspaces.get(num_tokens) workspace = graph_params.workspaces.get(num_tokens)
attn_output = torch.empty(attn_output_shape, dtype=q_pe.dtype, device=q_pe.device)
softmax_lse = torch.empty(num_tokens, dtype=q_pe.dtype, device=q_pe.device)
attn_params = (
weak_ref_tensors(q_nope),
weak_ref_tensors(k_nope),
weak_ref_tensors(q_pe),
weak_ref_tensors(k_pe),
self.num_heads,
self.num_kv_heads,
input_layout,
weak_ref_tensors(attn_mask) if attn_mask is not None else None,
sparse_mode,
self.scale,
decode_meta.block_table,
block_size,
decode_meta.seq_lens_list,
actual_seq_lengths,
weak_ref_tensors(attn_output),
weak_ref_tensors(softmax_lse),
)
if self.fa_quant_layer:
attn_params = attn_params + (dequant_scale_q_nope, self.fak_descale_float) # type: ignore
else:
attn_params = attn_params + (None, None) # type: ignore
if workspace is None: if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( workspace = torch_npu._npu_fused_infer_attention_score_v2_get_max_workspace(
q_nope, k_nope, k_nope, **common_kwargs q_nope, k_nope, k_nope, **common_kwargs
) )
if _EXTRA_CTX.is_draft_model: if _EXTRA_CTX.is_draft_model:
@@ -1334,38 +1431,16 @@ class AscendMLAImpl(MLAAttentionImpl):
else: else:
update_graph_params_workspaces(num_tokens, workspace) update_graph_params_workspaces(num_tokens, workspace)
attn_output = torch.empty(attn_output_shape, dtype=q_nope.dtype, device=q_nope.device) graph_params.attn_params[num_tokens].append(attn_params)
softmax_lse = torch.empty(num_tokens, dtype=q_nope.dtype, device=q_nope.device)
graph_params.attn_params[num_tokens].append(
(
weak_ref_tensors(q_nope),
weak_ref_tensors(k_nope),
weak_ref_tensors(q_pe),
weak_ref_tensors(k_pe),
self.num_heads,
self.num_kv_heads,
input_layout,
weak_ref_tensors(attn_mask) if attn_mask is not None else None,
sparse_mode,
self.scale,
decode_meta.block_table,
block_size,
decode_meta.seq_lens_list,
actual_seq_lengths,
weak_ref_tensors(attn_output),
weak_ref_tensors(softmax_lse),
)
)
torch.npu.graph_task_group_begin(stream) torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out( torch_npu.npu_fused_infer_attention_score_v2.out(
q_nope, k_nope, k_nope, **common_kwargs, workspace=workspace, out=[attn_output, softmax_lse] q_nope, k_nope, k_nope, **common_kwargs, workspace=workspace, out=[attn_output, softmax_lse]
) )
handle = torch.npu.graph_task_group_end(stream) handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle) graph_params.handles[num_tokens].append(handle)
else: else:
attn_output, _ = torch_npu.npu_fused_infer_attention_score(q_nope, k_nope, k_nope, **common_kwargs) attn_output, _ = torch_npu.npu_fused_infer_attention_score_v2(q_nope, k_nope, k_nope, **common_kwargs)
return self._v_up_proj(attn_output) return self._v_up_proj(attn_output)
@@ -1381,55 +1456,81 @@ class AscendMLAImpl(MLAAttentionImpl):
sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1]) sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1])
decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1] decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1]
decode_q_nope = torch.empty( dequant_scale_q_nope = None
(hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_nope.shape[-1]), if self.fa_quant_layer:
dtype=hidden_states.dtype, quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
device=hidden_states.device, decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe, dequant_scale_q_nope = torch_npu.npu_mla_prolog_v2(
) quantized_x,
decode_q_pe = torch.empty( self.wd_q,
(hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_pe.shape[-1]), self.wu_q,
dtype=hidden_states.dtype, self.W_UK_T,
device=hidden_states.device, self.wd_kv,
) self.gamma1,
self.gamma2,
sin,
cos,
attn_metadata.slot_mapping[:bsz].to(torch.int64),
decode_k_nope,
decode_k_pe,
dequant_scale_x=pertoken_scale.view(-1, 1),
dequant_scale_w_dq=self.dequant_scale_w_dq,
dequant_scale_w_uq_qr=self.dequant_scale_w_uq_qr,
dequant_scale_w_dkv_kr=self.dequant_scale_w_dkv_kr,
quant_scale_ckv=self.quant_kscale,
cache_mode="PA_NZ",
)
else:
decode_q_nope = torch.empty(
(hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_nope.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
decode_q_pe = torch.empty(
(hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_pe.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops._C_ascend.mla_preprocess( torch.ops._C_ascend.mla_preprocess(
hidden_states, hidden_states,
self.wd_qkv, self.wd_qkv,
self.deq_scale_qkv, self.deq_scale_qkv,
self.gamma1, self.gamma1,
self.beta1, self.beta1,
self.wu_q, self.wu_q,
self.qb_deq_scl, self.qb_deq_scl,
self.gamma2, self.gamma2,
cos, cos,
sin, sin,
self.W_UK_T, self.W_UK_T,
decode_k_nope, decode_k_nope,
decode_k_pe, decode_k_pe,
attn_metadata.slot_mapping[:bsz], attn_metadata.slot_mapping[:bsz],
quant_scale0=self.quant_scale0, quant_scale0=self.quant_scale0,
quant_offset0=self.quant_offset0, quant_offset0=self.quant_offset0,
bias0=self.quant_bias_qkv, bias0=self.quant_bias_qkv,
quant_scale1=self.quant_scale1, quant_scale1=self.quant_scale1,
quant_offset1=self.quant_offset1, quant_offset1=self.quant_offset1,
bias1=self.qb_qt_bias, bias1=self.qb_qt_bias,
ctkv_scale=self.ctkv_scale, ctkv_scale=self.ctkv_scale,
q_nope_scale=self.q_nope_scale, q_nope_scale=self.q_nope_scale,
cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv", cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv",
quant_mode="per_tensor_quant_asymm", quant_mode="per_tensor_quant_asymm",
q_out0=decode_q_nope, q_out0=decode_q_nope,
kv_cache_out0=decode_k_nope, kv_cache_out0=decode_k_nope,
q_out1=decode_q_pe, q_out1=decode_q_pe,
kv_cache_out1=decode_k_pe, kv_cache_out1=decode_k_pe,
enable_inner_out=False, enable_inner_out=False,
inner_out=torch.tensor([], device=hidden_states.device), inner_out=torch.tensor([], device=hidden_states.device),
) )
decode_q_nope = decode_q_nope.view(bsz, self.num_heads, self.kv_lora_rank) decode_q_nope = decode_q_nope.view(bsz, self.num_heads, self.kv_lora_rank)
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
decode_q_nope, decode_q_pe = self.reorg_decode_q(decode_q_nope, decode_q_pe) decode_q_nope, decode_q_pe = self.reorg_decode_q(decode_q_nope, decode_q_pe)
decode_preprocess_res = DecodeMLAPreprocessResult(decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe) decode_preprocess_res = DecodeMLAPreprocessResult(
decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe, dequant_scale_q_nope=dequant_scale_q_nope
)
return decode_preprocess_res, None return decode_preprocess_res, None
def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, attn_metadata): def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, attn_metadata):
@@ -1576,7 +1677,7 @@ class AscendMLAImpl(MLAAttentionImpl):
o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, device=hidden_states.device) o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, device=hidden_states.device)
# MLA Preprocess # MLA Preprocess
if self.enable_mlapo and attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: if self.fa_quant_layer or (self.enable_mlapo and attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS):
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), need_gather_q_kv hidden_states.contiguous(), need_gather_q_kv
) )
@@ -1596,6 +1697,7 @@ class AscendMLAImpl(MLAAttentionImpl):
decode_preprocess_res.k_pe, decode_preprocess_res.k_pe,
kv_cache[0].shape[1], kv_cache[0].shape[1],
attn_metadata, attn_metadata,
decode_preprocess_res.dequant_scale_q_nope,
) )
o_proj_input[:num_decode_tokens] = output_decode o_proj_input[:num_decode_tokens] = output_decode

View File

@@ -66,7 +66,7 @@ from vllm_ascend.distributed.kv_transfer.utils.utils import (
kv_alltoall_and_rearrange, kv_alltoall_and_rearrange,
parallel_info, parallel_info,
) )
from vllm_ascend.utils import npu_stream_switch from vllm_ascend.utils import npu_stream_switch, trans_nd_to_nz
# isort: off # isort: off
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -124,6 +124,9 @@ class SendTask:
# pd_head_ratio > 1 use # pd_head_ratio > 1 use
k_cache: torch.Tensor | None = None k_cache: torch.Tensor | None = None
v_cache: torch.Tensor | None = None v_cache: torch.Tensor | None = None
# kv cache quantization layer use
k_quant_cache: torch.Tensor | None = None
v_quant_cache: torch.Tensor | None = None
layer_idx: int = 0 layer_idx: int = 0
layer_name: str = "" layer_name: str = ""
# trans block info # trans block info
@@ -210,6 +213,9 @@ class KVCacheSendingLayerThread(threading.Thread):
use_mla: bool, use_mla: bool,
k_buffer: torch.Tensor, k_buffer: torch.Tensor,
v_buffer: torch.Tensor, v_buffer: torch.Tensor,
enable_kv_quant: bool,
k_quant_buffer: torch.Tensor | None,
v_quant_buffer: torch.Tensor | None,
resharding_stream: torch.npu.Stream, resharding_stream: torch.npu.Stream,
callback_func: Callable[..., None] = lambda x: None, callback_func: Callable[..., None] = lambda x: None,
): ):
@@ -232,6 +238,9 @@ class KVCacheSendingLayerThread(threading.Thread):
self.send_queue = queue.Queue[SendTask]() self.send_queue = queue.Queue[SendTask]()
self.k_buffer = k_buffer self.k_buffer = k_buffer
self.v_buffer = v_buffer self.v_buffer = v_buffer
self.enable_kv_quant = enable_kv_quant
self.k_quant_buffer = k_quant_buffer
self.v_quant_buffer = v_quant_buffer
self.ready_event = ready_event self.ready_event = ready_event
self.callback_func = callback_func self.callback_func = callback_func
@@ -325,19 +334,43 @@ class KVCacheSendingLayerThread(threading.Thread):
grouped_remote_block_ids, grouped_local_block_ids = group_concurrent_contiguous( grouped_remote_block_ids, grouped_local_block_ids = group_concurrent_contiguous(
remote_block_ids, local_block_ids remote_block_ids, local_block_ids
) )
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( # kv cache quantization scenario
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr) if self.enable_kv_quant and send_task.k_quant_cache is not None:
): assert len(block_lens) == 2, "Quantization block length must be 2!"
block_len = block_lens[k] quant_block_lens = [block_lens[0] // 2, block_lens[1]]
for group_remote_block_id, group_local_block_id in zip( layer_local_quant_kv_addr = [self.k_quant_buffer.data_ptr(), self.v_quant_buffer.data_ptr()]
grouped_remote_block_ids, grouped_local_block_ids rearrange_block_ids = send_task.group_rearrange_block_ids[layer_group_idx]
# eg:[5,6,7,9] -> {5:0, 6:1, 7:2, 9:3}
rearrange_block_dict = {
value: index
for index, value in enumerate(rearrange_block_ids) # type:ignore
}
for block_len, src_layer_base_addr, dst_layer_base_addr in zip(
quant_block_lens, layer_local_quant_kv_addr, layer_remote_kv_base_addr
): ):
src = src_layer_base_addr + group_local_block_id[0] * block_len for group_remote_block_id, group_local_block_id in zip(
dst = dst_layer_base_addr + group_remote_block_id[0] * block_len grouped_remote_block_ids, grouped_local_block_ids
length = len(group_local_block_id) * block_len ):
src_list.append(src) src = src_layer_base_addr + rearrange_block_dict[group_local_block_id[0]] * block_len
dst_list.append(dst) dst = dst_layer_base_addr + group_remote_block_id[0] * block_len
length_list.append(length) length = len(group_local_block_id) * block_len
src_list.append(src)
dst_list.append(dst)
length_list.append(length)
else:
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)
):
block_len = block_lens[k]
for group_remote_block_id, group_local_block_id in zip(
grouped_remote_block_ids, grouped_local_block_ids
):
src = src_layer_base_addr + group_local_block_id[0] * block_len
dst = dst_layer_base_addr + group_remote_block_id[0] * block_len
length = len(group_local_block_id) * block_len
src_list.append(src)
dst_list.append(dst)
length_list.append(length)
else: else:
rearrange_block_ids = send_task.group_rearrange_block_ids[layer_group_idx] rearrange_block_ids = send_task.group_rearrange_block_ids[layer_group_idx]
rearrange_block_dict = { rearrange_block_dict = {
@@ -380,6 +413,14 @@ class KVCacheSendingLayerThread(threading.Thread):
value = value.view(-1, key.shape[-1]) # type:ignore value = value.view(-1, key.shape[-1]) # type:ignore
self.k_buffer[: key.shape[0]].copy_(key) # [:4, 128] -> self.k_buffer[: key.shape[0]].copy_(key) # [:4, 128] ->
self.v_buffer[: value.shape[0]].copy_(value) self.v_buffer[: value.shape[0]].copy_(value)
if send_task.k_quant_cache is not None:
with npu_stream_switch(self.resharding_stream):
key_quant = send_task.k_quant_cache
key_quant = key_quant.view(-1, key_quant.shape[-1]) # type:ignore
self.k_quant_buffer[: key_quant.shape[0]].copy_(key_quant)
value_quant = send_task.v_quant_cache
value_quant = value_quant.view(-1, value_quant.shape[-1]) # type:ignore
self.v_quant_buffer[: value_quant.shape[0]].copy_(value_quant)
# Merge transmission tasks of the same session # Merge transmission tasks of the same session
session_meta: dict[str, TransferMeta] = {} session_meta: dict[str, TransferMeta] = {}
@@ -395,7 +436,9 @@ class KVCacheSendingLayerThread(threading.Thread):
session_meta[session_id].length.extend(length_list) session_meta[session_id].length.extend(length_list)
session_meta[session_id].req_ids.append(req_id) session_meta[session_id].req_ids.append(req_id)
if self.pd_head_ratio == 1: if send_task.k_quant_cache is not None:
self.resharding_stream.synchronize()
elif self.pd_head_ratio == 1:
""" """
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang. Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
This issue will be fixed in CANN version 8.5.rc1. This issue will be fixed in CANN version 8.5.rc1.
@@ -628,7 +671,7 @@ class MooncakeLayerwiseConnector(KVConnectorBase_V1, SupportsHMA):
self.connector_worker.wait_for_layer_load(layer_name) self.connector_worker.wait_for_layer_load(layer_name)
def save_kv_layer( def save_kv_layer(
self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs self, layer_name: str, kv_layer: list[torch.Tensor], attn_metadata: "AttentionMetadata", **kwargs
) -> None: ) -> None:
"""MooncakeLayerwiseConnector does not save explicitly.""" """MooncakeLayerwiseConnector does not save explicitly."""
assert self.connector_worker is not None assert self.connector_worker is not None
@@ -962,10 +1005,13 @@ class MooncakeLayerwiseConnectorWorker:
self.layer_metadata: dict[str, LayerMetadata] = {} self.layer_metadata: dict[str, LayerMetadata] = {}
self.attn_resharding_group_idx = set[int]() self.attn_resharding_group_idx = set[int]()
self.enable_kv_quant = (
vllm_config.quant_config.enable_fa_quant if vllm_config.quant_config is not None else False
)
self.pd_head_ratio = get_ascend_config().pd_head_ratio self.pd_head_ratio = get_ascend_config().pd_head_ratio
self.num_head_replica = get_ascend_config().num_head_replica self.num_head_replica = get_ascend_config().num_head_replica
self.resharding_stream = None self.resharding_stream = None
if self.pd_head_ratio > 1: if self.pd_head_ratio > 1 or self.enable_kv_quant:
self.resharding_stream = torch.npu.Stream() self.resharding_stream = torch.npu.Stream()
self.remote_poller = zmq.Poller() # type: ignore self.remote_poller = zmq.Poller() # type: ignore
@@ -985,11 +1031,16 @@ class MooncakeLayerwiseConnectorWorker:
self.timeout = 1.0 # seconds self.timeout = 1.0 # seconds
self.k_buffer: torch.Tensor | None = None self.k_buffer: torch.Tensor | None = None
self.v_buffer: torch.Tensor | None = None self.v_buffer: torch.Tensor | None = None
# TODO(kunpengW-code): Reuse k_buffer, v_buffer
self.k_quant_buffer: torch.Tensor | None = None
self.v_quant_buffer: torch.Tensor | None = None
def create_kv_buffer(self, first_kv_cache): def create_kv_buffer(self, first_kv_cache_tuple):
alignment = 2 * 1024 * 1024
buffer_list = []
first_kv_cache = first_kv_cache_tuple[0]
if self.pd_head_ratio > 1: if self.pd_head_ratio > 1:
# regesit kv buffer for tp inequal # regesit kv buffer for tp inequal
alignment = 2 * 1024 * 1024
self.k_buffer = torch.zeros( self.k_buffer = torch.zeros(
first_kv_cache.numel() + alignment, dtype=first_kv_cache.dtype, device=first_kv_cache.device first_kv_cache.numel() + alignment, dtype=first_kv_cache.dtype, device=first_kv_cache.device
) )
@@ -1002,18 +1053,34 @@ class MooncakeLayerwiseConnectorWorker:
self.v_buffer = align_memory(self.v_buffer, alignment)[: first_kv_cache.numel()].view( self.v_buffer = align_memory(self.v_buffer, alignment)[: first_kv_cache.numel()].view(
-1, first_kv_cache.shape[-1] -1, first_kv_cache.shape[-1]
) )
buffer_list.append(self.k_buffer)
buffer_list.append(self.v_buffer)
if self.enable_kv_quant:
quant_k_cache_numel = first_kv_cache_tuple[0].numel() // 2
self.k_quant_buffer = torch.zeros(
quant_k_cache_numel + alignment, dtype=torch.int8, device=first_kv_cache.device
)
self.k_quant_buffer = align_memory(self.k_quant_buffer, alignment)[:quant_k_cache_numel].view(
-1, first_kv_cache.shape[-1]
)
quant_v_cache_numel = first_kv_cache_tuple[1].numel()
self.v_quant_buffer = torch.zeros(
quant_v_cache_numel + alignment, dtype=first_kv_cache.dtype, device=first_kv_cache.device
)
self.v_quant_buffer = align_memory(self.v_quant_buffer, alignment)[:quant_v_cache_numel].view(
-1, first_kv_cache_tuple[1].shape[-1]
)
buffer_list.append(self.k_quant_buffer)
buffer_list.append(self.v_quant_buffer)
for tensor in (self.k_buffer, self.v_buffer): for tensor in buffer_list:
assert tensor.data_ptr() % alignment == 0, ( assert tensor.data_ptr() % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
"The address of the registered kv cache should be aligned to 2M" ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel())
) logger.info(
ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel()) f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}"
logger.info( )
f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} " if ret_value != 0:
f"{tensor.numel()} {ret_value=}" raise RuntimeError("Mooncake memory registration failed. ")
)
if ret_value != 0:
raise RuntimeError("Mooncake memory registration failed. ")
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data.""" """Register the KV Cache data."""
@@ -1042,8 +1109,8 @@ class MooncakeLayerwiseConnectorWorker:
ptrs = [] ptrs = []
lengths = [] lengths = []
use_resharding_buffer = False use_kv_buffer = False
resharding_buffer = None kv_buffer = None
for layer_name, kv_cache_tuple in kv_caches.items(): for layer_name, kv_cache_tuple in kv_caches.items():
if isinstance(kv_cache_tuple, (list, tuple)) is False: if isinstance(kv_cache_tuple, (list, tuple)) is False:
kv_cache_tuple = [kv_cache_tuple] kv_cache_tuple = [kv_cache_tuple]
@@ -1051,12 +1118,13 @@ class MooncakeLayerwiseConnectorWorker:
layer_kv_cache_spec = kv_cache_groups[layer_kv_group_id].kv_cache_spec layer_kv_cache_spec = kv_cache_groups[layer_kv_group_id].kv_cache_spec
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name] layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name]
if self.pd_head_ratio > 1 and (isinstance(layer_kv_cache_spec, (FullAttentionSpec, SlidingWindowSpec))): if (
self.pd_head_ratio > 1 and (isinstance(layer_kv_cache_spec, (FullAttentionSpec, SlidingWindowSpec)))
) or self.enable_kv_quant:
self.attn_resharding_group_idx.add(layer_kv_group_id) self.attn_resharding_group_idx.add(layer_kv_group_id)
if use_resharding_buffer is False: if use_kv_buffer is False:
use_resharding_buffer = True use_kv_buffer = True
resharding_buffer = kv_cache_tuple[0] kv_buffer = kv_cache_tuple
self.resharding_stream = torch.npu.Stream()
single_layer_meta = LayerMetadata([], [], [], []) single_layer_meta = LayerMetadata([], [], [], [])
for single_kv_cache in kv_cache_tuple: for single_kv_cache in kv_cache_tuple:
block_start_rank = 1 block_start_rank = 1
@@ -1092,8 +1160,8 @@ class MooncakeLayerwiseConnectorWorker:
lengths.append(kv_cache_tensor.size) lengths.append(kv_cache_tensor.size)
global_te.register_buffer(ptrs, lengths) global_te.register_buffer(ptrs, lengths)
if use_resharding_buffer: if use_kv_buffer:
self.create_kv_buffer(resharding_buffer) self.create_kv_buffer(kv_buffer)
num_attn_module = 2 if self.vllm_config.model_config.hf_text_config.model_type == "longcat_flash" else 1 num_attn_module = 2 if self.vllm_config.model_config.hf_text_config.model_type == "longcat_flash" else 1
mtp_layer_name = "" mtp_layer_name = ""
@@ -1133,6 +1201,9 @@ class MooncakeLayerwiseConnectorWorker:
use_mla=self.use_mla, use_mla=self.use_mla,
k_buffer=self.k_buffer, k_buffer=self.k_buffer,
v_buffer=self.v_buffer, v_buffer=self.v_buffer,
enable_kv_quant=self.enable_kv_quant,
k_quant_buffer=self.k_quant_buffer,
v_quant_buffer=self.v_quant_buffer,
resharding_stream=self.resharding_stream, resharding_stream=self.resharding_stream,
callback_func=self.send_done_send_signal, callback_func=self.send_done_send_signal,
) )
@@ -1380,7 +1451,7 @@ class MooncakeLayerwiseConnectorWorker:
metadata.requests[req_id] = update_metadata[req_id] metadata.requests[req_id] = update_metadata[req_id]
# update send task trans block info # update send task trans block info
if self.pd_head_ratio != 1: if self.pd_head_ratio != 1 or self.enable_kv_quant:
send_task = metadata.send_task send_task = metadata.send_task
send_task.group_rearrange_block_ids = [[] for _ in range(self.num_kv_cache_groups)] send_task.group_rearrange_block_ids = [[] for _ in range(self.num_kv_cache_groups)]
send_task.group_num_blocks = [0 for _ in range(self.num_kv_cache_groups)] send_task.group_num_blocks = [0 for _ in range(self.num_kv_cache_groups)]
@@ -1388,7 +1459,7 @@ class MooncakeLayerwiseConnectorWorker:
send_task.group_block_table = [None for _ in range(self.num_kv_cache_groups)] send_task.group_block_table = [None for _ in range(self.num_kv_cache_groups)]
send_task.group_block_len_tensor = [None for _ in range(self.num_kv_cache_groups)] send_task.group_block_len_tensor = [None for _ in range(self.num_kv_cache_groups)]
send_task.group_seq_start_tensor = [None for _ in range(self.num_kv_cache_groups)] send_task.group_seq_start_tensor = [None for _ in range(self.num_kv_cache_groups)]
device = self.k_buffer.device # type: ignore device = self.k_buffer.device if self.k_buffer is not None else self.k_quant_buffer.device # type: ignore
for i in self.attn_resharding_group_idx: for i in self.attn_resharding_group_idx:
send_task.group_rearrange_block_ids[i].extend( send_task.group_rearrange_block_ids[i].extend(
sorted( sorted(
@@ -1415,7 +1486,7 @@ class MooncakeLayerwiseConnectorWorker:
def save_kv_layer( def save_kv_layer(
self, self,
layer_name: str, layer_name: str,
kv_layer: tuple[torch.Tensor, torch.Tensor], kv_layer: list[torch.Tensor],
attn_metadata: "AttentionMetadata", attn_metadata: "AttentionMetadata",
connector_metadata: MooncakeLayerwiseConnectorMetadata, connector_metadata: MooncakeLayerwiseConnectorMetadata,
**kwargs, **kwargs,
@@ -1490,12 +1561,51 @@ class MooncakeLayerwiseConnectorWorker:
values = values.reshape(-1, *kv_layer[1].shape[2:]) values = values.reshape(-1, *kv_layer[1].shape[2:])
(keys, values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys, values) (keys, values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys, values)
quant_keys = None
quant_values = None
if self.enable_kv_quant and self.current_layer in self.vllm_config.quant_config.kvcache_quant_layers:
assert self.resharding_stream is not None
with npu_stream_switch(self.resharding_stream):
reshape_cache_event.wait()
device = self.k_quant_buffer.device # type: ignore
layer = self.vllm_config.compilation_config.static_forward_context[layer_name]
# Initialize buffers
# [num_tokens, kv_head, head_dim]
quant_key = torch.empty(
(send_task.group_num_tokens[layer_group_idx], *kv_layer[0].size()[-2:]),
dtype=kv_layer[0].dtype,
device=device,
)
quant_values = torch.empty(
(send_task.group_num_tokens[layer_group_idx], *kv_layer[1].size()[-2:]),
dtype=kv_layer[1].dtype,
device=device,
)
# Load cache data into buffers
torch_npu.atb.npu_paged_cache_load(
kv_layer[0],
kv_layer[1],
send_task.group_block_table[layer_group_idx],
send_task.group_block_len_tensor[layer_group_idx],
seq_starts=send_task.group_seq_start_tensor[layer_group_idx],
key=quant_key,
value=quant_values,
)
quant_keys = torch.ops.vllm.quantize(
quant_key, layer.fak_descale, layer.fak_descale_reciprocal, layer.fak_offset
)
quant_keys = self.get_nz_cache(quant_keys, layer_group_idx)
quant_values = self.get_nz_cache(quant_values, layer_group_idx)
assert self.kv_send_layer_thread is not None assert self.kv_send_layer_thread is not None
assert reshape_cache_event is not None assert reshape_cache_event is not None
layer_send_task = SendTask( layer_send_task = SendTask(
wait_event=reshape_cache_event, wait_event=reshape_cache_event,
k_cache=keys, k_cache=keys,
v_cache=values, v_cache=values,
k_quant_cache=quant_keys,
v_quant_cache=quant_values,
layer_idx=self.current_layer, layer_idx=self.current_layer,
layer_name=layer_name, layer_name=layer_name,
group_rearrange_block_ids=send_task.group_rearrange_block_ids, group_rearrange_block_ids=send_task.group_rearrange_block_ids,
@@ -1510,6 +1620,15 @@ class MooncakeLayerwiseConnectorWorker:
self.kv_send_layer_thread.send_queue.put(layer_send_task) self.kv_send_layer_thread.send_queue.put(layer_send_task)
self.current_layer += 1 self.current_layer += 1
# NOTE: Due to the FIA operator constraints, the expected kv cache is ND format, NZ shape,
# while the npu_format_cast method only modifies the memory layout, we manually convert it to NZ shape here
def get_nz_cache(self, cache_tensor: torch.Tensor, layer_group_idx: int):
head_num, head_dim = cache_tensor.shape[-2], cache_tensor.shape[-1]
cache_tensor = cache_tensor.view(-1, self.block_size[layer_group_idx], head_num * head_dim)
cache_tensor = trans_nd_to_nz(cache_tensor)
cache_tensor = cache_tensor.reshape(-1, head_num, head_dim)
return cache_tensor
def _get_remote_socket(self, remote_host: str, remote_handshake_port: int) -> zmq.Socket: # type: ignore def _get_remote_socket(self, remote_host: str, remote_handshake_port: int) -> zmq.Socket: # type: ignore
"""Get a socket to the remote host.""" """Get a socket to the remote host."""
remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port) remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port)

View File

@@ -120,6 +120,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa, kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa,
kv_a_layernorm=mla_modules.kv_a_layernorm, kv_a_layernorm=mla_modules.kv_a_layernorm,
o_proj=mla_modules.o_proj, o_proj=mla_modules.o_proj,
layer_name=f"{prefix}.attn",
) )
original_process_weights = self.mla_attn.process_weights_after_loading original_process_weights = self.mla_attn.process_weights_after_loading

View File

@@ -528,3 +528,17 @@
# Future Plan: # Future Plan:
# Remove this patch when: # Remove this patch when:
# design a dispatch mechanism for batch_memcpy_kernel. # design a dispatch mechanism for batch_memcpy_kernel.
#
# ** 23. File: worker/patch_weight_utils.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.models.deepseek_v2.DeepseekV2ForCausalLM.load_weights`
# Why:
# The C8 weight quantized by modelslim will modify the model structure,
# and the scale and offset required for kvcache quantization will increase.
# In addition, the names of the quantization parameters are different from
# those in the community.
# How
# we have enhanced the maybe_remap_kv_scale_name function.
# Future Plan:
# The maybe_remap_kv_scale_name function of the community is reconstructed to support
# multiple backends.

View File

@@ -22,6 +22,7 @@ if HAS_TRITON:
# isort: off # isort: off
import vllm_ascend.patch.worker.patch_weight_utils # noqa
import vllm_ascend.patch.platform.patch_sched_yield # noqa import vllm_ascend.patch.platform.patch_sched_yield # noqa
import vllm_ascend.patch.worker.patch_unquantized_gemm # noqa import vllm_ascend.patch.worker.patch_unquantized_gemm # noqa
import vllm_ascend.patch.worker.patch_bert # noqa import vllm_ascend.patch.worker.patch_bert # noqa

View File

@@ -0,0 +1,86 @@
import sys
from typing import Any
from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
logger = init_logger(__name__)
class ImportPatchDecorator:
"""Import patch decorator"""
_patches: dict[str, Any] = {}
@classmethod
def register(cls, module_name):
"""Decorator for registering module patches"""
def decorator(func):
cls._patches[module_name] = func
return func
return decorator
@classmethod
def apply_patches(cls):
"""Apply all patches"""
for module_name, patch_func in cls._patches.items():
if module_name in sys.modules:
module = sys.modules[module_name]
try:
patch_func(module)
except Exception as e:
logger.error(f"Patch application failed {module_name}: {e}")
@ImportPatchDecorator.register("vllm.model_executor.models.deepseek_v2")
def patch_deepseek(module):
ori_maybe_remap_kv_scale_name = maybe_remap_kv_scale_name
def new_remap(name: str, params_dict: dict):
name = ori_maybe_remap_kv_scale_name(name, params_dict)
replace_scale_names = ["fa_q.scale", "fa_k.scale", "fa_v.scale", "fa_q.offset", "fa_k.offset", "fa_v.offset"]
for scale_name in replace_scale_names:
if name.endswith(scale_name):
remap_name = name.replace(scale_name, f"mla_attn.mla_attn.{scale_name}")
if remap_name in params_dict:
return remap_name
else:
return remap_name.replace(".mla_attn", "")
return name
if hasattr(module, "maybe_remap_kv_scale_name"):
module._original_maybe_remap_kv_scale_name = module.maybe_remap_kv_scale_name
module.maybe_remap_kv_scale_name = new_remap
@ImportPatchDecorator.register("vllm.model_executor.model_loader.weight_utils")
def patch_weight_utils(module):
if "vllm.model_executor.models.deepseek_v2" in sys.modules:
deepseek = sys.modules["vllm.model_executor.models.deepseek_v2"]
if hasattr(deepseek, "maybe_remap_kv_scale_name"):
module.maybe_remap_kv_scale_name = deepseek.maybe_remap_kv_scale_name
original_import = __builtins__["__import__"] # type: ignore
def patched_import(name, globals=None, locals=None, fromlist=(), level=0):
module = original_import(name, globals, locals, fromlist, level)
if name in ImportPatchDecorator._patches:
try:
ImportPatchDecorator._patches[name](module)
except Exception as e:
logger.error(f"Patch application failed during import {name}: {e}")
return module
__builtins__["__import__"] = patched_import
ImportPatchDecorator.apply_patches()

View File

@@ -32,10 +32,11 @@ from typing import Any
# Import base classes # Import base classes
from .base import AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme, QuantType from .base import AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme, QuantType
# Import all scheme classes for external access
from .kv_c8 import AscendFAQuantAttentionMethod
# Import registry functions # Import registry functions
from .registry import get_scheme_class, register_scheme from .registry import get_scheme_class, register_scheme
# Import all scheme classes for external access
from .w4a4_flatquant import AscendW4A4FlatQuantDynamicLinearMethod from .w4a4_flatquant import AscendW4A4FlatQuantDynamicLinearMethod
from .w4a4_laos_dynamic import AscendW4A4LaosDynamicLinearMethod from .w4a4_laos_dynamic import AscendW4A4LaosDynamicLinearMethod
from .w4a8 import AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod from .w4a8 import AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod
@@ -77,4 +78,5 @@ __all__ = [
"AscendW4A16FusedMoEMethod", "AscendW4A16FusedMoEMethod",
"AscendW4A4FlatQuantDynamicLinearMethod", "AscendW4A4FlatQuantDynamicLinearMethod",
"AscendW4A4LaosDynamicLinearMethod", "AscendW4A4LaosDynamicLinearMethod",
"AscendFAQuantAttentionMethod",
] ]

View File

@@ -0,0 +1,65 @@
import torch
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
from .registry import register_scheme
def weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor):
"""fa_q weight loader."""
if param.numel() == 1 and loaded_weight.numel() == 1:
param.data.fill_(loaded_weight.item())
else:
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
shard_size = loaded_weight.shape[0] // tp_size
loaded_weight = loaded_weight.narrow(0, shard_size * tp_rank, shard_size)
assert param.size() == loaded_weight.size(), (
f"Attempted to load weight ({loaded_weight.size()}) into parameter ({param.size()}) when TP is ({tp_size})"
)
param.data.copy_(loaded_weight)
@register_scheme("FAKQuant", "attention")
class AscendFAQuantAttentionMethod:
def __init__(self):
self.transpose_weight = True
self.printFlag = False
vllm_config = get_current_vllm_config()
config = vllm_config.model_config.hf_config
self.kv_lora_rank = getattr(config, "kv_lora_rank", 0)
self.qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
def create_weights(self, layer: torch.nn.Module) -> None:
extra_module_names = ["fa_q", "fa_k", "fa_v"]
for name in extra_module_names:
setattr(layer, name, torch.nn.Module())
params_dict = {}
dtype = torch.get_default_dtype()
params_dict["fa_q.scale"] = torch.empty((layer.num_heads, 1), dtype=dtype)
params_dict["fa_k.scale"] = torch.empty((layer.num_kv_heads, 1), dtype=dtype)
params_dict["fa_v.scale"] = torch.empty((layer.num_kv_heads, 1), dtype=dtype)
params_dict["fa_q.offset"] = torch.empty((layer.num_heads, 1), dtype=torch.int8)
params_dict["fa_k.offset"] = torch.empty((layer.num_kv_heads, 1), dtype=torch.int8)
params_dict["fa_v.offset"] = torch.empty((layer.num_kv_heads, 1), dtype=torch.int8)
for name, weight in params_dict.items():
module_name, weight_name = name.rsplit(".", 1)
module = getattr(layer, module_name)
weight_param = torch.nn.Parameter(weight, requires_grad=False)
module.register_parameter(weight_name, weight_param)
# When loading weights, segment them according to TP
weight_param.weight_loader = weight_loader
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
fa_k_scale = torch.squeeze(layer.fa_k.scale).unsqueeze(0)
layer.fak_descale_float = torch.nn.Parameter(fa_k_scale.to(torch.float), requires_grad=False)
layer.fak_descale = torch.nn.Parameter(fa_k_scale, requires_grad=False)
layer.fak_descale_reciprocal = 1.0 / torch.nn.Parameter(fa_k_scale, requires_grad=False)
fa_k_offset = torch.squeeze(layer.fa_k.offset).unsqueeze(0)
layer.fak_offset = torch.nn.Parameter(fa_k_offset.to(layer.fak_descale.dtype), requires_grad=False)
repeated_quant_kscale = fa_k_scale.repeat(self.kv_lora_rank)
layer.quant_kscale = repeated_quant_kscale.view(1, self.kv_lora_rank)
layer.quant_kscale = 1.0 / torch.nn.Parameter(layer.quant_kscale.to(torch.float), requires_grad=False)

View File

@@ -24,6 +24,7 @@ configs generated by the ModelSlim tool, along with model-specific mappings.
import glob import glob
import json import json
import os import os
import re
from collections.abc import Mapping from collections.abc import Mapping
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Optional from typing import Any, Optional
@@ -31,6 +32,7 @@ from typing import Any, Optional
import torch import torch
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization import register_quantization_config from vllm.model_executor.layers.quantization import register_quantization_config
@@ -38,7 +40,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from vllm.model_executor.layers.vocab_parallel_embedding import UnquantizedEmbeddingMethod, VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import UnquantizedEmbeddingMethod, VocabParallelEmbedding
from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.models.utils import WeightsMapper
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, calc_split_factor
from .methods import get_scheme_class from .methods import get_scheme_class
@@ -438,6 +440,7 @@ class AscendModelSlimConfig(QuantizationConfig):
new_k = k.replace("weight_packed", "weight") new_k = k.replace("weight_packed", "weight")
extra_quant_dict[new_k] = self.quant_description[k] extra_quant_dict[new_k] = self.quant_description[k]
self.quant_description.update(extra_quant_dict) self.quant_description.update(extra_quant_dict)
self._add_kvcache_quant_metadata()
def __repr__(self) -> str: def __repr__(self) -> str:
return "AscendModelSlimConfig:\n" + super().__repr__() return "AscendModelSlimConfig:\n" + super().__repr__()
@@ -509,8 +512,6 @@ class AscendModelSlimConfig(QuantizationConfig):
self.packed_modules_mapping = packed_modules_model_mapping[model_type] self.packed_modules_mapping = packed_modules_model_mapping[model_type]
prefix = self.quant_prefix_mapper(model_type, prefix) prefix = self.quant_prefix_mapper(model_type, prefix)
from vllm.model_executor.layers.attention import Attention
if model_type != "kimi_k2": if model_type != "kimi_k2":
if prefix.startswith("language_model"): if prefix.startswith("language_model"):
prefix = prefix.split(".", 1)[-1] prefix = prefix.split(".", 1)[-1]
@@ -522,11 +523,7 @@ class AscendModelSlimConfig(QuantizationConfig):
return AscendUnquantizedLinearMethod() return AscendUnquantizedLinearMethod()
scheme = create_scheme_for_layer(self.quant_description, prefix, "linear", self.packed_modules_mapping) scheme = create_scheme_for_layer(self.quant_description, prefix, "linear", self.packed_modules_mapping)
return AscendLinearMethod(scheme) return AscendLinearMethod(scheme)
elif ( elif isinstance(layer, AttentionLayerBase) and self.is_fa_quant_layer(prefix):
isinstance(layer, Attention)
and "fa_quant_type" in self.quant_description
and self.quant_description["fa_quant_type"] is not None
):
scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping) scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping)
return AscendKVCacheMethod(scheme) return AscendKVCacheMethod(scheme)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
@@ -573,6 +570,39 @@ class AscendModelSlimConfig(QuantizationConfig):
assert is_skipped is not None assert is_skipped is not None
return is_skipped return is_skipped
def is_fa_quant_layer(self, prefix):
if self.enable_fa_quant:
layer_id_str = "".join(re.findall(r"\.(\d+)\.", prefix))
if layer_id_str.isdigit() and int(layer_id_str) in self.kvcache_quant_layers:
return True
return False
def enabling_fa_quant(self, vllm_config, layer_name) -> bool:
is_decode_instance = (
vllm_config.kv_transfer_config is not None
and vllm_config.kv_transfer_config.is_kv_consumer
and not vllm_config.kv_transfer_config.is_kv_producer
)
return bool(is_decode_instance and self.is_fa_quant_layer(layer_name))
def get_kv_quant_dtype(self, layer_name, cache_dtype, model_config):
if self.enable_fa_quant and self.is_fa_quant_layer(layer_name):
ori_dtype = model_config.dtype
quant_dtype = torch.int8
# For MLA models like deepseek, we only quantify K cache to ensure accuracy
if model_config.use_mla:
return quant_dtype, ori_dtype
else:
return quant_dtype, quant_dtype
return cache_dtype, cache_dtype
def get_kv_quant_split_factor(self, layer_name, kv_head_dim_list):
if self.enable_fa_quant and self.is_fa_quant_layer(layer_name):
k_quant_head_dim = kv_head_dim_list[0]
v_quant_head_dim = kv_head_dim_list[1] * 2
kv_head_dim_list = [k_quant_head_dim, v_quant_head_dim]
return calc_split_factor(kv_head_dim_list)
def maybe_update_config(self, model_name: str, revision: str | None = None) -> None: def maybe_update_config(self, model_name: str, revision: str | None = None) -> None:
"""Load the ModelSlim quantization config from model directory. """Load the ModelSlim quantization config from model directory.
@@ -606,6 +636,7 @@ class AscendModelSlimConfig(QuantizationConfig):
with open(config_path) as f: with open(config_path) as f:
self.quant_description = json.load(f) self.quant_description = json.load(f)
self._apply_extra_quant_adaptations() self._apply_extra_quant_adaptations()
self._add_kvcache_quant_metadata()
return return
# Collect diagnostic info for the error message # Collect diagnostic info for the error message
@@ -678,3 +709,13 @@ class AscendModelSlimConfig(QuantizationConfig):
def get_scaled_act_names(self) -> list[str]: def get_scaled_act_names(self) -> list[str]:
return [] return []
def _add_kvcache_quant_metadata(self):
fa_quant_type = self.quant_description.get("fa_quant_type", "")
self.enable_fa_quant = fa_quant_type != ""
self.kvcache_quant_layers = []
if self.enable_fa_quant:
for key in self.quant_description:
if "fa_k.scale" in key:
_id = "".join(re.findall(r"\.(\d+)\.", key))
self.kvcache_quant_layers.append(int(_id))

View File

@@ -1211,3 +1211,36 @@ def get_rope_dim(vllm_config):
rope_dim = int(model_config.hf_text_config.rotary_dim) rope_dim = int(model_config.hf_text_config.rotary_dim)
return rope_dim return rope_dim
def calc_split_factor(num_list: list[int]):
total = sum(num_list)
split_factor_list = []
for num in num_list:
split_factor_list.append(total / num)
return split_factor_list
# NOTE: The last two dimensions of ND are transferred to NZ
def trans_nd_to_nz(cache_tensor: torch.Tensor):
assert len(cache_tensor.shape) >= 2
batch = cache_tensor.shape[:-2]
a, b = cache_tensor.shape[-2], cache_tensor.shape[-1]
dtype = cache_tensor.dtype
if dtype == torch.int8:
a0, b0 = 16, 32
else:
a0, b0 = 16, 16
nz_shape = list(batch) + [math.ceil(b / b0), math.ceil(a / a0), a0, b0]
# Generate the axis order for the transpose operation.
offset = len(cache_tensor.shape) - 2
base = [2, 0, 1, 3]
array_trans = [i for i in range(offset)] + [i + offset for i in base]
# Perform shape transformation and transpose operation.
*_, n1, m1, m0, n0 = nz_shape
cache_tensor = cache_tensor.reshape(nz_shape[:-4] + [m1, m0, n1, n0])
cache_tensor = cache_tensor.permute(*array_trans)
return cache_tensor

View File

@@ -117,6 +117,7 @@ from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer
from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer
from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer
from vllm_ascend.utils import ( from vllm_ascend.utils import (
calc_split_factor,
check_gdn_layer, check_gdn_layer,
enable_sp, enable_sp,
enable_sp_by_pass, enable_sp_by_pass,
@@ -2683,12 +2684,6 @@ class NPUModelRunner(GPUModelRunner):
# as it only support the 0-dim of kv_cache is `num_blocks`. # as it only support the 0-dim of kv_cache is `num_blocks`.
# For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim # For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim
# and rope head dim. # and rope head dim.
if self.model_config.use_mla:
head_size = (
self.model_config.hf_text_config.qk_rope_head_dim
+ self.model_config.hf_text_config.kv_lora_rank
)
if not self.model_config.use_mla: if not self.model_config.use_mla:
# for non-mla model, use FullAttentionSpec # for non-mla model, use FullAttentionSpec
k_tensor_split_factor = 2.0 k_tensor_split_factor = 2.0
@@ -2703,8 +2698,16 @@ class NPUModelRunner(GPUModelRunner):
dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3] dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3]
else: else:
# for other deepseek models, use MLAAttentionSpec # for other deepseek models, use MLAAttentionSpec
k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank kv_head_dim_list = [
v_tensor_split_factor = head_size / self.model_config.hf_text_config.qk_rope_head_dim self.model_config.hf_text_config.kv_lora_rank,
self.model_config.hf_text_config.qk_rope_head_dim,
]
if self.is_kv_consumer and self.vllm_config.quant_config is not None:
k_tensor_split_factor, v_tensor_split_factor = (
self.vllm_config.quant_config.get_kv_quant_split_factor(layer_name, kv_head_dim_list)
)
else:
k_tensor_split_factor, v_tensor_split_factor = calc_split_factor(kv_head_dim_list)
k_tensor_size = int(kv_cache_tensor.size // k_tensor_split_factor) k_tensor_size = int(kv_cache_tensor.size // k_tensor_split_factor)
v_tensor_size = int(kv_cache_tensor.size // v_tensor_split_factor) v_tensor_size = int(kv_cache_tensor.size // v_tensor_split_factor)
@@ -2881,8 +2884,13 @@ class NPUModelRunner(GPUModelRunner):
num_kv_heads, num_kv_heads,
self.model_config.hf_text_config.qk_rope_head_dim, self.model_config.hf_text_config.qk_rope_head_dim,
] ]
k_cache = raw_k_tensor.view(kv_cache_spec.dtype).view(k_shape) k_cache_dtype = v_cache_dtype = kv_cache_spec.dtype
v_cache = raw_v_tensor.view(kv_cache_spec.dtype).view(v_shape) if self.is_kv_consumer and self.vllm_config.quant_config is not None:
k_cache_dtype, v_cache_dtype = self.vllm_config.quant_config.get_kv_quant_dtype(
layer_name, kv_cache_spec.dtype, self.model_config
)
k_cache = raw_k_tensor.view(k_cache_dtype).view(k_shape)
v_cache = raw_v_tensor.view(v_cache_dtype).view(v_shape)
if self.use_sparse: if self.use_sparse:
dsa_k_cache_shape = ( dsa_k_cache_shape = (
@@ -3199,12 +3207,17 @@ class NPUModelRunner(GPUModelRunner):
elif spec := attn_module.get_kv_cache_spec(self.vllm_config): elif spec := attn_module.get_kv_cache_spec(self.vllm_config):
assert isinstance(spec, MLAAttentionSpec) assert isinstance(spec, MLAAttentionSpec)
from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec
if getattr(attn_module.impl, "fa_quant_layer", False):
head_size = attn_module.head_size + attn_module.qk_rope_head_dim
dtype, cache_dtype_str = attn_module.impl.dtype, None
else:
head_size, dtype, cache_dtype_str = spec.head_size, spec.dtype, spec.cache_dtype_str
kv_cache_spec[layer_name] = AscendMLAAttentionSpec( kv_cache_spec[layer_name] = AscendMLAAttentionSpec(
block_size=spec.block_size, block_size=spec.block_size,
num_kv_heads=spec.num_kv_heads, num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size, head_size=head_size,
dtype=spec.dtype, dtype=dtype,
cache_dtype_str=spec.cache_dtype_str, cache_dtype_str=cache_dtype_str,
) )
elif isinstance(attn_module, MambaBase): elif isinstance(attn_module, MambaBase):