From 3f39ac9c8d12e576487bf435af146e3244187507 Mon Sep 17 00:00:00 2001 From: pichangping <1337510399@qq.com> Date: Mon, 16 Mar 2026 22:49:05 +0800 Subject: [PATCH] [Feature]Supports DSv3.1 PD separation and C8 quantization (#7222) Co-authored-by: kunpengW-code <1289706727@qq.com> Co-authored-by: linsheng1 <1950916997@qq.com> ### What this PR does / why we need it? Currently, chunked prefill is forcibly enabled. DeepSeek V3.1 W8A8C8 supports only the PD separation scenario. C8 refers to quantizing the KV cache to int8, which aims to reduce the GPU memory usage of the KV cache and improve the inference throughput. Constraints: 1. Only the PD separation mode can be used and MooncakeLayerwiseConnector can be used to run the model. 2. Currently, only the activation value supports dynamic quantization, and the KV cache supports static quantization. C8 quantization with MTP is not supported. You can use ModelSlim for quantization. The quantization procedure is as follows: pip install transformers==4.48.2 git clone https://gitcode.com/Ascend/msmodelslim.git cd msmodelslim bash install.sh cd example/DeepSeek/ python3 quant_deepseek_w8a8.py --model_path --save_path --anti_dataset../common/deepseek_anti_prompt_50_v3_1.json --calib_dataset../common/deepseek_calib_prompt_50_v3_1.json --rot --trust_remote_code True --fa_quant --dynamic --anti_method m6 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: pichangping <1337510399@qq.com> Signed-off-by: Wang Kunpeng <1289706727@qq.com> Co-authored-by: Wang Kunpeng <1289706727@qq.com> --- tests/ut/attention/test_mla_v1.py | 21 +- tests/ut/quantization/test_kv_c8.py | 468 ++++++++++++++++++ .../ut/quantization/test_modelslim_config.py | 4 + .../attention/context_parallel/mla_cp.py | 1 + vllm_ascend/attention/mla_v1.py | 276 +++++++---- .../kv_p2p/mooncake_layerwise_connector.py | 201 ++++++-- vllm_ascend/ops/mla.py | 1 + vllm_ascend/patch/__init__.py | 14 + vllm_ascend/patch/worker/__init__.py | 1 + .../patch/worker/patch_weight_utils.py | 86 ++++ vllm_ascend/quantization/methods/__init__.py | 6 +- vllm_ascend/quantization/methods/kv_c8.py | 65 +++ vllm_ascend/quantization/modelslim_config.py | 57 ++- vllm_ascend/utils.py | 33 ++ vllm_ascend/worker/model_runner_v1.py | 39 +- 15 files changed, 1112 insertions(+), 161 deletions(-) create mode 100644 tests/ut/quantization/test_kv_c8.py create mode 100644 vllm_ascend/patch/worker/patch_weight_utils.py create mode 100644 vllm_ascend/quantization/methods/kv_c8.py diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index c625969f..1cf661bd 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -807,6 +807,7 @@ class TestAscendMLAImpl(TestBase): attn_type=None, kv_sharing_target_layer_name=None, **kwargs) + self.impl.fa_quant_layer = False def test_init(self): self.assertEqual(self.impl.num_heads, 256) @@ -938,9 +939,9 @@ class TestAscendMLAImpl(TestBase): @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj") - @patch("torch_npu.npu_fused_infer_attention_score") + @patch("torch_npu.npu_fused_infer_attention_score_v2") def test_forward_decode_without_graph(self, - mock_npu_fused_infer_attention_score, + mock_npu_fused_infer_attention_score_v2, mock_up_proj, mock_get_forward_context): num_tokens = 100 @@ -956,8 +957,8 @@ class TestAscendMLAImpl(TestBase): metadata = MagicMock() metadata.decode = MagicMock() metadata.decode.block_table = MagicMock() - metadata.decode.seq_lens = 10 - mock_npu_fused_infer_attention_score.return_value = [ + metadata.decode.actual_seq_lengths = 10 + mock_npu_fused_infer_attention_score_v2.return_value = [ torch.randn(num_tokens, self.impl.num_heads, self.impl.kv_lora_rank), None ] @@ -971,7 +972,7 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(result.shape[1], self.impl.num_heads) self.assertEqual(result.shape[2], self.impl.v_head_dim) mock_up_proj.assert_called_once() - mock_npu_fused_infer_attention_score.assert_called_once() + mock_npu_fused_infer_attention_score_v2.assert_called_once() @patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad") @patch("vllm_ascend.attention.mla_v1.get_weight_prefetch_method", @@ -1103,8 +1104,8 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank) @patch('vllm_ascend.ascend_forward_context.get_forward_context') - @patch("torch_npu.npu_fused_infer_attention_score") - def test_forward_decode(self, mock_npu_fused_infer_attention_score, + @patch("torch_npu.npu_fused_infer_attention_score_v2") + def test_forward_decode(self, mock_npu_fused_infer_attention_score_v2, mock_get_forward_context): B = 2 N = self.impl.num_kv_heads @@ -1121,11 +1122,11 @@ class TestAscendMLAImpl(TestBase): attn_metadata = MagicMock() attn_metadata.attn_state = AscendAttentionState.SpecDecoding attn_metadata.decode = MagicMock() - attn_metadata.decode.actual_seq_lengths_q = MagicMock() - attn_metadata.decode.seq_lens_list = MagicMock() + attn_metadata.decode.actual_seq_qlen = MagicMock() + attn_metadata.decode.actual_seq_kvlen = MagicMock() self.impl.enable_kv_nz = True - mock_npu_fused_infer_attention_score.return_value = [ + mock_npu_fused_infer_attention_score_v2.return_value = [ torch.randn(B, N, self.impl.kv_lora_rank), None ] mock_get_forward_context.return_value = MagicMock(capturing=False) diff --git a/tests/ut/quantization/test_kv_c8.py b/tests/ut/quantization/test_kv_c8.py new file mode 100644 index 00000000..cfafefc9 --- /dev/null +++ b/tests/ut/quantization/test_kv_c8.py @@ -0,0 +1,468 @@ +import unittest +import torch +import torch.nn as nn +from unittest.mock import Mock, patch + + +class TestWeightLoader(unittest.TestCase): + """Test cases for weight_loader function in kv_c8.py""" + + def setUp(self): + """Set up test environment before each test""" + # Import the module under test + from vllm_ascend.quantization.methods.kv_c8 import weight_loader + self.weight_loader = weight_loader + + # Mock distributed functions + self.tp_rank_patch = patch( + "vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_rank" + ) + self.tp_size_patch = patch( + "vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_world_size" + ) + self.mock_tp_rank = self.tp_rank_patch.start() + self.mock_tp_size = self.tp_size_patch.start() + + def tearDown(self): + """Clean up after each test""" + self.tp_rank_patch.stop() + self.tp_size_patch.stop() + + def test_weight_loader_single_element(self): + """Test weight_loader when both tensors contain a single element""" + # Create tensors with single element + param = torch.tensor([0.0]) + loaded_weight = torch.tensor([5.0]) + + # Call weight_loader + self.weight_loader(param, loaded_weight) + + # Verify the value was filled correctly + self.assertEqual(param.item(), 5.0) + self.assertEqual(param.dtype, torch.float32) + + def test_weight_loader_single_element_int(self): + """Test weight_loader with integer tensors""" + param = torch.tensor([0], dtype=torch.int32) + loaded_weight = torch.tensor([10], dtype=torch.int32) + + self.weight_loader(param, loaded_weight) + + self.assertEqual(param.item(), 10) + + def test_weight_loader_tp_sharding_first_rank(self): + """Test weight_loader with tensor parallelism sharding for first rank""" + # Configure mocks for rank 0 of 4 + self.mock_tp_rank.return_value = 0 + self.mock_tp_size.return_value = 4 + + # Create test tensors + param = torch.zeros(2, 5) # Target param shape (2,5) + loaded_weight = torch.ones(8, 5) # Full weight (8,5) + + # Mock narrow to track the call + with patch.object(loaded_weight, 'narrow', wraps=loaded_weight.narrow) as mock_narrow: + self.weight_loader(param, loaded_weight) + + # Verify narrow was called correctly: narrow(dim=0, start=0, length=2) + mock_narrow.assert_called_once_with(0, 0, 2) + + # Verify data was copied + self.assertTrue(torch.all(param == 1)) + + def test_weight_loader_tp_sharding_middle_rank(self): + """Test weight_loader with tensor parallelism sharding for middle rank""" + # Configure mocks for rank 2 of 4 + self.mock_tp_rank.return_value = 2 + self.mock_tp_size.return_value = 4 + + param = torch.zeros(2, 5) + loaded_weight = torch.ones(8, 5) + + with patch.object(loaded_weight, 'narrow', wraps=loaded_weight.narrow) as mock_narrow: + self.weight_loader(param, loaded_weight) + + # Verify narrow was called correctly: start = shard_size * rank = 2 * 2 = 4 + mock_narrow.assert_called_once_with(0, 4, 2) + + self.assertTrue(torch.all(param == 1)) + + def test_weight_loader_tp_sharding_last_rank(self): + """Test weight_loader with tensor parallelism sharding for last rank""" + # Configure mocks for rank 3 of 4 + self.mock_tp_rank.return_value = 3 + self.mock_tp_size.return_value = 4 + + param = torch.zeros(2, 5) + loaded_weight = torch.ones(8, 5) + + with patch.object(loaded_weight, 'narrow', wraps=loaded_weight.narrow) as mock_narrow: + self.weight_loader(param, loaded_weight) + + # Verify narrow was called correctly: start = 2 * 3 = 6 + mock_narrow.assert_called_once_with(0, 6, 2) + + self.assertTrue(torch.all(param == 1)) + + def test_weight_loader_shape_mismatch(self): + """Test weight_loader raises assertion error on shape mismatch""" + self.mock_tp_rank.return_value = 0 + self.mock_tp_size.return_value = 2 + + param = torch.zeros(2, 3) + loaded_weight = torch.ones(4, 4) # Different shape after sharding + + # Mock narrow to return tensor with wrong shape + with patch.object(loaded_weight, 'narrow', return_value=torch.ones(2, 4)): + with self.assertRaises(AssertionError) as context: + self.weight_loader(param, loaded_weight) + + # Verify error message contains expected information + self.assertIn("Attempted to load weight", str(context.exception)) + self.assertIn("into parameter", str(context.exception)) + + def test_weight_loader_with_different_dtypes(self): + """Test weight_loader handles different dtypes correctly""" + self.mock_tp_rank.return_value = 0 + self.mock_tp_size.return_value = 1 # No sharding + + param = torch.zeros(2, 3, dtype=torch.float32) + loaded_weight = torch.ones(2, 3, dtype=torch.float16) + + self.weight_loader(param, loaded_weight) + + # Verify data was copied and converted + self.assertTrue(torch.all(param == 1)) + self.assertEqual(param.dtype, torch.float32) + + +class TestAscendFAQuantAttentionMethodInit(unittest.TestCase): + """Test cases for AscendFAQuantAttentionMethod initialization""" + + def setUp(self): + """Set up test environment""" + # Mock vllm_config + self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config") + self.mock_get_config = self.config_patch.start() + + # Create mock config with attributes + self.mock_config = Mock() + self.mock_hf_config = Mock() + self.mock_hf_config.kv_lora_rank = 128 + self.mock_hf_config.qk_rope_head_dim = 64 + self.mock_config.model_config.hf_config = self.mock_hf_config + self.mock_get_config.return_value = self.mock_config + + # Import the class after patching + from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod + self.method_class = AscendFAQuantAttentionMethod + + def tearDown(self): + """Clean up after each test""" + self.config_patch.stop() + + def test_init_with_full_config(self): + """Test initialization when config has all attributes""" + method = self.method_class() + + self.assertTrue(method.transpose_weight) + self.assertFalse(method.printFlag) + self.assertEqual(method.kv_lora_rank, 128) + self.assertEqual(method.qk_rope_head_dim, 64) + + def test_init_without_kv_lora_rank(self): + """Test initialization when config lacks kv_lora_rank""" + delattr(self.mock_hf_config, "kv_lora_rank") + + method = self.method_class() + + self.assertEqual(method.kv_lora_rank, 0) + self.assertEqual(method.qk_rope_head_dim, 64) + + def test_init_without_qk_rope_head_dim(self): + """Test initialization when config lacks qk_rope_head_dim""" + delattr(self.mock_hf_config, "qk_rope_head_dim") + + method = self.method_class() + + self.assertEqual(method.kv_lora_rank, 128) + self.assertEqual(method.qk_rope_head_dim, 0) + + def test_init_without_both_attributes(self): + """Test initialization when config lacks both attributes""" + delattr(self.mock_hf_config, "kv_lora_rank") + delattr(self.mock_hf_config, "qk_rope_head_dim") + + method = self.method_class() + + self.assertEqual(method.kv_lora_rank, 0) + self.assertEqual(method.qk_rope_head_dim, 0) + + +class TestAscendFAQuantAttentionMethodCreateWeights(unittest.TestCase): + """Test cases for create_weights method""" + + def setUp(self): + """Set up test environment""" + # Mock vllm_config + self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config") + self.mock_get_config = self.config_patch.start() + + self.mock_config = Mock() + self.mock_hf_config = Mock() + self.mock_hf_config.kv_lora_rank = 128 + self.mock_hf_config.qk_rope_head_dim = 64 + self.mock_config.model_config.hf_config = self.mock_hf_config + self.mock_get_config.return_value = self.mock_config + + # Import the class + from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod + self.method_class = AscendFAQuantAttentionMethod + + # Mock torch functions + self.default_dtype_patch = patch("torch.get_default_dtype", return_value=torch.float32) + self.mock_default_dtype = self.default_dtype_patch.start() + + # Create a real nn.Module for testing + self.layer = nn.Module() + self.layer.num_heads = 32 + self.layer.num_kv_heads = 1 + + def tearDown(self): + """Clean up after each test""" + self.config_patch.stop() + self.default_dtype_patch.stop() + + def test_create_weights_adds_submodules(self): + """Test that create_weights adds fa_q, fa_k, fa_v submodules""" + method = self.method_class() + + with patch("torch.empty") as mock_empty: + mock_empty.return_value = torch.zeros(1, 1) + + method.create_weights(self.layer) + + # Verify submodules were added + self.assertTrue(hasattr(self.layer, "fa_q")) + self.assertTrue(hasattr(self.layer, "fa_k")) + self.assertTrue(hasattr(self.layer, "fa_v")) + + # Verify they are instances of nn.Module + self.assertIsInstance(self.layer.fa_q, nn.Module) + self.assertIsInstance(self.layer.fa_k, nn.Module) + self.assertIsInstance(self.layer.fa_v, nn.Module) + + def test_create_weights_creates_correct_tensors(self): + """Test that create_weights creates tensors with correct shapes and dtypes""" + method = self.method_class() + + # Track torch.empty calls + empty_calls = [] + + def mock_empty(size, dtype=None): + empty_calls.append((size, dtype)) + return torch.zeros(size, dtype=dtype if dtype else torch.float32) + + with patch("torch.empty", side_effect=mock_empty): + method.create_weights(self.layer) + + # Verify tensor creations + expected_calls = [ + ((32, 1), torch.float32), # fa_q.scale + ((1, 1), torch.float32), # fa_k.scale + ((1, 1), torch.float32), # fa_v.scale + ((32, 1), torch.int8), # fa_q.offset + ((1, 1), torch.int8), # fa_k.offset + ((1, 1), torch.int8), # fa_v.offset + ] + + # Compare without considering order + self.assertEqual(len(empty_calls), len(expected_calls)) + for call in expected_calls: + self.assertIn(call, empty_calls) + + def test_create_weights_registers_parameters(self): + """Test that create_weights registers parameters with correct attributes""" + method = self.method_class() + + # Create real tensors for testing + def create_tensor(*args, **kwargs): + size = args[0] if args else kwargs.get('size', (1,)) + dtype = kwargs.get('dtype', torch.float32) + return torch.zeros(*size, dtype=dtype) + + with patch("torch.empty", side_effect=create_tensor): + method.create_weights(self.layer) + + # Import weight_loader for comparison + from vllm_ascend.quantization.methods.kv_c8 import weight_loader + + # Verify each parameter exists and has weight_loader + self.assertTrue(hasattr(self.layer.fa_q, "scale")) + self.assertTrue(hasattr(self.layer.fa_q.scale, "weight_loader")) + self.assertEqual(self.layer.fa_q.scale.weight_loader, weight_loader) + self.assertFalse(self.layer.fa_q.scale.requires_grad) + + self.assertTrue(hasattr(self.layer.fa_k, "scale")) + self.assertTrue(hasattr(self.layer.fa_k.scale, "weight_loader")) + + self.assertTrue(hasattr(self.layer.fa_v, "scale")) + self.assertTrue(hasattr(self.layer.fa_v.scale, "weight_loader")) + + self.assertTrue(hasattr(self.layer.fa_q, "offset")) + self.assertTrue(hasattr(self.layer.fa_q.offset, "weight_loader")) + self.assertEqual(self.layer.fa_q.offset.dtype, torch.int8) + + +class TestAscendFAQuantAttentionMethodProcessWeights(unittest.TestCase): + """Test cases for process_weights_after_loading method""" + + def setUp(self): + """Set up test environment""" + # Mock vllm_config + self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config") + self.mock_get_config = self.config_patch.start() + + self.mock_config = Mock() + self.mock_hf_config = Mock() + self.mock_hf_config.kv_lora_rank = 64 + self.mock_hf_config.qk_rope_head_dim = 32 + self.mock_config.model_config.hf_config = self.mock_hf_config + self.mock_get_config.return_value = self.mock_config + + # Import the class + from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod + self.method_class = AscendFAQuantAttentionMethod + + # Create method instance with real layer + self.method = self.method_class() + + # Create a real nn.Module for testing + self.layer = nn.Module() + + # Create real tensors for fa_k + self.fa_k_scale = torch.tensor([[2.0, 3.0, 4.0]], dtype=torch.float16) # Shape (1,3) + self.fa_k_offset = torch.tensor([[1, 2, 3]], dtype=torch.int8) # Shape (1,3) + + # Create fa_k module with parameters + self.layer.fa_k = nn.Module() + self.layer.fa_k.scale = nn.Parameter(self.fa_k_scale, requires_grad=False) + self.layer.fa_k.offset = nn.Parameter(self.fa_k_offset, requires_grad=False) + + def tearDown(self): + """Clean up after each test""" + self.config_patch.stop() + + def test_process_weights_with_single_value_scale(self): + """Test process_weights with single value scale""" + # Create new layer with single value scale + layer = nn.Module() + layer.fa_k = nn.Module() + layer.fa_k.scale = nn.Parameter(torch.tensor([[2.0]], dtype=torch.float16), requires_grad=False) + layer.fa_k.offset = nn.Parameter(torch.tensor([[1]], dtype=torch.int8), requires_grad=False) + + self.method.kv_lora_rank = 4 + self.method.process_weights_after_loading(layer) + + self.assertEqual(layer.quant_kscale.shape, (1, 4)) + self.assertEqual(layer.quant_kscale.dtype, torch.float32) + + +class TestIntegration(unittest.TestCase): + """Integration tests for the complete kv_c8 functionality""" + + def setUp(self): + """Set up test environment""" + # Mock vllm_config + self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config") + self.mock_get_config = self.config_patch.start() + + self.mock_config = Mock() + self.mock_hf_config = Mock() + self.mock_hf_config.kv_lora_rank = 64 + self.mock_hf_config.qk_rope_head_dim = 32 + self.mock_config.model_config.hf_config = self.mock_hf_config + self.mock_get_config.return_value = self.mock_config + + # Mock distributed functions + self.tp_rank_patch = patch( + "vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_rank" + ) + self.tp_size_patch = patch( + "vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_world_size" + ) + self.mock_tp_rank = self.tp_rank_patch.start() + self.mock_tp_size = self.tp_size_patch.start() + + def tearDown(self): + """Clean up after each test""" + self.config_patch.stop() + self.tp_rank_patch.stop() + self.tp_size_patch.stop() + + def test_complete_workflow(self): + """Test complete workflow from weight creation to processing""" + from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod + + # Create method instance + method = AscendFAQuantAttentionMethod() + + # Create real layer + layer = nn.Module() + layer.num_heads = 32 + layer.num_kv_heads = 1 + + # Step 1: Create weights + method.create_weights(layer) + + # Verify weights were created with correct structure + self.assertTrue(hasattr(layer, "fa_q")) + self.assertTrue(hasattr(layer, "fa_k")) + self.assertTrue(hasattr(layer, "fa_v")) + + self.assertTrue(hasattr(layer.fa_q, "scale")) + self.assertTrue(hasattr(layer.fa_q, "offset")) + self.assertTrue(hasattr(layer.fa_k, "scale")) + self.assertTrue(hasattr(layer.fa_k, "offset")) + self.assertTrue(hasattr(layer.fa_v, "scale")) + self.assertTrue(hasattr(layer.fa_v, "offset")) + + # Step 2: Simulate weight loading + self.mock_tp_rank.return_value = 0 + self.mock_tp_size.return_value = 1 + + # Create dummy weights + q_scale = torch.randn(32, 1) + k_scale = torch.randn(1, 1) + v_scale = torch.randn(1, 1) + q_offset = torch.randint(-128, 127, (32, 1), dtype=torch.int8) + k_offset = torch.randint(-128, 127, (1, 1), dtype=torch.int8) + v_offset = torch.randint(-128, 127, (1, 1), dtype=torch.int8) + + # Load weights using weight_loader + from vllm_ascend.quantization.methods.kv_c8 import weight_loader + + with torch.no_grad(): + weight_loader(layer.fa_q.scale, q_scale) + weight_loader(layer.fa_k.scale, k_scale) + weight_loader(layer.fa_v.scale, v_scale) + weight_loader(layer.fa_q.offset, q_offset) + weight_loader(layer.fa_k.offset, k_offset) + weight_loader(layer.fa_v.offset, v_offset) + + # Verify weights were loaded correctly + self.assertTrue(torch.all(layer.fa_q.scale == q_scale)) + self.assertTrue(torch.all(layer.fa_k.scale == k_scale)) + self.assertTrue(torch.all(layer.fa_v.scale == v_scale)) + + # Step 3: Process after loading + method.process_weights_after_loading(layer) + + # Verify processed parameters + self.assertTrue(hasattr(layer, "fak_descale")) + self.assertTrue(hasattr(layer, "fak_offset")) + self.assertTrue(hasattr(layer, "quant_kscale")) + + +if __name__ == "__main__": + unittest.main(verbosity=2) \ No newline at end of file diff --git a/tests/ut/quantization/test_modelslim_config.py b/tests/ut/quantization/test_modelslim_config.py index 556c8a4a..745b4736 100644 --- a/tests/ut/quantization/test_modelslim_config.py +++ b/tests/ut/quantization/test_modelslim_config.py @@ -24,6 +24,7 @@ class TestAscendModelSlimConfig(TestBase): self.sample_config = { "weight": "INT8", "fa_quant_type": "C8", + "layers.1.fa_k.scale": "C8", "layer1.weight": "INT8", "layer2.weight": "FLOAT", "fused_layer.weight": "FLOAT", @@ -119,6 +120,9 @@ class TestAscendModelSlimConfig(TestBase): # Test with fa_quant_type method = self.ascend_config.get_quant_method( attention_layer, ".attn") + self.assertIs(method, None) + method = self.ascend_config.get_quant_method( + attention_layer, "layers.1.attn") self.assertIs(method, mock_ascend_kvcache.return_value) def test_get_quant_method_for_fused_moe(self): diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index c1bf1d36..c1550a0b 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -612,6 +612,7 @@ class AscendMlaCPImpl(AscendMLAImpl): k_pe: torch.Tensor, block_size: int, attn_metadata: AscendMLAMetadata, + dequant_scale_q_nope=None, ) -> torch.Tensor: decode_meta = attn_metadata.decode assert decode_meta is not None diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 3cb9615b..dd6b018b 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -47,7 +47,7 @@ from vllm_ascend.ops.layer_shard_linear import ( register_all_layers_to_shard_weight_series, ) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla -from vllm_ascend.quantization.methods import AscendW8A8LinearMethod +from vllm_ascend.quantization.methods.w8a8_static import AscendW8A8LinearMethod from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, get_weight_prefetch_method, maybe_trans_nz, weak_ref_tensors from vllm_ascend.worker.npu_input_batch import NPUInputBatch @@ -658,6 +658,7 @@ class DecodeMLAPreprocessResult(NamedTuple): k_nope: torch.Tensor | None = None k_pe: torch.Tensor | None = None decode_q_wo_k_up: torch.Tensor | None = None + dequant_scale_q_nope: torch.Tensor | None = None class PrefillMLAPreprocessResult(NamedTuple): @@ -725,6 +726,12 @@ class AscendMLAImpl(MLAAttentionImpl): self.is_kv_producer = ( self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer ) + self.layer_name = kwargs.get("layer_name") + quant_config = self.vllm_config.quant_config + self.fa_quant_layer = ( + quant_config.enabling_fa_quant(self.vllm_config, self.layer_name) if quant_config is not None else False + ) + self.dtype = torch.int8 if self.fa_quant_layer else self.vllm_config.model_config.dtype self.layer_sharding_kwargs = [] for layer_name in get_ascend_config().layer_sharding or []: if layer_name in kwargs: @@ -775,6 +782,8 @@ class AscendMLAImpl(MLAAttentionImpl): actual_seq_lengths, attn_output, softmax_lse, + dequant_scale_q_nope, + fak_descale_float, ) = param seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list if speculative_config and speculative_config.method == "mtp" and not _EXTRA_CTX.is_draft_model: @@ -793,26 +802,35 @@ class AscendMLAImpl(MLAAttentionImpl): seq_lens_list = seq_lens_list + [0] * (num_tokens - len(seq_lens_list)) torch.npu.graph_task_update_begin(update_stream, handle) - torch_npu.npu_fused_infer_attention_score.out( + extra_args = {} + if dequant_scale_q_nope is not None: + extra_args = { + "query_quant_mode": 3, + "key_quant_mode": 0, + "value_quant_mode": 0, + "dequant_scale_query": dequant_scale_q_nope, + "dequant_scale_key": fak_descale_float, + "dequant_scale_value": fak_descale_float, + } + torch_npu.npu_fused_infer_attention_score_v2.out( q_nope, k_nope, k_nope, query_rope=q_pe, key_rope=k_pe, - num_heads=num_heads, + num_query_heads=num_heads, num_key_value_heads=num_kv_heads, input_layout=input_layout, atten_mask=attn_mask, sparse_mode=sparse_mode, - scale=scale, - antiquant_mode=0, - antiquant_scale=None, + softmax_scale=scale, block_table=block_table, block_size=block_size, - actual_seq_lengths_kv=seq_lens_list, - actual_seq_lengths=actual_seq_lengths, + actual_seq_kvlen=seq_lens_list, + actual_seq_qlen=actual_seq_lengths, workspace=graph_params.workspaces.get(num_tokens), out=[attn_output, softmax_lse], + **extra_args, ) torch.npu.graph_task_update_end(update_stream) @@ -887,6 +905,8 @@ class AscendMLAImpl(MLAAttentionImpl): ) if self.enable_mlapo: self._process_weights_for_fused_mlapo(act_dtype) + elif self.fa_quant_layer: + self._process_weights_for_fused_fa_quant() else: # if mlapo, W_UK_T can't trans nz self.W_UK_T = maybe_trans_nz(self.W_UK_T) @@ -895,6 +915,32 @@ class AscendMLAImpl(MLAAttentionImpl): if is_hidden_layer(layer): post_process_after_loading_for_shard_weight_series(layer) + def _process_weights_for_fused_fa_quant(self): + self.gamma1 = self.q_a_layernorm.weight.data # type: ignore[union-attr] + self.gamma2 = self.kv_a_layernorm.weight.data # type: ignore[union-attr] + + wu_q = self.q_proj.weight.data + + self.wu_q = wu_q + + q_a_proj_fa3 = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous() # type: ignore[union-attr] + + self.wd_q = q_a_proj_fa3 + + kv_a_proj_fa3 = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous() # type: ignore[union-attr] + + self.wd_kv = kv_a_proj_fa3 + + self.dequant_scale_w_uq_qr = self.q_proj.weight_scale.data.view(1, -1).to(torch.float) + q_a_proj_deq_scl = self.fused_qkv_a_proj.weight_scale[: self.q_lora_rank].contiguous() # type: ignore[union-attr] + self.dequant_scale_w_dq = q_a_proj_deq_scl.view(1, -1).to(torch.float) + kv_a_proj_deq_scl = self.fused_qkv_a_proj.weight_scale[self.q_lora_rank :].contiguous() # type: ignore[union-attr] + self.dequant_scale_w_dkv_kr = kv_a_proj_deq_scl.view(1, -1).to(torch.float) + + layer = self.vllm_config.compilation_config.static_forward_context[self.layer_name] + self.quant_kscale = layer.quant_kscale + self.fak_descale_float = layer.fak_descale_float + def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): assert self.fused_qkv_a_proj is not None assert self.q_a_layernorm is not None @@ -1236,6 +1282,7 @@ class AscendMLAImpl(MLAAttentionImpl): k_pe: torch.Tensor, block_size: int, attn_metadata: AscendMLAMetadata, + dequant_scale_q_nope=None, ) -> torch.Tensor: decode_meta = attn_metadata.decode assert decode_meta is not None @@ -1243,7 +1290,15 @@ class AscendMLAImpl(MLAAttentionImpl): # shape of knope/k_pe for npu graph mode should be: # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] actual_seq_lengths = None - if self.enable_kv_nz: + if self.fa_quant_layer: + nz_fmt_last_dim = 16 + k_nope = k_nope.view( + -1, self.num_kv_heads, self.kv_lora_rank // (nz_fmt_last_dim * 2), block_size, nz_fmt_last_dim * 2 + ) + k_pe = k_pe.view( + -1, self.num_kv_heads, self.qk_rope_head_dim // nz_fmt_last_dim, block_size, nz_fmt_last_dim + ) + elif self.enable_kv_nz: nz_fmt_last_dim = 16 k_nope = k_nope.view( -1, self.num_kv_heads, self.kv_lora_rank // nz_fmt_last_dim, block_size, nz_fmt_last_dim @@ -1278,6 +1333,15 @@ class AscendMLAImpl(MLAAttentionImpl): sparse_mode = 3 attn_mask = attn_metadata.decode.attn_mask # type:ignore actual_seq_lengths = decode_meta.actual_seq_lengths_q + elif self.fa_quant_layer: + attn_mask = None + input_layout = "BSND_NBSD" + q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1).contiguous() + q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1).contiguous() + dequant_scale_q_nope = dequant_scale_q_nope.view(num_tokens, 1, self.num_heads) + sparse_mode = 0 + actual_seq_lengths = None + attn_output_shape = (self.num_heads, num_tokens, 1, self.kv_lora_rank) else: # The output layout is set to NBSD to eliminate the need for a # transpose operation after attention. @@ -1299,19 +1363,27 @@ class AscendMLAImpl(MLAAttentionImpl): common_kwargs = { "query_rope": q_pe, "key_rope": k_pe, - "num_heads": self.num_heads, + "num_query_heads": self.num_heads, "num_key_value_heads": self.num_kv_heads, "input_layout": input_layout, "atten_mask": attn_mask, "sparse_mode": sparse_mode, - "scale": self.scale, - "antiquant_mode": 0, - "antiquant_scale": None, + "softmax_scale": self.scale, "block_table": decode_meta.block_table, "block_size": block_size, - "actual_seq_lengths": actual_seq_lengths, - "actual_seq_lengths_kv": decode_meta.seq_lens_list, + "actual_seq_qlen": actual_seq_lengths, + "actual_seq_kvlen": decode_meta.seq_lens_list, } + if self.fa_quant_layer: + extra_fa_args = { + "query_quant_mode": 3, + "key_quant_mode": 0, + "value_quant_mode": 0, + "dequant_scale_query": dequant_scale_q_nope, + "dequant_scale_key": self.fak_descale_float, + "dequant_scale_value": self.fak_descale_float, + } + common_kwargs.update(extra_fa_args) if _EXTRA_CTX.is_draft_model: graph_params = get_draft_graph_params() else: @@ -1325,8 +1397,33 @@ class AscendMLAImpl(MLAAttentionImpl): graph_params.events[num_tokens].append(event) workspace = graph_params.workspaces.get(num_tokens) + attn_output = torch.empty(attn_output_shape, dtype=q_pe.dtype, device=q_pe.device) + softmax_lse = torch.empty(num_tokens, dtype=q_pe.dtype, device=q_pe.device) + attn_params = ( + weak_ref_tensors(q_nope), + weak_ref_tensors(k_nope), + weak_ref_tensors(q_pe), + weak_ref_tensors(k_pe), + self.num_heads, + self.num_kv_heads, + input_layout, + weak_ref_tensors(attn_mask) if attn_mask is not None else None, + sparse_mode, + self.scale, + decode_meta.block_table, + block_size, + decode_meta.seq_lens_list, + actual_seq_lengths, + weak_ref_tensors(attn_output), + weak_ref_tensors(softmax_lse), + ) + if self.fa_quant_layer: + attn_params = attn_params + (dequant_scale_q_nope, self.fak_descale_float) # type: ignore + else: + attn_params = attn_params + (None, None) # type: ignore + if workspace is None: - workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( + workspace = torch_npu._npu_fused_infer_attention_score_v2_get_max_workspace( q_nope, k_nope, k_nope, **common_kwargs ) if _EXTRA_CTX.is_draft_model: @@ -1334,38 +1431,16 @@ class AscendMLAImpl(MLAAttentionImpl): else: update_graph_params_workspaces(num_tokens, workspace) - attn_output = torch.empty(attn_output_shape, dtype=q_nope.dtype, device=q_nope.device) - softmax_lse = torch.empty(num_tokens, dtype=q_nope.dtype, device=q_nope.device) - - graph_params.attn_params[num_tokens].append( - ( - weak_ref_tensors(q_nope), - weak_ref_tensors(k_nope), - weak_ref_tensors(q_pe), - weak_ref_tensors(k_pe), - self.num_heads, - self.num_kv_heads, - input_layout, - weak_ref_tensors(attn_mask) if attn_mask is not None else None, - sparse_mode, - self.scale, - decode_meta.block_table, - block_size, - decode_meta.seq_lens_list, - actual_seq_lengths, - weak_ref_tensors(attn_output), - weak_ref_tensors(softmax_lse), - ) - ) + graph_params.attn_params[num_tokens].append(attn_params) torch.npu.graph_task_group_begin(stream) - torch_npu.npu_fused_infer_attention_score.out( + torch_npu.npu_fused_infer_attention_score_v2.out( q_nope, k_nope, k_nope, **common_kwargs, workspace=workspace, out=[attn_output, softmax_lse] ) handle = torch.npu.graph_task_group_end(stream) graph_params.handles[num_tokens].append(handle) else: - attn_output, _ = torch_npu.npu_fused_infer_attention_score(q_nope, k_nope, k_nope, **common_kwargs) + attn_output, _ = torch_npu.npu_fused_infer_attention_score_v2(q_nope, k_nope, k_nope, **common_kwargs) return self._v_up_proj(attn_output) @@ -1381,55 +1456,81 @@ class AscendMLAImpl(MLAAttentionImpl): sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1]) decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1] - decode_q_nope = torch.empty( - (hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_nope.shape[-1]), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - decode_q_pe = torch.empty( - (hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_pe.shape[-1]), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) + dequant_scale_q_nope = None + if self.fa_quant_layer: + quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) + decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe, dequant_scale_q_nope = torch_npu.npu_mla_prolog_v2( + quantized_x, + self.wd_q, + self.wu_q, + self.W_UK_T, + self.wd_kv, + self.gamma1, + self.gamma2, + sin, + cos, + attn_metadata.slot_mapping[:bsz].to(torch.int64), + decode_k_nope, + decode_k_pe, + dequant_scale_x=pertoken_scale.view(-1, 1), + dequant_scale_w_dq=self.dequant_scale_w_dq, + dequant_scale_w_uq_qr=self.dequant_scale_w_uq_qr, + dequant_scale_w_dkv_kr=self.dequant_scale_w_dkv_kr, + quant_scale_ckv=self.quant_kscale, + cache_mode="PA_NZ", + ) + else: + decode_q_nope = torch.empty( + (hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_nope.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + decode_q_pe = torch.empty( + (hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_pe.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) - torch.ops._C_ascend.mla_preprocess( - hidden_states, - self.wd_qkv, - self.deq_scale_qkv, - self.gamma1, - self.beta1, - self.wu_q, - self.qb_deq_scl, - self.gamma2, - cos, - sin, - self.W_UK_T, - decode_k_nope, - decode_k_pe, - attn_metadata.slot_mapping[:bsz], - quant_scale0=self.quant_scale0, - quant_offset0=self.quant_offset0, - bias0=self.quant_bias_qkv, - quant_scale1=self.quant_scale1, - quant_offset1=self.quant_offset1, - bias1=self.qb_qt_bias, - ctkv_scale=self.ctkv_scale, - q_nope_scale=self.q_nope_scale, - cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv", - quant_mode="per_tensor_quant_asymm", - q_out0=decode_q_nope, - kv_cache_out0=decode_k_nope, - q_out1=decode_q_pe, - kv_cache_out1=decode_k_pe, - enable_inner_out=False, - inner_out=torch.tensor([], device=hidden_states.device), - ) - decode_q_nope = decode_q_nope.view(bsz, self.num_heads, self.kv_lora_rank) - decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) + torch.ops._C_ascend.mla_preprocess( + hidden_states, + self.wd_qkv, + self.deq_scale_qkv, + self.gamma1, + self.beta1, + self.wu_q, + self.qb_deq_scl, + self.gamma2, + cos, + sin, + self.W_UK_T, + decode_k_nope, + decode_k_pe, + attn_metadata.slot_mapping[:bsz], + quant_scale0=self.quant_scale0, + quant_offset0=self.quant_offset0, + bias0=self.quant_bias_qkv, + quant_scale1=self.quant_scale1, + quant_offset1=self.quant_offset1, + bias1=self.qb_qt_bias, + ctkv_scale=self.ctkv_scale, + q_nope_scale=self.q_nope_scale, + cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv", + quant_mode="per_tensor_quant_asymm", + q_out0=decode_q_nope, + kv_cache_out0=decode_k_nope, + q_out1=decode_q_pe, + kv_cache_out1=decode_k_pe, + enable_inner_out=False, + inner_out=torch.tensor([], device=hidden_states.device), + ) + decode_q_nope = decode_q_nope.view(bsz, self.num_heads, self.kv_lora_rank) + decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) decode_q_nope, decode_q_pe = self.reorg_decode_q(decode_q_nope, decode_q_pe) - decode_preprocess_res = DecodeMLAPreprocessResult(decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe) + decode_preprocess_res = DecodeMLAPreprocessResult( + decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe, dequant_scale_q_nope=dequant_scale_q_nope + ) return decode_preprocess_res, None def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, attn_metadata): @@ -1576,7 +1677,7 @@ class AscendMLAImpl(MLAAttentionImpl): o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, device=hidden_states.device) # MLA Preprocess - if self.enable_mlapo and attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: + if self.fa_quant_layer or (self.enable_mlapo and attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS): hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( hidden_states.contiguous(), need_gather_q_kv ) @@ -1596,6 +1697,7 @@ class AscendMLAImpl(MLAAttentionImpl): decode_preprocess_res.k_pe, kv_cache[0].shape[1], attn_metadata, + decode_preprocess_res.dequant_scale_q_nope, ) o_proj_input[:num_decode_tokens] = output_decode diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index 415bb6c9..ea82322b 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -66,7 +66,7 @@ from vllm_ascend.distributed.kv_transfer.utils.utils import ( kv_alltoall_and_rearrange, parallel_info, ) -from vllm_ascend.utils import npu_stream_switch +from vllm_ascend.utils import npu_stream_switch, trans_nd_to_nz # isort: off if TYPE_CHECKING: @@ -124,6 +124,9 @@ class SendTask: # pd_head_ratio > 1 use k_cache: torch.Tensor | None = None v_cache: torch.Tensor | None = None + # kv cache quantization layer use + k_quant_cache: torch.Tensor | None = None + v_quant_cache: torch.Tensor | None = None layer_idx: int = 0 layer_name: str = "" # trans block info @@ -210,6 +213,9 @@ class KVCacheSendingLayerThread(threading.Thread): use_mla: bool, k_buffer: torch.Tensor, v_buffer: torch.Tensor, + enable_kv_quant: bool, + k_quant_buffer: torch.Tensor | None, + v_quant_buffer: torch.Tensor | None, resharding_stream: torch.npu.Stream, callback_func: Callable[..., None] = lambda x: None, ): @@ -232,6 +238,9 @@ class KVCacheSendingLayerThread(threading.Thread): self.send_queue = queue.Queue[SendTask]() self.k_buffer = k_buffer self.v_buffer = v_buffer + self.enable_kv_quant = enable_kv_quant + self.k_quant_buffer = k_quant_buffer + self.v_quant_buffer = v_quant_buffer self.ready_event = ready_event self.callback_func = callback_func @@ -325,19 +334,43 @@ class KVCacheSendingLayerThread(threading.Thread): grouped_remote_block_ids, grouped_local_block_ids = group_concurrent_contiguous( remote_block_ids, local_block_ids ) - for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( - zip(layer_local_kv_base_addr, layer_remote_kv_base_addr) - ): - block_len = block_lens[k] - for group_remote_block_id, group_local_block_id in zip( - grouped_remote_block_ids, grouped_local_block_ids + # kv cache quantization scenario + if self.enable_kv_quant and send_task.k_quant_cache is not None: + assert len(block_lens) == 2, "Quantization block length must be 2!" + quant_block_lens = [block_lens[0] // 2, block_lens[1]] + layer_local_quant_kv_addr = [self.k_quant_buffer.data_ptr(), self.v_quant_buffer.data_ptr()] + rearrange_block_ids = send_task.group_rearrange_block_ids[layer_group_idx] + # eg:[5,6,7,9] -> {5:0, 6:1, 7:2, 9:3} + rearrange_block_dict = { + value: index + for index, value in enumerate(rearrange_block_ids) # type:ignore + } + for block_len, src_layer_base_addr, dst_layer_base_addr in zip( + quant_block_lens, layer_local_quant_kv_addr, layer_remote_kv_base_addr ): - src = src_layer_base_addr + group_local_block_id[0] * block_len - dst = dst_layer_base_addr + group_remote_block_id[0] * block_len - length = len(group_local_block_id) * block_len - src_list.append(src) - dst_list.append(dst) - length_list.append(length) + for group_remote_block_id, group_local_block_id in zip( + grouped_remote_block_ids, grouped_local_block_ids + ): + src = src_layer_base_addr + rearrange_block_dict[group_local_block_id[0]] * block_len + dst = dst_layer_base_addr + group_remote_block_id[0] * block_len + length = len(group_local_block_id) * block_len + src_list.append(src) + dst_list.append(dst) + length_list.append(length) + else: + for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( + zip(layer_local_kv_base_addr, layer_remote_kv_base_addr) + ): + block_len = block_lens[k] + for group_remote_block_id, group_local_block_id in zip( + grouped_remote_block_ids, grouped_local_block_ids + ): + src = src_layer_base_addr + group_local_block_id[0] * block_len + dst = dst_layer_base_addr + group_remote_block_id[0] * block_len + length = len(group_local_block_id) * block_len + src_list.append(src) + dst_list.append(dst) + length_list.append(length) else: rearrange_block_ids = send_task.group_rearrange_block_ids[layer_group_idx] rearrange_block_dict = { @@ -380,6 +413,14 @@ class KVCacheSendingLayerThread(threading.Thread): value = value.view(-1, key.shape[-1]) # type:ignore self.k_buffer[: key.shape[0]].copy_(key) # [:4, 128] -> self.v_buffer[: value.shape[0]].copy_(value) + if send_task.k_quant_cache is not None: + with npu_stream_switch(self.resharding_stream): + key_quant = send_task.k_quant_cache + key_quant = key_quant.view(-1, key_quant.shape[-1]) # type:ignore + self.k_quant_buffer[: key_quant.shape[0]].copy_(key_quant) + value_quant = send_task.v_quant_cache + value_quant = value_quant.view(-1, value_quant.shape[-1]) # type:ignore + self.v_quant_buffer[: value_quant.shape[0]].copy_(value_quant) # Merge transmission tasks of the same session session_meta: dict[str, TransferMeta] = {} @@ -395,7 +436,9 @@ class KVCacheSendingLayerThread(threading.Thread): session_meta[session_id].length.extend(length_list) session_meta[session_id].req_ids.append(req_id) - if self.pd_head_ratio == 1: + if send_task.k_quant_cache is not None: + self.resharding_stream.synchronize() + elif self.pd_head_ratio == 1: """ Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang. This issue will be fixed in CANN version 8.5.rc1. @@ -628,7 +671,7 @@ class MooncakeLayerwiseConnector(KVConnectorBase_V1, SupportsHMA): self.connector_worker.wait_for_layer_load(layer_name) def save_kv_layer( - self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs + self, layer_name: str, kv_layer: list[torch.Tensor], attn_metadata: "AttentionMetadata", **kwargs ) -> None: """MooncakeLayerwiseConnector does not save explicitly.""" assert self.connector_worker is not None @@ -962,10 +1005,13 @@ class MooncakeLayerwiseConnectorWorker: self.layer_metadata: dict[str, LayerMetadata] = {} self.attn_resharding_group_idx = set[int]() + self.enable_kv_quant = ( + vllm_config.quant_config.enable_fa_quant if vllm_config.quant_config is not None else False + ) self.pd_head_ratio = get_ascend_config().pd_head_ratio self.num_head_replica = get_ascend_config().num_head_replica self.resharding_stream = None - if self.pd_head_ratio > 1: + if self.pd_head_ratio > 1 or self.enable_kv_quant: self.resharding_stream = torch.npu.Stream() self.remote_poller = zmq.Poller() # type: ignore @@ -985,11 +1031,16 @@ class MooncakeLayerwiseConnectorWorker: self.timeout = 1.0 # seconds self.k_buffer: torch.Tensor | None = None self.v_buffer: torch.Tensor | None = None + # TODO(kunpengW-code): Reuse k_buffer, v_buffer + self.k_quant_buffer: torch.Tensor | None = None + self.v_quant_buffer: torch.Tensor | None = None - def create_kv_buffer(self, first_kv_cache): + def create_kv_buffer(self, first_kv_cache_tuple): + alignment = 2 * 1024 * 1024 + buffer_list = [] + first_kv_cache = first_kv_cache_tuple[0] if self.pd_head_ratio > 1: # regesit kv buffer for tp inequal - alignment = 2 * 1024 * 1024 self.k_buffer = torch.zeros( first_kv_cache.numel() + alignment, dtype=first_kv_cache.dtype, device=first_kv_cache.device ) @@ -1002,18 +1053,34 @@ class MooncakeLayerwiseConnectorWorker: self.v_buffer = align_memory(self.v_buffer, alignment)[: first_kv_cache.numel()].view( -1, first_kv_cache.shape[-1] ) + buffer_list.append(self.k_buffer) + buffer_list.append(self.v_buffer) + if self.enable_kv_quant: + quant_k_cache_numel = first_kv_cache_tuple[0].numel() // 2 + self.k_quant_buffer = torch.zeros( + quant_k_cache_numel + alignment, dtype=torch.int8, device=first_kv_cache.device + ) + self.k_quant_buffer = align_memory(self.k_quant_buffer, alignment)[:quant_k_cache_numel].view( + -1, first_kv_cache.shape[-1] + ) + quant_v_cache_numel = first_kv_cache_tuple[1].numel() + self.v_quant_buffer = torch.zeros( + quant_v_cache_numel + alignment, dtype=first_kv_cache.dtype, device=first_kv_cache.device + ) + self.v_quant_buffer = align_memory(self.v_quant_buffer, alignment)[:quant_v_cache_numel].view( + -1, first_kv_cache_tuple[1].shape[-1] + ) + buffer_list.append(self.k_quant_buffer) + buffer_list.append(self.v_quant_buffer) - for tensor in (self.k_buffer, self.v_buffer): - assert tensor.data_ptr() % alignment == 0, ( - "The address of the registered kv cache should be aligned to 2M" - ) - ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel()) - logger.info( - f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} " - f"{tensor.numel()} {ret_value=}" - ) - if ret_value != 0: - raise RuntimeError("Mooncake memory registration failed. ") + for tensor in buffer_list: + assert tensor.data_ptr() % alignment == 0, "The address of the registered kv cache should be aligned to 2M" + ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel()) + logger.info( + f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}" + ) + if ret_value != 0: + raise RuntimeError("Mooncake memory registration failed. ") def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data.""" @@ -1042,8 +1109,8 @@ class MooncakeLayerwiseConnectorWorker: ptrs = [] lengths = [] - use_resharding_buffer = False - resharding_buffer = None + use_kv_buffer = False + kv_buffer = None for layer_name, kv_cache_tuple in kv_caches.items(): if isinstance(kv_cache_tuple, (list, tuple)) is False: kv_cache_tuple = [kv_cache_tuple] @@ -1051,12 +1118,13 @@ class MooncakeLayerwiseConnectorWorker: layer_kv_cache_spec = kv_cache_groups[layer_kv_group_id].kv_cache_spec if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name] - if self.pd_head_ratio > 1 and (isinstance(layer_kv_cache_spec, (FullAttentionSpec, SlidingWindowSpec))): + if ( + self.pd_head_ratio > 1 and (isinstance(layer_kv_cache_spec, (FullAttentionSpec, SlidingWindowSpec))) + ) or self.enable_kv_quant: self.attn_resharding_group_idx.add(layer_kv_group_id) - if use_resharding_buffer is False: - use_resharding_buffer = True - resharding_buffer = kv_cache_tuple[0] - self.resharding_stream = torch.npu.Stream() + if use_kv_buffer is False: + use_kv_buffer = True + kv_buffer = kv_cache_tuple single_layer_meta = LayerMetadata([], [], [], []) for single_kv_cache in kv_cache_tuple: block_start_rank = 1 @@ -1092,8 +1160,8 @@ class MooncakeLayerwiseConnectorWorker: lengths.append(kv_cache_tensor.size) global_te.register_buffer(ptrs, lengths) - if use_resharding_buffer: - self.create_kv_buffer(resharding_buffer) + if use_kv_buffer: + self.create_kv_buffer(kv_buffer) num_attn_module = 2 if self.vllm_config.model_config.hf_text_config.model_type == "longcat_flash" else 1 mtp_layer_name = "" @@ -1133,6 +1201,9 @@ class MooncakeLayerwiseConnectorWorker: use_mla=self.use_mla, k_buffer=self.k_buffer, v_buffer=self.v_buffer, + enable_kv_quant=self.enable_kv_quant, + k_quant_buffer=self.k_quant_buffer, + v_quant_buffer=self.v_quant_buffer, resharding_stream=self.resharding_stream, callback_func=self.send_done_send_signal, ) @@ -1380,7 +1451,7 @@ class MooncakeLayerwiseConnectorWorker: metadata.requests[req_id] = update_metadata[req_id] # update send task trans block info - if self.pd_head_ratio != 1: + if self.pd_head_ratio != 1 or self.enable_kv_quant: send_task = metadata.send_task send_task.group_rearrange_block_ids = [[] for _ in range(self.num_kv_cache_groups)] send_task.group_num_blocks = [0 for _ in range(self.num_kv_cache_groups)] @@ -1388,7 +1459,7 @@ class MooncakeLayerwiseConnectorWorker: send_task.group_block_table = [None for _ in range(self.num_kv_cache_groups)] send_task.group_block_len_tensor = [None for _ in range(self.num_kv_cache_groups)] send_task.group_seq_start_tensor = [None for _ in range(self.num_kv_cache_groups)] - device = self.k_buffer.device # type: ignore + device = self.k_buffer.device if self.k_buffer is not None else self.k_quant_buffer.device # type: ignore for i in self.attn_resharding_group_idx: send_task.group_rearrange_block_ids[i].extend( sorted( @@ -1415,7 +1486,7 @@ class MooncakeLayerwiseConnectorWorker: def save_kv_layer( self, layer_name: str, - kv_layer: tuple[torch.Tensor, torch.Tensor], + kv_layer: list[torch.Tensor], attn_metadata: "AttentionMetadata", connector_metadata: MooncakeLayerwiseConnectorMetadata, **kwargs, @@ -1490,12 +1561,51 @@ class MooncakeLayerwiseConnectorWorker: values = values.reshape(-1, *kv_layer[1].shape[2:]) (keys, values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys, values) + quant_keys = None + quant_values = None + if self.enable_kv_quant and self.current_layer in self.vllm_config.quant_config.kvcache_quant_layers: + assert self.resharding_stream is not None + with npu_stream_switch(self.resharding_stream): + reshape_cache_event.wait() + device = self.k_quant_buffer.device # type: ignore + layer = self.vllm_config.compilation_config.static_forward_context[layer_name] + # Initialize buffers + # [num_tokens, kv_head, head_dim] + quant_key = torch.empty( + (send_task.group_num_tokens[layer_group_idx], *kv_layer[0].size()[-2:]), + dtype=kv_layer[0].dtype, + device=device, + ) + quant_values = torch.empty( + (send_task.group_num_tokens[layer_group_idx], *kv_layer[1].size()[-2:]), + dtype=kv_layer[1].dtype, + device=device, + ) + + # Load cache data into buffers + torch_npu.atb.npu_paged_cache_load( + kv_layer[0], + kv_layer[1], + send_task.group_block_table[layer_group_idx], + send_task.group_block_len_tensor[layer_group_idx], + seq_starts=send_task.group_seq_start_tensor[layer_group_idx], + key=quant_key, + value=quant_values, + ) + quant_keys = torch.ops.vllm.quantize( + quant_key, layer.fak_descale, layer.fak_descale_reciprocal, layer.fak_offset + ) + quant_keys = self.get_nz_cache(quant_keys, layer_group_idx) + quant_values = self.get_nz_cache(quant_values, layer_group_idx) + assert self.kv_send_layer_thread is not None assert reshape_cache_event is not None layer_send_task = SendTask( wait_event=reshape_cache_event, k_cache=keys, v_cache=values, + k_quant_cache=quant_keys, + v_quant_cache=quant_values, layer_idx=self.current_layer, layer_name=layer_name, group_rearrange_block_ids=send_task.group_rearrange_block_ids, @@ -1510,6 +1620,15 @@ class MooncakeLayerwiseConnectorWorker: self.kv_send_layer_thread.send_queue.put(layer_send_task) self.current_layer += 1 + # NOTE: Due to the FIA operator constraints, the expected kv cache is ND format, NZ shape, + # while the npu_format_cast method only modifies the memory layout, we manually convert it to NZ shape here + def get_nz_cache(self, cache_tensor: torch.Tensor, layer_group_idx: int): + head_num, head_dim = cache_tensor.shape[-2], cache_tensor.shape[-1] + cache_tensor = cache_tensor.view(-1, self.block_size[layer_group_idx], head_num * head_dim) + cache_tensor = trans_nd_to_nz(cache_tensor) + cache_tensor = cache_tensor.reshape(-1, head_num, head_dim) + return cache_tensor + def _get_remote_socket(self, remote_host: str, remote_handshake_port: int) -> zmq.Socket: # type: ignore """Get a socket to the remote host.""" remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port) diff --git a/vllm_ascend/ops/mla.py b/vllm_ascend/ops/mla.py index c7ae6046..de2e2fa4 100644 --- a/vllm_ascend/ops/mla.py +++ b/vllm_ascend/ops/mla.py @@ -120,6 +120,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa, kv_a_layernorm=mla_modules.kv_a_layernorm, o_proj=mla_modules.o_proj, + layer_name=f"{prefix}.attn", ) original_process_weights = self.mla_attn.process_weights_after_loading diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 41463c7b..ad6d1d9b 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -528,3 +528,17 @@ # Future Plan: # Remove this patch when: # design a dispatch mechanism for batch_memcpy_kernel. +# +# ** 23. File: worker/patch_weight_utils.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.deepseek_v2.DeepseekV2ForCausalLM.load_weights` +# Why: +# The C8 weight quantized by modelslim will modify the model structure, +# and the scale and offset required for kvcache quantization will increase. +# In addition, the names of the quantization parameters are different from +# those in the community. +# How: +# we have enhanced the maybe_remap_kv_scale_name function. +# Future Plan: +# The maybe_remap_kv_scale_name function of the community is reconstructed to support +# multiple backends. diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index 06aa5d2a..11982294 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -22,6 +22,7 @@ if HAS_TRITON: # isort: off +import vllm_ascend.patch.worker.patch_weight_utils # noqa import vllm_ascend.patch.platform.patch_sched_yield # noqa import vllm_ascend.patch.worker.patch_unquantized_gemm # noqa import vllm_ascend.patch.worker.patch_bert # noqa diff --git a/vllm_ascend/patch/worker/patch_weight_utils.py b/vllm_ascend/patch/worker/patch_weight_utils.py new file mode 100644 index 00000000..f2e50c43 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_weight_utils.py @@ -0,0 +1,86 @@ +import sys +from typing import Any + +from vllm.logger import init_logger +from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name + +logger = init_logger(__name__) + + +class ImportPatchDecorator: + """Import patch decorator""" + + _patches: dict[str, Any] = {} + + @classmethod + def register(cls, module_name): + """Decorator for registering module patches""" + + def decorator(func): + cls._patches[module_name] = func + return func + + return decorator + + @classmethod + def apply_patches(cls): + """Apply all patches""" + for module_name, patch_func in cls._patches.items(): + if module_name in sys.modules: + module = sys.modules[module_name] + try: + patch_func(module) + except Exception as e: + logger.error(f"Patch application failed {module_name}: {e}") + + +@ImportPatchDecorator.register("vllm.model_executor.models.deepseek_v2") +def patch_deepseek(module): + ori_maybe_remap_kv_scale_name = maybe_remap_kv_scale_name + + def new_remap(name: str, params_dict: dict): + name = ori_maybe_remap_kv_scale_name(name, params_dict) + + replace_scale_names = ["fa_q.scale", "fa_k.scale", "fa_v.scale", "fa_q.offset", "fa_k.offset", "fa_v.offset"] + + for scale_name in replace_scale_names: + if name.endswith(scale_name): + remap_name = name.replace(scale_name, f"mla_attn.mla_attn.{scale_name}") + if remap_name in params_dict: + return remap_name + else: + return remap_name.replace(".mla_attn", "") + + return name + + if hasattr(module, "maybe_remap_kv_scale_name"): + module._original_maybe_remap_kv_scale_name = module.maybe_remap_kv_scale_name + module.maybe_remap_kv_scale_name = new_remap + + +@ImportPatchDecorator.register("vllm.model_executor.model_loader.weight_utils") +def patch_weight_utils(module): + if "vllm.model_executor.models.deepseek_v2" in sys.modules: + deepseek = sys.modules["vllm.model_executor.models.deepseek_v2"] + if hasattr(deepseek, "maybe_remap_kv_scale_name"): + module.maybe_remap_kv_scale_name = deepseek.maybe_remap_kv_scale_name + + +original_import = __builtins__["__import__"] # type: ignore + + +def patched_import(name, globals=None, locals=None, fromlist=(), level=0): + module = original_import(name, globals, locals, fromlist, level) + + if name in ImportPatchDecorator._patches: + try: + ImportPatchDecorator._patches[name](module) + except Exception as e: + logger.error(f"Patch application failed during import {name}: {e}") + + return module + + +__builtins__["__import__"] = patched_import + +ImportPatchDecorator.apply_patches() diff --git a/vllm_ascend/quantization/methods/__init__.py b/vllm_ascend/quantization/methods/__init__.py index 38840295..2e91ba2d 100644 --- a/vllm_ascend/quantization/methods/__init__.py +++ b/vllm_ascend/quantization/methods/__init__.py @@ -32,10 +32,11 @@ from typing import Any # Import base classes from .base import AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme, QuantType +# Import all scheme classes for external access +from .kv_c8 import AscendFAQuantAttentionMethod + # Import registry functions from .registry import get_scheme_class, register_scheme - -# Import all scheme classes for external access from .w4a4_flatquant import AscendW4A4FlatQuantDynamicLinearMethod from .w4a4_laos_dynamic import AscendW4A4LaosDynamicLinearMethod from .w4a8 import AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod @@ -77,4 +78,5 @@ __all__ = [ "AscendW4A16FusedMoEMethod", "AscendW4A4FlatQuantDynamicLinearMethod", "AscendW4A4LaosDynamicLinearMethod", + "AscendFAQuantAttentionMethod", ] diff --git a/vllm_ascend/quantization/methods/kv_c8.py b/vllm_ascend/quantization/methods/kv_c8.py new file mode 100644 index 00000000..8a700484 --- /dev/null +++ b/vllm_ascend/quantization/methods/kv_c8.py @@ -0,0 +1,65 @@ +import torch +from vllm.config import get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size + +from .registry import register_scheme + + +def weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor): + """fa_q weight loader.""" + if param.numel() == 1 and loaded_weight.numel() == 1: + param.data.fill_(loaded_weight.item()) + else: + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + shard_size = loaded_weight.shape[0] // tp_size + loaded_weight = loaded_weight.narrow(0, shard_size * tp_rank, shard_size) + assert param.size() == loaded_weight.size(), ( + f"Attempted to load weight ({loaded_weight.size()}) into parameter ({param.size()}) when TP is ({tp_size})" + ) + + param.data.copy_(loaded_weight) + + +@register_scheme("FAKQuant", "attention") +class AscendFAQuantAttentionMethod: + def __init__(self): + self.transpose_weight = True + self.printFlag = False + vllm_config = get_current_vllm_config() + config = vllm_config.model_config.hf_config + self.kv_lora_rank = getattr(config, "kv_lora_rank", 0) + self.qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) + + def create_weights(self, layer: torch.nn.Module) -> None: + extra_module_names = ["fa_q", "fa_k", "fa_v"] + for name in extra_module_names: + setattr(layer, name, torch.nn.Module()) + params_dict = {} + dtype = torch.get_default_dtype() + params_dict["fa_q.scale"] = torch.empty((layer.num_heads, 1), dtype=dtype) + params_dict["fa_k.scale"] = torch.empty((layer.num_kv_heads, 1), dtype=dtype) + params_dict["fa_v.scale"] = torch.empty((layer.num_kv_heads, 1), dtype=dtype) + params_dict["fa_q.offset"] = torch.empty((layer.num_heads, 1), dtype=torch.int8) + params_dict["fa_k.offset"] = torch.empty((layer.num_kv_heads, 1), dtype=torch.int8) + params_dict["fa_v.offset"] = torch.empty((layer.num_kv_heads, 1), dtype=torch.int8) + + for name, weight in params_dict.items(): + module_name, weight_name = name.rsplit(".", 1) + module = getattr(layer, module_name) + weight_param = torch.nn.Parameter(weight, requires_grad=False) + module.register_parameter(weight_name, weight_param) + # When loading weights, segment them according to TP + weight_param.weight_loader = weight_loader + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + fa_k_scale = torch.squeeze(layer.fa_k.scale).unsqueeze(0) + layer.fak_descale_float = torch.nn.Parameter(fa_k_scale.to(torch.float), requires_grad=False) + layer.fak_descale = torch.nn.Parameter(fa_k_scale, requires_grad=False) + layer.fak_descale_reciprocal = 1.0 / torch.nn.Parameter(fa_k_scale, requires_grad=False) + fa_k_offset = torch.squeeze(layer.fa_k.offset).unsqueeze(0) + layer.fak_offset = torch.nn.Parameter(fa_k_offset.to(layer.fak_descale.dtype), requires_grad=False) + + repeated_quant_kscale = fa_k_scale.repeat(self.kv_lora_rank) + layer.quant_kscale = repeated_quant_kscale.view(1, self.kv_lora_rank) + layer.quant_kscale = 1.0 / torch.nn.Parameter(layer.quant_kscale.to(torch.float), requires_grad=False) diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index c682856b..82b3d279 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -24,6 +24,7 @@ configs generated by the ModelSlim tool, along with model-specific mappings. import glob import json import os +import re from collections.abc import Mapping from types import MappingProxyType from typing import Any, Optional @@ -31,6 +32,7 @@ from typing import Any, Optional import torch from vllm.config import get_current_vllm_config from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization import register_quantization_config @@ -38,7 +40,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf from vllm.model_executor.layers.vocab_parallel_embedding import UnquantizedEmbeddingMethod, VocabParallelEmbedding from vllm.model_executor.models.utils import WeightsMapper -from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD +from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, calc_split_factor from .methods import get_scheme_class @@ -438,6 +440,7 @@ class AscendModelSlimConfig(QuantizationConfig): new_k = k.replace("weight_packed", "weight") extra_quant_dict[new_k] = self.quant_description[k] self.quant_description.update(extra_quant_dict) + self._add_kvcache_quant_metadata() def __repr__(self) -> str: return "AscendModelSlimConfig:\n" + super().__repr__() @@ -509,8 +512,6 @@ class AscendModelSlimConfig(QuantizationConfig): self.packed_modules_mapping = packed_modules_model_mapping[model_type] prefix = self.quant_prefix_mapper(model_type, prefix) - from vllm.model_executor.layers.attention import Attention - if model_type != "kimi_k2": if prefix.startswith("language_model"): prefix = prefix.split(".", 1)[-1] @@ -522,11 +523,7 @@ class AscendModelSlimConfig(QuantizationConfig): return AscendUnquantizedLinearMethod() scheme = create_scheme_for_layer(self.quant_description, prefix, "linear", self.packed_modules_mapping) return AscendLinearMethod(scheme) - elif ( - isinstance(layer, Attention) - and "fa_quant_type" in self.quant_description - and self.quant_description["fa_quant_type"] is not None - ): + elif isinstance(layer, AttentionLayerBase) and self.is_fa_quant_layer(prefix): scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping) return AscendKVCacheMethod(scheme) elif isinstance(layer, FusedMoE): @@ -573,6 +570,39 @@ class AscendModelSlimConfig(QuantizationConfig): assert is_skipped is not None return is_skipped + def is_fa_quant_layer(self, prefix): + if self.enable_fa_quant: + layer_id_str = "".join(re.findall(r"\.(\d+)\.", prefix)) + if layer_id_str.isdigit() and int(layer_id_str) in self.kvcache_quant_layers: + return True + return False + + def enabling_fa_quant(self, vllm_config, layer_name) -> bool: + is_decode_instance = ( + vllm_config.kv_transfer_config is not None + and vllm_config.kv_transfer_config.is_kv_consumer + and not vllm_config.kv_transfer_config.is_kv_producer + ) + return bool(is_decode_instance and self.is_fa_quant_layer(layer_name)) + + def get_kv_quant_dtype(self, layer_name, cache_dtype, model_config): + if self.enable_fa_quant and self.is_fa_quant_layer(layer_name): + ori_dtype = model_config.dtype + quant_dtype = torch.int8 + # For MLA models like deepseek, we only quantify K cache to ensure accuracy + if model_config.use_mla: + return quant_dtype, ori_dtype + else: + return quant_dtype, quant_dtype + return cache_dtype, cache_dtype + + def get_kv_quant_split_factor(self, layer_name, kv_head_dim_list): + if self.enable_fa_quant and self.is_fa_quant_layer(layer_name): + k_quant_head_dim = kv_head_dim_list[0] + v_quant_head_dim = kv_head_dim_list[1] * 2 + kv_head_dim_list = [k_quant_head_dim, v_quant_head_dim] + return calc_split_factor(kv_head_dim_list) + def maybe_update_config(self, model_name: str, revision: str | None = None) -> None: """Load the ModelSlim quantization config from model directory. @@ -606,6 +636,7 @@ class AscendModelSlimConfig(QuantizationConfig): with open(config_path) as f: self.quant_description = json.load(f) self._apply_extra_quant_adaptations() + self._add_kvcache_quant_metadata() return # Collect diagnostic info for the error message @@ -678,3 +709,13 @@ class AscendModelSlimConfig(QuantizationConfig): def get_scaled_act_names(self) -> list[str]: return [] + + def _add_kvcache_quant_metadata(self): + fa_quant_type = self.quant_description.get("fa_quant_type", "") + self.enable_fa_quant = fa_quant_type != "" + self.kvcache_quant_layers = [] + if self.enable_fa_quant: + for key in self.quant_description: + if "fa_k.scale" in key: + _id = "".join(re.findall(r"\.(\d+)\.", key)) + self.kvcache_quant_layers.append(int(_id)) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 66d94475..458b95cd 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1211,3 +1211,36 @@ def get_rope_dim(vllm_config): rope_dim = int(model_config.hf_text_config.rotary_dim) return rope_dim + + +def calc_split_factor(num_list: list[int]): + total = sum(num_list) + split_factor_list = [] + for num in num_list: + split_factor_list.append(total / num) + return split_factor_list + + +# NOTE: The last two dimensions of ND are transferred to NZ +def trans_nd_to_nz(cache_tensor: torch.Tensor): + assert len(cache_tensor.shape) >= 2 + batch = cache_tensor.shape[:-2] + a, b = cache_tensor.shape[-2], cache_tensor.shape[-1] + + dtype = cache_tensor.dtype + if dtype == torch.int8: + a0, b0 = 16, 32 + else: + a0, b0 = 16, 16 + + nz_shape = list(batch) + [math.ceil(b / b0), math.ceil(a / a0), a0, b0] + + # Generate the axis order for the transpose operation. + offset = len(cache_tensor.shape) - 2 + base = [2, 0, 1, 3] + array_trans = [i for i in range(offset)] + [i + offset for i in base] + # Perform shape transformation and transpose operation. + *_, n1, m1, m0, n0 = nz_shape + cache_tensor = cache_tensor.reshape(nz_shape[:-4] + [m1, m0, n1, n0]) + cache_tensor = cache_tensor.permute(*array_trans) + return cache_tensor diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index fc02fc0a..8fc168ac 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -117,6 +117,7 @@ from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer from vllm_ascend.utils import ( + calc_split_factor, check_gdn_layer, enable_sp, enable_sp_by_pass, @@ -2683,12 +2684,6 @@ class NPUModelRunner(GPUModelRunner): # as it only support the 0-dim of kv_cache is `num_blocks`. # For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim # and rope head dim. - if self.model_config.use_mla: - head_size = ( - self.model_config.hf_text_config.qk_rope_head_dim - + self.model_config.hf_text_config.kv_lora_rank - ) - if not self.model_config.use_mla: # for non-mla model, use FullAttentionSpec k_tensor_split_factor = 2.0 @@ -2703,8 +2698,16 @@ class NPUModelRunner(GPUModelRunner): dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3] else: # for other deepseek models, use MLAAttentionSpec - k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank - v_tensor_split_factor = head_size / self.model_config.hf_text_config.qk_rope_head_dim + kv_head_dim_list = [ + self.model_config.hf_text_config.kv_lora_rank, + self.model_config.hf_text_config.qk_rope_head_dim, + ] + if self.is_kv_consumer and self.vllm_config.quant_config is not None: + k_tensor_split_factor, v_tensor_split_factor = ( + self.vllm_config.quant_config.get_kv_quant_split_factor(layer_name, kv_head_dim_list) + ) + else: + k_tensor_split_factor, v_tensor_split_factor = calc_split_factor(kv_head_dim_list) k_tensor_size = int(kv_cache_tensor.size // k_tensor_split_factor) v_tensor_size = int(kv_cache_tensor.size // v_tensor_split_factor) @@ -2881,8 +2884,13 @@ class NPUModelRunner(GPUModelRunner): num_kv_heads, self.model_config.hf_text_config.qk_rope_head_dim, ] - k_cache = raw_k_tensor.view(kv_cache_spec.dtype).view(k_shape) - v_cache = raw_v_tensor.view(kv_cache_spec.dtype).view(v_shape) + k_cache_dtype = v_cache_dtype = kv_cache_spec.dtype + if self.is_kv_consumer and self.vllm_config.quant_config is not None: + k_cache_dtype, v_cache_dtype = self.vllm_config.quant_config.get_kv_quant_dtype( + layer_name, kv_cache_spec.dtype, self.model_config + ) + k_cache = raw_k_tensor.view(k_cache_dtype).view(k_shape) + v_cache = raw_v_tensor.view(v_cache_dtype).view(v_shape) if self.use_sparse: dsa_k_cache_shape = ( @@ -3199,12 +3207,17 @@ class NPUModelRunner(GPUModelRunner): elif spec := attn_module.get_kv_cache_spec(self.vllm_config): assert isinstance(spec, MLAAttentionSpec) from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec + if getattr(attn_module.impl, "fa_quant_layer", False): + head_size = attn_module.head_size + attn_module.qk_rope_head_dim + dtype, cache_dtype_str = attn_module.impl.dtype, None + else: + head_size, dtype, cache_dtype_str = spec.head_size, spec.dtype, spec.cache_dtype_str kv_cache_spec[layer_name] = AscendMLAAttentionSpec( block_size=spec.block_size, num_kv_heads=spec.num_kv_heads, - head_size=spec.head_size, - dtype=spec.dtype, - cache_dtype_str=spec.cache_dtype_str, + head_size=head_size, + dtype=dtype, + cache_dtype_str=cache_dtype_str, ) elif isinstance(attn_module, MambaBase):