diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index c625969f..1cf661bd 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -807,6 +807,7 @@ class TestAscendMLAImpl(TestBase): attn_type=None, kv_sharing_target_layer_name=None, **kwargs) + self.impl.fa_quant_layer = False def test_init(self): self.assertEqual(self.impl.num_heads, 256) @@ -938,9 +939,9 @@ class TestAscendMLAImpl(TestBase): @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj") - @patch("torch_npu.npu_fused_infer_attention_score") + @patch("torch_npu.npu_fused_infer_attention_score_v2") def test_forward_decode_without_graph(self, - mock_npu_fused_infer_attention_score, + mock_npu_fused_infer_attention_score_v2, mock_up_proj, mock_get_forward_context): num_tokens = 100 @@ -956,8 +957,8 @@ class TestAscendMLAImpl(TestBase): metadata = MagicMock() metadata.decode = MagicMock() metadata.decode.block_table = MagicMock() - metadata.decode.seq_lens = 10 - mock_npu_fused_infer_attention_score.return_value = [ + metadata.decode.actual_seq_lengths = 10 + mock_npu_fused_infer_attention_score_v2.return_value = [ torch.randn(num_tokens, self.impl.num_heads, self.impl.kv_lora_rank), None ] @@ -971,7 +972,7 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(result.shape[1], self.impl.num_heads) self.assertEqual(result.shape[2], self.impl.v_head_dim) mock_up_proj.assert_called_once() - mock_npu_fused_infer_attention_score.assert_called_once() + mock_npu_fused_infer_attention_score_v2.assert_called_once() @patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad") @patch("vllm_ascend.attention.mla_v1.get_weight_prefetch_method", @@ -1103,8 +1104,8 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank) @patch('vllm_ascend.ascend_forward_context.get_forward_context') - @patch("torch_npu.npu_fused_infer_attention_score") - def test_forward_decode(self, mock_npu_fused_infer_attention_score, + @patch("torch_npu.npu_fused_infer_attention_score_v2") + def test_forward_decode(self, mock_npu_fused_infer_attention_score_v2, mock_get_forward_context): B = 2 N = self.impl.num_kv_heads @@ -1121,11 +1122,11 @@ class TestAscendMLAImpl(TestBase): attn_metadata = MagicMock() attn_metadata.attn_state = AscendAttentionState.SpecDecoding attn_metadata.decode = MagicMock() - attn_metadata.decode.actual_seq_lengths_q = MagicMock() - attn_metadata.decode.seq_lens_list = MagicMock() + attn_metadata.decode.actual_seq_qlen = MagicMock() + attn_metadata.decode.actual_seq_kvlen = MagicMock() self.impl.enable_kv_nz = True - mock_npu_fused_infer_attention_score.return_value = [ + mock_npu_fused_infer_attention_score_v2.return_value = [ torch.randn(B, N, self.impl.kv_lora_rank), None ] mock_get_forward_context.return_value = MagicMock(capturing=False) diff --git a/tests/ut/quantization/test_kv_c8.py b/tests/ut/quantization/test_kv_c8.py new file mode 100644 index 00000000..cfafefc9 --- /dev/null +++ b/tests/ut/quantization/test_kv_c8.py @@ -0,0 +1,468 @@ +import unittest +import torch +import torch.nn as nn +from unittest.mock import Mock, patch + + +class TestWeightLoader(unittest.TestCase): + """Test cases for weight_loader function in kv_c8.py""" + + def setUp(self): + """Set up test environment before each test""" + # Import the module under test + from vllm_ascend.quantization.methods.kv_c8 import weight_loader + self.weight_loader = weight_loader + + # Mock distributed functions + self.tp_rank_patch = patch( + "vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_rank" + ) + self.tp_size_patch = patch( + "vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_world_size" + ) + self.mock_tp_rank = self.tp_rank_patch.start() + self.mock_tp_size = self.tp_size_patch.start() + + def tearDown(self): + """Clean up after each test""" + self.tp_rank_patch.stop() + self.tp_size_patch.stop() + + def test_weight_loader_single_element(self): + """Test weight_loader when both tensors contain a single element""" + # Create tensors with single element + param = torch.tensor([0.0]) + loaded_weight = torch.tensor([5.0]) + + # Call weight_loader + self.weight_loader(param, loaded_weight) + + # Verify the value was filled correctly + self.assertEqual(param.item(), 5.0) + self.assertEqual(param.dtype, torch.float32) + + def test_weight_loader_single_element_int(self): + """Test weight_loader with integer tensors""" + param = torch.tensor([0], dtype=torch.int32) + loaded_weight = torch.tensor([10], dtype=torch.int32) + + self.weight_loader(param, loaded_weight) + + self.assertEqual(param.item(), 10) + + def test_weight_loader_tp_sharding_first_rank(self): + """Test weight_loader with tensor parallelism sharding for first rank""" + # Configure mocks for rank 0 of 4 + self.mock_tp_rank.return_value = 0 + self.mock_tp_size.return_value = 4 + + # Create test tensors + param = torch.zeros(2, 5) # Target param shape (2,5) + loaded_weight = torch.ones(8, 5) # Full weight (8,5) + + # Mock narrow to track the call + with patch.object(loaded_weight, 'narrow', wraps=loaded_weight.narrow) as mock_narrow: + self.weight_loader(param, loaded_weight) + + # Verify narrow was called correctly: narrow(dim=0, start=0, length=2) + mock_narrow.assert_called_once_with(0, 0, 2) + + # Verify data was copied + self.assertTrue(torch.all(param == 1)) + + def test_weight_loader_tp_sharding_middle_rank(self): + """Test weight_loader with tensor parallelism sharding for middle rank""" + # Configure mocks for rank 2 of 4 + self.mock_tp_rank.return_value = 2 + self.mock_tp_size.return_value = 4 + + param = torch.zeros(2, 5) + loaded_weight = torch.ones(8, 5) + + with patch.object(loaded_weight, 'narrow', wraps=loaded_weight.narrow) as mock_narrow: + self.weight_loader(param, loaded_weight) + + # Verify narrow was called correctly: start = shard_size * rank = 2 * 2 = 4 + mock_narrow.assert_called_once_with(0, 4, 2) + + self.assertTrue(torch.all(param == 1)) + + def test_weight_loader_tp_sharding_last_rank(self): + """Test weight_loader with tensor parallelism sharding for last rank""" + # Configure mocks for rank 3 of 4 + self.mock_tp_rank.return_value = 3 + self.mock_tp_size.return_value = 4 + + param = torch.zeros(2, 5) + loaded_weight = torch.ones(8, 5) + + with patch.object(loaded_weight, 'narrow', wraps=loaded_weight.narrow) as mock_narrow: + self.weight_loader(param, loaded_weight) + + # Verify narrow was called correctly: start = 2 * 3 = 6 + mock_narrow.assert_called_once_with(0, 6, 2) + + self.assertTrue(torch.all(param == 1)) + + def test_weight_loader_shape_mismatch(self): + """Test weight_loader raises assertion error on shape mismatch""" + self.mock_tp_rank.return_value = 0 + self.mock_tp_size.return_value = 2 + + param = torch.zeros(2, 3) + loaded_weight = torch.ones(4, 4) # Different shape after sharding + + # Mock narrow to return tensor with wrong shape + with patch.object(loaded_weight, 'narrow', return_value=torch.ones(2, 4)): + with self.assertRaises(AssertionError) as context: + self.weight_loader(param, loaded_weight) + + # Verify error message contains expected information + self.assertIn("Attempted to load weight", str(context.exception)) + self.assertIn("into parameter", str(context.exception)) + + def test_weight_loader_with_different_dtypes(self): + """Test weight_loader handles different dtypes correctly""" + self.mock_tp_rank.return_value = 0 + self.mock_tp_size.return_value = 1 # No sharding + + param = torch.zeros(2, 3, dtype=torch.float32) + loaded_weight = torch.ones(2, 3, dtype=torch.float16) + + self.weight_loader(param, loaded_weight) + + # Verify data was copied and converted + self.assertTrue(torch.all(param == 1)) + self.assertEqual(param.dtype, torch.float32) + + +class TestAscendFAQuantAttentionMethodInit(unittest.TestCase): + """Test cases for AscendFAQuantAttentionMethod initialization""" + + def setUp(self): + """Set up test environment""" + # Mock vllm_config + self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config") + self.mock_get_config = self.config_patch.start() + + # Create mock config with attributes + self.mock_config = Mock() + self.mock_hf_config = Mock() + self.mock_hf_config.kv_lora_rank = 128 + self.mock_hf_config.qk_rope_head_dim = 64 + self.mock_config.model_config.hf_config = self.mock_hf_config + self.mock_get_config.return_value = self.mock_config + + # Import the class after patching + from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod + self.method_class = AscendFAQuantAttentionMethod + + def tearDown(self): + """Clean up after each test""" + self.config_patch.stop() + + def test_init_with_full_config(self): + """Test initialization when config has all attributes""" + method = self.method_class() + + self.assertTrue(method.transpose_weight) + self.assertFalse(method.printFlag) + self.assertEqual(method.kv_lora_rank, 128) + self.assertEqual(method.qk_rope_head_dim, 64) + + def test_init_without_kv_lora_rank(self): + """Test initialization when config lacks kv_lora_rank""" + delattr(self.mock_hf_config, "kv_lora_rank") + + method = self.method_class() + + self.assertEqual(method.kv_lora_rank, 0) + self.assertEqual(method.qk_rope_head_dim, 64) + + def test_init_without_qk_rope_head_dim(self): + """Test initialization when config lacks qk_rope_head_dim""" + delattr(self.mock_hf_config, "qk_rope_head_dim") + + method = self.method_class() + + self.assertEqual(method.kv_lora_rank, 128) + self.assertEqual(method.qk_rope_head_dim, 0) + + def test_init_without_both_attributes(self): + """Test initialization when config lacks both attributes""" + delattr(self.mock_hf_config, "kv_lora_rank") + delattr(self.mock_hf_config, "qk_rope_head_dim") + + method = self.method_class() + + self.assertEqual(method.kv_lora_rank, 0) + self.assertEqual(method.qk_rope_head_dim, 0) + + +class TestAscendFAQuantAttentionMethodCreateWeights(unittest.TestCase): + """Test cases for create_weights method""" + + def setUp(self): + """Set up test environment""" + # Mock vllm_config + self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config") + self.mock_get_config = self.config_patch.start() + + self.mock_config = Mock() + self.mock_hf_config = Mock() + self.mock_hf_config.kv_lora_rank = 128 + self.mock_hf_config.qk_rope_head_dim = 64 + self.mock_config.model_config.hf_config = self.mock_hf_config + self.mock_get_config.return_value = self.mock_config + + # Import the class + from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod + self.method_class = AscendFAQuantAttentionMethod + + # Mock torch functions + self.default_dtype_patch = patch("torch.get_default_dtype", return_value=torch.float32) + self.mock_default_dtype = self.default_dtype_patch.start() + + # Create a real nn.Module for testing + self.layer = nn.Module() + self.layer.num_heads = 32 + self.layer.num_kv_heads = 1 + + def tearDown(self): + """Clean up after each test""" + self.config_patch.stop() + self.default_dtype_patch.stop() + + def test_create_weights_adds_submodules(self): + """Test that create_weights adds fa_q, fa_k, fa_v submodules""" + method = self.method_class() + + with patch("torch.empty") as mock_empty: + mock_empty.return_value = torch.zeros(1, 1) + + method.create_weights(self.layer) + + # Verify submodules were added + self.assertTrue(hasattr(self.layer, "fa_q")) + self.assertTrue(hasattr(self.layer, "fa_k")) + self.assertTrue(hasattr(self.layer, "fa_v")) + + # Verify they are instances of nn.Module + self.assertIsInstance(self.layer.fa_q, nn.Module) + self.assertIsInstance(self.layer.fa_k, nn.Module) + self.assertIsInstance(self.layer.fa_v, nn.Module) + + def test_create_weights_creates_correct_tensors(self): + """Test that create_weights creates tensors with correct shapes and dtypes""" + method = self.method_class() + + # Track torch.empty calls + empty_calls = [] + + def mock_empty(size, dtype=None): + empty_calls.append((size, dtype)) + return torch.zeros(size, dtype=dtype if dtype else torch.float32) + + with patch("torch.empty", side_effect=mock_empty): + method.create_weights(self.layer) + + # Verify tensor creations + expected_calls = [ + ((32, 1), torch.float32), # fa_q.scale + ((1, 1), torch.float32), # fa_k.scale + ((1, 1), torch.float32), # fa_v.scale + ((32, 1), torch.int8), # fa_q.offset + ((1, 1), torch.int8), # fa_k.offset + ((1, 1), torch.int8), # fa_v.offset + ] + + # Compare without considering order + self.assertEqual(len(empty_calls), len(expected_calls)) + for call in expected_calls: + self.assertIn(call, empty_calls) + + def test_create_weights_registers_parameters(self): + """Test that create_weights registers parameters with correct attributes""" + method = self.method_class() + + # Create real tensors for testing + def create_tensor(*args, **kwargs): + size = args[0] if args else kwargs.get('size', (1,)) + dtype = kwargs.get('dtype', torch.float32) + return torch.zeros(*size, dtype=dtype) + + with patch("torch.empty", side_effect=create_tensor): + method.create_weights(self.layer) + + # Import weight_loader for comparison + from vllm_ascend.quantization.methods.kv_c8 import weight_loader + + # Verify each parameter exists and has weight_loader + self.assertTrue(hasattr(self.layer.fa_q, "scale")) + self.assertTrue(hasattr(self.layer.fa_q.scale, "weight_loader")) + self.assertEqual(self.layer.fa_q.scale.weight_loader, weight_loader) + self.assertFalse(self.layer.fa_q.scale.requires_grad) + + self.assertTrue(hasattr(self.layer.fa_k, "scale")) + self.assertTrue(hasattr(self.layer.fa_k.scale, "weight_loader")) + + self.assertTrue(hasattr(self.layer.fa_v, "scale")) + self.assertTrue(hasattr(self.layer.fa_v.scale, "weight_loader")) + + self.assertTrue(hasattr(self.layer.fa_q, "offset")) + self.assertTrue(hasattr(self.layer.fa_q.offset, "weight_loader")) + self.assertEqual(self.layer.fa_q.offset.dtype, torch.int8) + + +class TestAscendFAQuantAttentionMethodProcessWeights(unittest.TestCase): + """Test cases for process_weights_after_loading method""" + + def setUp(self): + """Set up test environment""" + # Mock vllm_config + self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config") + self.mock_get_config = self.config_patch.start() + + self.mock_config = Mock() + self.mock_hf_config = Mock() + self.mock_hf_config.kv_lora_rank = 64 + self.mock_hf_config.qk_rope_head_dim = 32 + self.mock_config.model_config.hf_config = self.mock_hf_config + self.mock_get_config.return_value = self.mock_config + + # Import the class + from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod + self.method_class = AscendFAQuantAttentionMethod + + # Create method instance with real layer + self.method = self.method_class() + + # Create a real nn.Module for testing + self.layer = nn.Module() + + # Create real tensors for fa_k + self.fa_k_scale = torch.tensor([[2.0, 3.0, 4.0]], dtype=torch.float16) # Shape (1,3) + self.fa_k_offset = torch.tensor([[1, 2, 3]], dtype=torch.int8) # Shape (1,3) + + # Create fa_k module with parameters + self.layer.fa_k = nn.Module() + self.layer.fa_k.scale = nn.Parameter(self.fa_k_scale, requires_grad=False) + self.layer.fa_k.offset = nn.Parameter(self.fa_k_offset, requires_grad=False) + + def tearDown(self): + """Clean up after each test""" + self.config_patch.stop() + + def test_process_weights_with_single_value_scale(self): + """Test process_weights with single value scale""" + # Create new layer with single value scale + layer = nn.Module() + layer.fa_k = nn.Module() + layer.fa_k.scale = nn.Parameter(torch.tensor([[2.0]], dtype=torch.float16), requires_grad=False) + layer.fa_k.offset = nn.Parameter(torch.tensor([[1]], dtype=torch.int8), requires_grad=False) + + self.method.kv_lora_rank = 4 + self.method.process_weights_after_loading(layer) + + self.assertEqual(layer.quant_kscale.shape, (1, 4)) + self.assertEqual(layer.quant_kscale.dtype, torch.float32) + + +class TestIntegration(unittest.TestCase): + """Integration tests for the complete kv_c8 functionality""" + + def setUp(self): + """Set up test environment""" + # Mock vllm_config + self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config") + self.mock_get_config = self.config_patch.start() + + self.mock_config = Mock() + self.mock_hf_config = Mock() + self.mock_hf_config.kv_lora_rank = 64 + self.mock_hf_config.qk_rope_head_dim = 32 + self.mock_config.model_config.hf_config = self.mock_hf_config + self.mock_get_config.return_value = self.mock_config + + # Mock distributed functions + self.tp_rank_patch = patch( + "vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_rank" + ) + self.tp_size_patch = patch( + "vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_world_size" + ) + self.mock_tp_rank = self.tp_rank_patch.start() + self.mock_tp_size = self.tp_size_patch.start() + + def tearDown(self): + """Clean up after each test""" + self.config_patch.stop() + self.tp_rank_patch.stop() + self.tp_size_patch.stop() + + def test_complete_workflow(self): + """Test complete workflow from weight creation to processing""" + from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod + + # Create method instance + method = AscendFAQuantAttentionMethod() + + # Create real layer + layer = nn.Module() + layer.num_heads = 32 + layer.num_kv_heads = 1 + + # Step 1: Create weights + method.create_weights(layer) + + # Verify weights were created with correct structure + self.assertTrue(hasattr(layer, "fa_q")) + self.assertTrue(hasattr(layer, "fa_k")) + self.assertTrue(hasattr(layer, "fa_v")) + + self.assertTrue(hasattr(layer.fa_q, "scale")) + self.assertTrue(hasattr(layer.fa_q, "offset")) + self.assertTrue(hasattr(layer.fa_k, "scale")) + self.assertTrue(hasattr(layer.fa_k, "offset")) + self.assertTrue(hasattr(layer.fa_v, "scale")) + self.assertTrue(hasattr(layer.fa_v, "offset")) + + # Step 2: Simulate weight loading + self.mock_tp_rank.return_value = 0 + self.mock_tp_size.return_value = 1 + + # Create dummy weights + q_scale = torch.randn(32, 1) + k_scale = torch.randn(1, 1) + v_scale = torch.randn(1, 1) + q_offset = torch.randint(-128, 127, (32, 1), dtype=torch.int8) + k_offset = torch.randint(-128, 127, (1, 1), dtype=torch.int8) + v_offset = torch.randint(-128, 127, (1, 1), dtype=torch.int8) + + # Load weights using weight_loader + from vllm_ascend.quantization.methods.kv_c8 import weight_loader + + with torch.no_grad(): + weight_loader(layer.fa_q.scale, q_scale) + weight_loader(layer.fa_k.scale, k_scale) + weight_loader(layer.fa_v.scale, v_scale) + weight_loader(layer.fa_q.offset, q_offset) + weight_loader(layer.fa_k.offset, k_offset) + weight_loader(layer.fa_v.offset, v_offset) + + # Verify weights were loaded correctly + self.assertTrue(torch.all(layer.fa_q.scale == q_scale)) + self.assertTrue(torch.all(layer.fa_k.scale == k_scale)) + self.assertTrue(torch.all(layer.fa_v.scale == v_scale)) + + # Step 3: Process after loading + method.process_weights_after_loading(layer) + + # Verify processed parameters + self.assertTrue(hasattr(layer, "fak_descale")) + self.assertTrue(hasattr(layer, "fak_offset")) + self.assertTrue(hasattr(layer, "quant_kscale")) + + +if __name__ == "__main__": + unittest.main(verbosity=2) \ No newline at end of file diff --git a/tests/ut/quantization/test_modelslim_config.py b/tests/ut/quantization/test_modelslim_config.py index 556c8a4a..745b4736 100644 --- a/tests/ut/quantization/test_modelslim_config.py +++ b/tests/ut/quantization/test_modelslim_config.py @@ -24,6 +24,7 @@ class TestAscendModelSlimConfig(TestBase): self.sample_config = { "weight": "INT8", "fa_quant_type": "C8", + "layers.1.fa_k.scale": "C8", "layer1.weight": "INT8", "layer2.weight": "FLOAT", "fused_layer.weight": "FLOAT", @@ -119,6 +120,9 @@ class TestAscendModelSlimConfig(TestBase): # Test with fa_quant_type method = self.ascend_config.get_quant_method( attention_layer, ".attn") + self.assertIs(method, None) + method = self.ascend_config.get_quant_method( + attention_layer, "layers.1.attn") self.assertIs(method, mock_ascend_kvcache.return_value) def test_get_quant_method_for_fused_moe(self): diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index c1bf1d36..c1550a0b 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -612,6 +612,7 @@ class AscendMlaCPImpl(AscendMLAImpl): k_pe: torch.Tensor, block_size: int, attn_metadata: AscendMLAMetadata, + dequant_scale_q_nope=None, ) -> torch.Tensor: decode_meta = attn_metadata.decode assert decode_meta is not None diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 3cb9615b..dd6b018b 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -47,7 +47,7 @@ from vllm_ascend.ops.layer_shard_linear import ( register_all_layers_to_shard_weight_series, ) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla -from vllm_ascend.quantization.methods import AscendW8A8LinearMethod +from vllm_ascend.quantization.methods.w8a8_static import AscendW8A8LinearMethod from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, get_weight_prefetch_method, maybe_trans_nz, weak_ref_tensors from vllm_ascend.worker.npu_input_batch import NPUInputBatch @@ -658,6 +658,7 @@ class DecodeMLAPreprocessResult(NamedTuple): k_nope: torch.Tensor | None = None k_pe: torch.Tensor | None = None decode_q_wo_k_up: torch.Tensor | None = None + dequant_scale_q_nope: torch.Tensor | None = None class PrefillMLAPreprocessResult(NamedTuple): @@ -725,6 +726,12 @@ class AscendMLAImpl(MLAAttentionImpl): self.is_kv_producer = ( self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer ) + self.layer_name = kwargs.get("layer_name") + quant_config = self.vllm_config.quant_config + self.fa_quant_layer = ( + quant_config.enabling_fa_quant(self.vllm_config, self.layer_name) if quant_config is not None else False + ) + self.dtype = torch.int8 if self.fa_quant_layer else self.vllm_config.model_config.dtype self.layer_sharding_kwargs = [] for layer_name in get_ascend_config().layer_sharding or []: if layer_name in kwargs: @@ -775,6 +782,8 @@ class AscendMLAImpl(MLAAttentionImpl): actual_seq_lengths, attn_output, softmax_lse, + dequant_scale_q_nope, + fak_descale_float, ) = param seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list if speculative_config and speculative_config.method == "mtp" and not _EXTRA_CTX.is_draft_model: @@ -793,26 +802,35 @@ class AscendMLAImpl(MLAAttentionImpl): seq_lens_list = seq_lens_list + [0] * (num_tokens - len(seq_lens_list)) torch.npu.graph_task_update_begin(update_stream, handle) - torch_npu.npu_fused_infer_attention_score.out( + extra_args = {} + if dequant_scale_q_nope is not None: + extra_args = { + "query_quant_mode": 3, + "key_quant_mode": 0, + "value_quant_mode": 0, + "dequant_scale_query": dequant_scale_q_nope, + "dequant_scale_key": fak_descale_float, + "dequant_scale_value": fak_descale_float, + } + torch_npu.npu_fused_infer_attention_score_v2.out( q_nope, k_nope, k_nope, query_rope=q_pe, key_rope=k_pe, - num_heads=num_heads, + num_query_heads=num_heads, num_key_value_heads=num_kv_heads, input_layout=input_layout, atten_mask=attn_mask, sparse_mode=sparse_mode, - scale=scale, - antiquant_mode=0, - antiquant_scale=None, + softmax_scale=scale, block_table=block_table, block_size=block_size, - actual_seq_lengths_kv=seq_lens_list, - actual_seq_lengths=actual_seq_lengths, + actual_seq_kvlen=seq_lens_list, + actual_seq_qlen=actual_seq_lengths, workspace=graph_params.workspaces.get(num_tokens), out=[attn_output, softmax_lse], + **extra_args, ) torch.npu.graph_task_update_end(update_stream) @@ -887,6 +905,8 @@ class AscendMLAImpl(MLAAttentionImpl): ) if self.enable_mlapo: self._process_weights_for_fused_mlapo(act_dtype) + elif self.fa_quant_layer: + self._process_weights_for_fused_fa_quant() else: # if mlapo, W_UK_T can't trans nz self.W_UK_T = maybe_trans_nz(self.W_UK_T) @@ -895,6 +915,32 @@ class AscendMLAImpl(MLAAttentionImpl): if is_hidden_layer(layer): post_process_after_loading_for_shard_weight_series(layer) + def _process_weights_for_fused_fa_quant(self): + self.gamma1 = self.q_a_layernorm.weight.data # type: ignore[union-attr] + self.gamma2 = self.kv_a_layernorm.weight.data # type: ignore[union-attr] + + wu_q = self.q_proj.weight.data + + self.wu_q = wu_q + + q_a_proj_fa3 = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous() # type: ignore[union-attr] + + self.wd_q = q_a_proj_fa3 + + kv_a_proj_fa3 = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous() # type: ignore[union-attr] + + self.wd_kv = kv_a_proj_fa3 + + self.dequant_scale_w_uq_qr = self.q_proj.weight_scale.data.view(1, -1).to(torch.float) + q_a_proj_deq_scl = self.fused_qkv_a_proj.weight_scale[: self.q_lora_rank].contiguous() # type: ignore[union-attr] + self.dequant_scale_w_dq = q_a_proj_deq_scl.view(1, -1).to(torch.float) + kv_a_proj_deq_scl = self.fused_qkv_a_proj.weight_scale[self.q_lora_rank :].contiguous() # type: ignore[union-attr] + self.dequant_scale_w_dkv_kr = kv_a_proj_deq_scl.view(1, -1).to(torch.float) + + layer = self.vllm_config.compilation_config.static_forward_context[self.layer_name] + self.quant_kscale = layer.quant_kscale + self.fak_descale_float = layer.fak_descale_float + def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): assert self.fused_qkv_a_proj is not None assert self.q_a_layernorm is not None @@ -1236,6 +1282,7 @@ class AscendMLAImpl(MLAAttentionImpl): k_pe: torch.Tensor, block_size: int, attn_metadata: AscendMLAMetadata, + dequant_scale_q_nope=None, ) -> torch.Tensor: decode_meta = attn_metadata.decode assert decode_meta is not None @@ -1243,7 +1290,15 @@ class AscendMLAImpl(MLAAttentionImpl): # shape of knope/k_pe for npu graph mode should be: # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] actual_seq_lengths = None - if self.enable_kv_nz: + if self.fa_quant_layer: + nz_fmt_last_dim = 16 + k_nope = k_nope.view( + -1, self.num_kv_heads, self.kv_lora_rank // (nz_fmt_last_dim * 2), block_size, nz_fmt_last_dim * 2 + ) + k_pe = k_pe.view( + -1, self.num_kv_heads, self.qk_rope_head_dim // nz_fmt_last_dim, block_size, nz_fmt_last_dim + ) + elif self.enable_kv_nz: nz_fmt_last_dim = 16 k_nope = k_nope.view( -1, self.num_kv_heads, self.kv_lora_rank // nz_fmt_last_dim, block_size, nz_fmt_last_dim @@ -1278,6 +1333,15 @@ class AscendMLAImpl(MLAAttentionImpl): sparse_mode = 3 attn_mask = attn_metadata.decode.attn_mask # type:ignore actual_seq_lengths = decode_meta.actual_seq_lengths_q + elif self.fa_quant_layer: + attn_mask = None + input_layout = "BSND_NBSD" + q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1).contiguous() + q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1).contiguous() + dequant_scale_q_nope = dequant_scale_q_nope.view(num_tokens, 1, self.num_heads) + sparse_mode = 0 + actual_seq_lengths = None + attn_output_shape = (self.num_heads, num_tokens, 1, self.kv_lora_rank) else: # The output layout is set to NBSD to eliminate the need for a # transpose operation after attention. @@ -1299,19 +1363,27 @@ class AscendMLAImpl(MLAAttentionImpl): common_kwargs = { "query_rope": q_pe, "key_rope": k_pe, - "num_heads": self.num_heads, + "num_query_heads": self.num_heads, "num_key_value_heads": self.num_kv_heads, "input_layout": input_layout, "atten_mask": attn_mask, "sparse_mode": sparse_mode, - "scale": self.scale, - "antiquant_mode": 0, - "antiquant_scale": None, + "softmax_scale": self.scale, "block_table": decode_meta.block_table, "block_size": block_size, - "actual_seq_lengths": actual_seq_lengths, - "actual_seq_lengths_kv": decode_meta.seq_lens_list, + "actual_seq_qlen": actual_seq_lengths, + "actual_seq_kvlen": decode_meta.seq_lens_list, } + if self.fa_quant_layer: + extra_fa_args = { + "query_quant_mode": 3, + "key_quant_mode": 0, + "value_quant_mode": 0, + "dequant_scale_query": dequant_scale_q_nope, + "dequant_scale_key": self.fak_descale_float, + "dequant_scale_value": self.fak_descale_float, + } + common_kwargs.update(extra_fa_args) if _EXTRA_CTX.is_draft_model: graph_params = get_draft_graph_params() else: @@ -1325,8 +1397,33 @@ class AscendMLAImpl(MLAAttentionImpl): graph_params.events[num_tokens].append(event) workspace = graph_params.workspaces.get(num_tokens) + attn_output = torch.empty(attn_output_shape, dtype=q_pe.dtype, device=q_pe.device) + softmax_lse = torch.empty(num_tokens, dtype=q_pe.dtype, device=q_pe.device) + attn_params = ( + weak_ref_tensors(q_nope), + weak_ref_tensors(k_nope), + weak_ref_tensors(q_pe), + weak_ref_tensors(k_pe), + self.num_heads, + self.num_kv_heads, + input_layout, + weak_ref_tensors(attn_mask) if attn_mask is not None else None, + sparse_mode, + self.scale, + decode_meta.block_table, + block_size, + decode_meta.seq_lens_list, + actual_seq_lengths, + weak_ref_tensors(attn_output), + weak_ref_tensors(softmax_lse), + ) + if self.fa_quant_layer: + attn_params = attn_params + (dequant_scale_q_nope, self.fak_descale_float) # type: ignore + else: + attn_params = attn_params + (None, None) # type: ignore + if workspace is None: - workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( + workspace = torch_npu._npu_fused_infer_attention_score_v2_get_max_workspace( q_nope, k_nope, k_nope, **common_kwargs ) if _EXTRA_CTX.is_draft_model: @@ -1334,38 +1431,16 @@ class AscendMLAImpl(MLAAttentionImpl): else: update_graph_params_workspaces(num_tokens, workspace) - attn_output = torch.empty(attn_output_shape, dtype=q_nope.dtype, device=q_nope.device) - softmax_lse = torch.empty(num_tokens, dtype=q_nope.dtype, device=q_nope.device) - - graph_params.attn_params[num_tokens].append( - ( - weak_ref_tensors(q_nope), - weak_ref_tensors(k_nope), - weak_ref_tensors(q_pe), - weak_ref_tensors(k_pe), - self.num_heads, - self.num_kv_heads, - input_layout, - weak_ref_tensors(attn_mask) if attn_mask is not None else None, - sparse_mode, - self.scale, - decode_meta.block_table, - block_size, - decode_meta.seq_lens_list, - actual_seq_lengths, - weak_ref_tensors(attn_output), - weak_ref_tensors(softmax_lse), - ) - ) + graph_params.attn_params[num_tokens].append(attn_params) torch.npu.graph_task_group_begin(stream) - torch_npu.npu_fused_infer_attention_score.out( + torch_npu.npu_fused_infer_attention_score_v2.out( q_nope, k_nope, k_nope, **common_kwargs, workspace=workspace, out=[attn_output, softmax_lse] ) handle = torch.npu.graph_task_group_end(stream) graph_params.handles[num_tokens].append(handle) else: - attn_output, _ = torch_npu.npu_fused_infer_attention_score(q_nope, k_nope, k_nope, **common_kwargs) + attn_output, _ = torch_npu.npu_fused_infer_attention_score_v2(q_nope, k_nope, k_nope, **common_kwargs) return self._v_up_proj(attn_output) @@ -1381,55 +1456,81 @@ class AscendMLAImpl(MLAAttentionImpl): sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1]) decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1] - decode_q_nope = torch.empty( - (hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_nope.shape[-1]), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - decode_q_pe = torch.empty( - (hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_pe.shape[-1]), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) + dequant_scale_q_nope = None + if self.fa_quant_layer: + quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) + decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe, dequant_scale_q_nope = torch_npu.npu_mla_prolog_v2( + quantized_x, + self.wd_q, + self.wu_q, + self.W_UK_T, + self.wd_kv, + self.gamma1, + self.gamma2, + sin, + cos, + attn_metadata.slot_mapping[:bsz].to(torch.int64), + decode_k_nope, + decode_k_pe, + dequant_scale_x=pertoken_scale.view(-1, 1), + dequant_scale_w_dq=self.dequant_scale_w_dq, + dequant_scale_w_uq_qr=self.dequant_scale_w_uq_qr, + dequant_scale_w_dkv_kr=self.dequant_scale_w_dkv_kr, + quant_scale_ckv=self.quant_kscale, + cache_mode="PA_NZ", + ) + else: + decode_q_nope = torch.empty( + (hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_nope.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + decode_q_pe = torch.empty( + (hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_pe.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) - torch.ops._C_ascend.mla_preprocess( - hidden_states, - self.wd_qkv, - self.deq_scale_qkv, - self.gamma1, - self.beta1, - self.wu_q, - self.qb_deq_scl, - self.gamma2, - cos, - sin, - self.W_UK_T, - decode_k_nope, - decode_k_pe, - attn_metadata.slot_mapping[:bsz], - quant_scale0=self.quant_scale0, - quant_offset0=self.quant_offset0, - bias0=self.quant_bias_qkv, - quant_scale1=self.quant_scale1, - quant_offset1=self.quant_offset1, - bias1=self.qb_qt_bias, - ctkv_scale=self.ctkv_scale, - q_nope_scale=self.q_nope_scale, - cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv", - quant_mode="per_tensor_quant_asymm", - q_out0=decode_q_nope, - kv_cache_out0=decode_k_nope, - q_out1=decode_q_pe, - kv_cache_out1=decode_k_pe, - enable_inner_out=False, - inner_out=torch.tensor([], device=hidden_states.device), - ) - decode_q_nope = decode_q_nope.view(bsz, self.num_heads, self.kv_lora_rank) - decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) + torch.ops._C_ascend.mla_preprocess( + hidden_states, + self.wd_qkv, + self.deq_scale_qkv, + self.gamma1, + self.beta1, + self.wu_q, + self.qb_deq_scl, + self.gamma2, + cos, + sin, + self.W_UK_T, + decode_k_nope, + decode_k_pe, + attn_metadata.slot_mapping[:bsz], + quant_scale0=self.quant_scale0, + quant_offset0=self.quant_offset0, + bias0=self.quant_bias_qkv, + quant_scale1=self.quant_scale1, + quant_offset1=self.quant_offset1, + bias1=self.qb_qt_bias, + ctkv_scale=self.ctkv_scale, + q_nope_scale=self.q_nope_scale, + cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv", + quant_mode="per_tensor_quant_asymm", + q_out0=decode_q_nope, + kv_cache_out0=decode_k_nope, + q_out1=decode_q_pe, + kv_cache_out1=decode_k_pe, + enable_inner_out=False, + inner_out=torch.tensor([], device=hidden_states.device), + ) + decode_q_nope = decode_q_nope.view(bsz, self.num_heads, self.kv_lora_rank) + decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) decode_q_nope, decode_q_pe = self.reorg_decode_q(decode_q_nope, decode_q_pe) - decode_preprocess_res = DecodeMLAPreprocessResult(decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe) + decode_preprocess_res = DecodeMLAPreprocessResult( + decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe, dequant_scale_q_nope=dequant_scale_q_nope + ) return decode_preprocess_res, None def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, attn_metadata): @@ -1576,7 +1677,7 @@ class AscendMLAImpl(MLAAttentionImpl): o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, device=hidden_states.device) # MLA Preprocess - if self.enable_mlapo and attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: + if self.fa_quant_layer or (self.enable_mlapo and attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS): hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( hidden_states.contiguous(), need_gather_q_kv ) @@ -1596,6 +1697,7 @@ class AscendMLAImpl(MLAAttentionImpl): decode_preprocess_res.k_pe, kv_cache[0].shape[1], attn_metadata, + decode_preprocess_res.dequant_scale_q_nope, ) o_proj_input[:num_decode_tokens] = output_decode diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index 415bb6c9..ea82322b 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -66,7 +66,7 @@ from vllm_ascend.distributed.kv_transfer.utils.utils import ( kv_alltoall_and_rearrange, parallel_info, ) -from vllm_ascend.utils import npu_stream_switch +from vllm_ascend.utils import npu_stream_switch, trans_nd_to_nz # isort: off if TYPE_CHECKING: @@ -124,6 +124,9 @@ class SendTask: # pd_head_ratio > 1 use k_cache: torch.Tensor | None = None v_cache: torch.Tensor | None = None + # kv cache quantization layer use + k_quant_cache: torch.Tensor | None = None + v_quant_cache: torch.Tensor | None = None layer_idx: int = 0 layer_name: str = "" # trans block info @@ -210,6 +213,9 @@ class KVCacheSendingLayerThread(threading.Thread): use_mla: bool, k_buffer: torch.Tensor, v_buffer: torch.Tensor, + enable_kv_quant: bool, + k_quant_buffer: torch.Tensor | None, + v_quant_buffer: torch.Tensor | None, resharding_stream: torch.npu.Stream, callback_func: Callable[..., None] = lambda x: None, ): @@ -232,6 +238,9 @@ class KVCacheSendingLayerThread(threading.Thread): self.send_queue = queue.Queue[SendTask]() self.k_buffer = k_buffer self.v_buffer = v_buffer + self.enable_kv_quant = enable_kv_quant + self.k_quant_buffer = k_quant_buffer + self.v_quant_buffer = v_quant_buffer self.ready_event = ready_event self.callback_func = callback_func @@ -325,19 +334,43 @@ class KVCacheSendingLayerThread(threading.Thread): grouped_remote_block_ids, grouped_local_block_ids = group_concurrent_contiguous( remote_block_ids, local_block_ids ) - for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( - zip(layer_local_kv_base_addr, layer_remote_kv_base_addr) - ): - block_len = block_lens[k] - for group_remote_block_id, group_local_block_id in zip( - grouped_remote_block_ids, grouped_local_block_ids + # kv cache quantization scenario + if self.enable_kv_quant and send_task.k_quant_cache is not None: + assert len(block_lens) == 2, "Quantization block length must be 2!" + quant_block_lens = [block_lens[0] // 2, block_lens[1]] + layer_local_quant_kv_addr = [self.k_quant_buffer.data_ptr(), self.v_quant_buffer.data_ptr()] + rearrange_block_ids = send_task.group_rearrange_block_ids[layer_group_idx] + # eg:[5,6,7,9] -> {5:0, 6:1, 7:2, 9:3} + rearrange_block_dict = { + value: index + for index, value in enumerate(rearrange_block_ids) # type:ignore + } + for block_len, src_layer_base_addr, dst_layer_base_addr in zip( + quant_block_lens, layer_local_quant_kv_addr, layer_remote_kv_base_addr ): - src = src_layer_base_addr + group_local_block_id[0] * block_len - dst = dst_layer_base_addr + group_remote_block_id[0] * block_len - length = len(group_local_block_id) * block_len - src_list.append(src) - dst_list.append(dst) - length_list.append(length) + for group_remote_block_id, group_local_block_id in zip( + grouped_remote_block_ids, grouped_local_block_ids + ): + src = src_layer_base_addr + rearrange_block_dict[group_local_block_id[0]] * block_len + dst = dst_layer_base_addr + group_remote_block_id[0] * block_len + length = len(group_local_block_id) * block_len + src_list.append(src) + dst_list.append(dst) + length_list.append(length) + else: + for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( + zip(layer_local_kv_base_addr, layer_remote_kv_base_addr) + ): + block_len = block_lens[k] + for group_remote_block_id, group_local_block_id in zip( + grouped_remote_block_ids, grouped_local_block_ids + ): + src = src_layer_base_addr + group_local_block_id[0] * block_len + dst = dst_layer_base_addr + group_remote_block_id[0] * block_len + length = len(group_local_block_id) * block_len + src_list.append(src) + dst_list.append(dst) + length_list.append(length) else: rearrange_block_ids = send_task.group_rearrange_block_ids[layer_group_idx] rearrange_block_dict = { @@ -380,6 +413,14 @@ class KVCacheSendingLayerThread(threading.Thread): value = value.view(-1, key.shape[-1]) # type:ignore self.k_buffer[: key.shape[0]].copy_(key) # [:4, 128] -> self.v_buffer[: value.shape[0]].copy_(value) + if send_task.k_quant_cache is not None: + with npu_stream_switch(self.resharding_stream): + key_quant = send_task.k_quant_cache + key_quant = key_quant.view(-1, key_quant.shape[-1]) # type:ignore + self.k_quant_buffer[: key_quant.shape[0]].copy_(key_quant) + value_quant = send_task.v_quant_cache + value_quant = value_quant.view(-1, value_quant.shape[-1]) # type:ignore + self.v_quant_buffer[: value_quant.shape[0]].copy_(value_quant) # Merge transmission tasks of the same session session_meta: dict[str, TransferMeta] = {} @@ -395,7 +436,9 @@ class KVCacheSendingLayerThread(threading.Thread): session_meta[session_id].length.extend(length_list) session_meta[session_id].req_ids.append(req_id) - if self.pd_head_ratio == 1: + if send_task.k_quant_cache is not None: + self.resharding_stream.synchronize() + elif self.pd_head_ratio == 1: """ Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang. This issue will be fixed in CANN version 8.5.rc1. @@ -628,7 +671,7 @@ class MooncakeLayerwiseConnector(KVConnectorBase_V1, SupportsHMA): self.connector_worker.wait_for_layer_load(layer_name) def save_kv_layer( - self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs + self, layer_name: str, kv_layer: list[torch.Tensor], attn_metadata: "AttentionMetadata", **kwargs ) -> None: """MooncakeLayerwiseConnector does not save explicitly.""" assert self.connector_worker is not None @@ -962,10 +1005,13 @@ class MooncakeLayerwiseConnectorWorker: self.layer_metadata: dict[str, LayerMetadata] = {} self.attn_resharding_group_idx = set[int]() + self.enable_kv_quant = ( + vllm_config.quant_config.enable_fa_quant if vllm_config.quant_config is not None else False + ) self.pd_head_ratio = get_ascend_config().pd_head_ratio self.num_head_replica = get_ascend_config().num_head_replica self.resharding_stream = None - if self.pd_head_ratio > 1: + if self.pd_head_ratio > 1 or self.enable_kv_quant: self.resharding_stream = torch.npu.Stream() self.remote_poller = zmq.Poller() # type: ignore @@ -985,11 +1031,16 @@ class MooncakeLayerwiseConnectorWorker: self.timeout = 1.0 # seconds self.k_buffer: torch.Tensor | None = None self.v_buffer: torch.Tensor | None = None + # TODO(kunpengW-code): Reuse k_buffer, v_buffer + self.k_quant_buffer: torch.Tensor | None = None + self.v_quant_buffer: torch.Tensor | None = None - def create_kv_buffer(self, first_kv_cache): + def create_kv_buffer(self, first_kv_cache_tuple): + alignment = 2 * 1024 * 1024 + buffer_list = [] + first_kv_cache = first_kv_cache_tuple[0] if self.pd_head_ratio > 1: # regesit kv buffer for tp inequal - alignment = 2 * 1024 * 1024 self.k_buffer = torch.zeros( first_kv_cache.numel() + alignment, dtype=first_kv_cache.dtype, device=first_kv_cache.device ) @@ -1002,18 +1053,34 @@ class MooncakeLayerwiseConnectorWorker: self.v_buffer = align_memory(self.v_buffer, alignment)[: first_kv_cache.numel()].view( -1, first_kv_cache.shape[-1] ) + buffer_list.append(self.k_buffer) + buffer_list.append(self.v_buffer) + if self.enable_kv_quant: + quant_k_cache_numel = first_kv_cache_tuple[0].numel() // 2 + self.k_quant_buffer = torch.zeros( + quant_k_cache_numel + alignment, dtype=torch.int8, device=first_kv_cache.device + ) + self.k_quant_buffer = align_memory(self.k_quant_buffer, alignment)[:quant_k_cache_numel].view( + -1, first_kv_cache.shape[-1] + ) + quant_v_cache_numel = first_kv_cache_tuple[1].numel() + self.v_quant_buffer = torch.zeros( + quant_v_cache_numel + alignment, dtype=first_kv_cache.dtype, device=first_kv_cache.device + ) + self.v_quant_buffer = align_memory(self.v_quant_buffer, alignment)[:quant_v_cache_numel].view( + -1, first_kv_cache_tuple[1].shape[-1] + ) + buffer_list.append(self.k_quant_buffer) + buffer_list.append(self.v_quant_buffer) - for tensor in (self.k_buffer, self.v_buffer): - assert tensor.data_ptr() % alignment == 0, ( - "The address of the registered kv cache should be aligned to 2M" - ) - ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel()) - logger.info( - f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} " - f"{tensor.numel()} {ret_value=}" - ) - if ret_value != 0: - raise RuntimeError("Mooncake memory registration failed. ") + for tensor in buffer_list: + assert tensor.data_ptr() % alignment == 0, "The address of the registered kv cache should be aligned to 2M" + ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel()) + logger.info( + f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}" + ) + if ret_value != 0: + raise RuntimeError("Mooncake memory registration failed. ") def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data.""" @@ -1042,8 +1109,8 @@ class MooncakeLayerwiseConnectorWorker: ptrs = [] lengths = [] - use_resharding_buffer = False - resharding_buffer = None + use_kv_buffer = False + kv_buffer = None for layer_name, kv_cache_tuple in kv_caches.items(): if isinstance(kv_cache_tuple, (list, tuple)) is False: kv_cache_tuple = [kv_cache_tuple] @@ -1051,12 +1118,13 @@ class MooncakeLayerwiseConnectorWorker: layer_kv_cache_spec = kv_cache_groups[layer_kv_group_id].kv_cache_spec if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name] - if self.pd_head_ratio > 1 and (isinstance(layer_kv_cache_spec, (FullAttentionSpec, SlidingWindowSpec))): + if ( + self.pd_head_ratio > 1 and (isinstance(layer_kv_cache_spec, (FullAttentionSpec, SlidingWindowSpec))) + ) or self.enable_kv_quant: self.attn_resharding_group_idx.add(layer_kv_group_id) - if use_resharding_buffer is False: - use_resharding_buffer = True - resharding_buffer = kv_cache_tuple[0] - self.resharding_stream = torch.npu.Stream() + if use_kv_buffer is False: + use_kv_buffer = True + kv_buffer = kv_cache_tuple single_layer_meta = LayerMetadata([], [], [], []) for single_kv_cache in kv_cache_tuple: block_start_rank = 1 @@ -1092,8 +1160,8 @@ class MooncakeLayerwiseConnectorWorker: lengths.append(kv_cache_tensor.size) global_te.register_buffer(ptrs, lengths) - if use_resharding_buffer: - self.create_kv_buffer(resharding_buffer) + if use_kv_buffer: + self.create_kv_buffer(kv_buffer) num_attn_module = 2 if self.vllm_config.model_config.hf_text_config.model_type == "longcat_flash" else 1 mtp_layer_name = "" @@ -1133,6 +1201,9 @@ class MooncakeLayerwiseConnectorWorker: use_mla=self.use_mla, k_buffer=self.k_buffer, v_buffer=self.v_buffer, + enable_kv_quant=self.enable_kv_quant, + k_quant_buffer=self.k_quant_buffer, + v_quant_buffer=self.v_quant_buffer, resharding_stream=self.resharding_stream, callback_func=self.send_done_send_signal, ) @@ -1380,7 +1451,7 @@ class MooncakeLayerwiseConnectorWorker: metadata.requests[req_id] = update_metadata[req_id] # update send task trans block info - if self.pd_head_ratio != 1: + if self.pd_head_ratio != 1 or self.enable_kv_quant: send_task = metadata.send_task send_task.group_rearrange_block_ids = [[] for _ in range(self.num_kv_cache_groups)] send_task.group_num_blocks = [0 for _ in range(self.num_kv_cache_groups)] @@ -1388,7 +1459,7 @@ class MooncakeLayerwiseConnectorWorker: send_task.group_block_table = [None for _ in range(self.num_kv_cache_groups)] send_task.group_block_len_tensor = [None for _ in range(self.num_kv_cache_groups)] send_task.group_seq_start_tensor = [None for _ in range(self.num_kv_cache_groups)] - device = self.k_buffer.device # type: ignore + device = self.k_buffer.device if self.k_buffer is not None else self.k_quant_buffer.device # type: ignore for i in self.attn_resharding_group_idx: send_task.group_rearrange_block_ids[i].extend( sorted( @@ -1415,7 +1486,7 @@ class MooncakeLayerwiseConnectorWorker: def save_kv_layer( self, layer_name: str, - kv_layer: tuple[torch.Tensor, torch.Tensor], + kv_layer: list[torch.Tensor], attn_metadata: "AttentionMetadata", connector_metadata: MooncakeLayerwiseConnectorMetadata, **kwargs, @@ -1490,12 +1561,51 @@ class MooncakeLayerwiseConnectorWorker: values = values.reshape(-1, *kv_layer[1].shape[2:]) (keys, values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys, values) + quant_keys = None + quant_values = None + if self.enable_kv_quant and self.current_layer in self.vllm_config.quant_config.kvcache_quant_layers: + assert self.resharding_stream is not None + with npu_stream_switch(self.resharding_stream): + reshape_cache_event.wait() + device = self.k_quant_buffer.device # type: ignore + layer = self.vllm_config.compilation_config.static_forward_context[layer_name] + # Initialize buffers + # [num_tokens, kv_head, head_dim] + quant_key = torch.empty( + (send_task.group_num_tokens[layer_group_idx], *kv_layer[0].size()[-2:]), + dtype=kv_layer[0].dtype, + device=device, + ) + quant_values = torch.empty( + (send_task.group_num_tokens[layer_group_idx], *kv_layer[1].size()[-2:]), + dtype=kv_layer[1].dtype, + device=device, + ) + + # Load cache data into buffers + torch_npu.atb.npu_paged_cache_load( + kv_layer[0], + kv_layer[1], + send_task.group_block_table[layer_group_idx], + send_task.group_block_len_tensor[layer_group_idx], + seq_starts=send_task.group_seq_start_tensor[layer_group_idx], + key=quant_key, + value=quant_values, + ) + quant_keys = torch.ops.vllm.quantize( + quant_key, layer.fak_descale, layer.fak_descale_reciprocal, layer.fak_offset + ) + quant_keys = self.get_nz_cache(quant_keys, layer_group_idx) + quant_values = self.get_nz_cache(quant_values, layer_group_idx) + assert self.kv_send_layer_thread is not None assert reshape_cache_event is not None layer_send_task = SendTask( wait_event=reshape_cache_event, k_cache=keys, v_cache=values, + k_quant_cache=quant_keys, + v_quant_cache=quant_values, layer_idx=self.current_layer, layer_name=layer_name, group_rearrange_block_ids=send_task.group_rearrange_block_ids, @@ -1510,6 +1620,15 @@ class MooncakeLayerwiseConnectorWorker: self.kv_send_layer_thread.send_queue.put(layer_send_task) self.current_layer += 1 + # NOTE: Due to the FIA operator constraints, the expected kv cache is ND format, NZ shape, + # while the npu_format_cast method only modifies the memory layout, we manually convert it to NZ shape here + def get_nz_cache(self, cache_tensor: torch.Tensor, layer_group_idx: int): + head_num, head_dim = cache_tensor.shape[-2], cache_tensor.shape[-1] + cache_tensor = cache_tensor.view(-1, self.block_size[layer_group_idx], head_num * head_dim) + cache_tensor = trans_nd_to_nz(cache_tensor) + cache_tensor = cache_tensor.reshape(-1, head_num, head_dim) + return cache_tensor + def _get_remote_socket(self, remote_host: str, remote_handshake_port: int) -> zmq.Socket: # type: ignore """Get a socket to the remote host.""" remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port) diff --git a/vllm_ascend/ops/mla.py b/vllm_ascend/ops/mla.py index c7ae6046..de2e2fa4 100644 --- a/vllm_ascend/ops/mla.py +++ b/vllm_ascend/ops/mla.py @@ -120,6 +120,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa, kv_a_layernorm=mla_modules.kv_a_layernorm, o_proj=mla_modules.o_proj, + layer_name=f"{prefix}.attn", ) original_process_weights = self.mla_attn.process_weights_after_loading diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 41463c7b..ad6d1d9b 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -528,3 +528,17 @@ # Future Plan: # Remove this patch when: # design a dispatch mechanism for batch_memcpy_kernel. +# +# ** 23. File: worker/patch_weight_utils.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.deepseek_v2.DeepseekV2ForCausalLM.load_weights` +# Why: +# The C8 weight quantized by modelslim will modify the model structure, +# and the scale and offset required for kvcache quantization will increase. +# In addition, the names of the quantization parameters are different from +# those in the community. +# How: +# we have enhanced the maybe_remap_kv_scale_name function. +# Future Plan: +# The maybe_remap_kv_scale_name function of the community is reconstructed to support +# multiple backends. diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index 06aa5d2a..11982294 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -22,6 +22,7 @@ if HAS_TRITON: # isort: off +import vllm_ascend.patch.worker.patch_weight_utils # noqa import vllm_ascend.patch.platform.patch_sched_yield # noqa import vllm_ascend.patch.worker.patch_unquantized_gemm # noqa import vllm_ascend.patch.worker.patch_bert # noqa diff --git a/vllm_ascend/patch/worker/patch_weight_utils.py b/vllm_ascend/patch/worker/patch_weight_utils.py new file mode 100644 index 00000000..f2e50c43 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_weight_utils.py @@ -0,0 +1,86 @@ +import sys +from typing import Any + +from vllm.logger import init_logger +from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name + +logger = init_logger(__name__) + + +class ImportPatchDecorator: + """Import patch decorator""" + + _patches: dict[str, Any] = {} + + @classmethod + def register(cls, module_name): + """Decorator for registering module patches""" + + def decorator(func): + cls._patches[module_name] = func + return func + + return decorator + + @classmethod + def apply_patches(cls): + """Apply all patches""" + for module_name, patch_func in cls._patches.items(): + if module_name in sys.modules: + module = sys.modules[module_name] + try: + patch_func(module) + except Exception as e: + logger.error(f"Patch application failed {module_name}: {e}") + + +@ImportPatchDecorator.register("vllm.model_executor.models.deepseek_v2") +def patch_deepseek(module): + ori_maybe_remap_kv_scale_name = maybe_remap_kv_scale_name + + def new_remap(name: str, params_dict: dict): + name = ori_maybe_remap_kv_scale_name(name, params_dict) + + replace_scale_names = ["fa_q.scale", "fa_k.scale", "fa_v.scale", "fa_q.offset", "fa_k.offset", "fa_v.offset"] + + for scale_name in replace_scale_names: + if name.endswith(scale_name): + remap_name = name.replace(scale_name, f"mla_attn.mla_attn.{scale_name}") + if remap_name in params_dict: + return remap_name + else: + return remap_name.replace(".mla_attn", "") + + return name + + if hasattr(module, "maybe_remap_kv_scale_name"): + module._original_maybe_remap_kv_scale_name = module.maybe_remap_kv_scale_name + module.maybe_remap_kv_scale_name = new_remap + + +@ImportPatchDecorator.register("vllm.model_executor.model_loader.weight_utils") +def patch_weight_utils(module): + if "vllm.model_executor.models.deepseek_v2" in sys.modules: + deepseek = sys.modules["vllm.model_executor.models.deepseek_v2"] + if hasattr(deepseek, "maybe_remap_kv_scale_name"): + module.maybe_remap_kv_scale_name = deepseek.maybe_remap_kv_scale_name + + +original_import = __builtins__["__import__"] # type: ignore + + +def patched_import(name, globals=None, locals=None, fromlist=(), level=0): + module = original_import(name, globals, locals, fromlist, level) + + if name in ImportPatchDecorator._patches: + try: + ImportPatchDecorator._patches[name](module) + except Exception as e: + logger.error(f"Patch application failed during import {name}: {e}") + + return module + + +__builtins__["__import__"] = patched_import + +ImportPatchDecorator.apply_patches() diff --git a/vllm_ascend/quantization/methods/__init__.py b/vllm_ascend/quantization/methods/__init__.py index 38840295..2e91ba2d 100644 --- a/vllm_ascend/quantization/methods/__init__.py +++ b/vllm_ascend/quantization/methods/__init__.py @@ -32,10 +32,11 @@ from typing import Any # Import base classes from .base import AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme, QuantType +# Import all scheme classes for external access +from .kv_c8 import AscendFAQuantAttentionMethod + # Import registry functions from .registry import get_scheme_class, register_scheme - -# Import all scheme classes for external access from .w4a4_flatquant import AscendW4A4FlatQuantDynamicLinearMethod from .w4a4_laos_dynamic import AscendW4A4LaosDynamicLinearMethod from .w4a8 import AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod @@ -77,4 +78,5 @@ __all__ = [ "AscendW4A16FusedMoEMethod", "AscendW4A4FlatQuantDynamicLinearMethod", "AscendW4A4LaosDynamicLinearMethod", + "AscendFAQuantAttentionMethod", ] diff --git a/vllm_ascend/quantization/methods/kv_c8.py b/vllm_ascend/quantization/methods/kv_c8.py new file mode 100644 index 00000000..8a700484 --- /dev/null +++ b/vllm_ascend/quantization/methods/kv_c8.py @@ -0,0 +1,65 @@ +import torch +from vllm.config import get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size + +from .registry import register_scheme + + +def weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor): + """fa_q weight loader.""" + if param.numel() == 1 and loaded_weight.numel() == 1: + param.data.fill_(loaded_weight.item()) + else: + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + shard_size = loaded_weight.shape[0] // tp_size + loaded_weight = loaded_weight.narrow(0, shard_size * tp_rank, shard_size) + assert param.size() == loaded_weight.size(), ( + f"Attempted to load weight ({loaded_weight.size()}) into parameter ({param.size()}) when TP is ({tp_size})" + ) + + param.data.copy_(loaded_weight) + + +@register_scheme("FAKQuant", "attention") +class AscendFAQuantAttentionMethod: + def __init__(self): + self.transpose_weight = True + self.printFlag = False + vllm_config = get_current_vllm_config() + config = vllm_config.model_config.hf_config + self.kv_lora_rank = getattr(config, "kv_lora_rank", 0) + self.qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) + + def create_weights(self, layer: torch.nn.Module) -> None: + extra_module_names = ["fa_q", "fa_k", "fa_v"] + for name in extra_module_names: + setattr(layer, name, torch.nn.Module()) + params_dict = {} + dtype = torch.get_default_dtype() + params_dict["fa_q.scale"] = torch.empty((layer.num_heads, 1), dtype=dtype) + params_dict["fa_k.scale"] = torch.empty((layer.num_kv_heads, 1), dtype=dtype) + params_dict["fa_v.scale"] = torch.empty((layer.num_kv_heads, 1), dtype=dtype) + params_dict["fa_q.offset"] = torch.empty((layer.num_heads, 1), dtype=torch.int8) + params_dict["fa_k.offset"] = torch.empty((layer.num_kv_heads, 1), dtype=torch.int8) + params_dict["fa_v.offset"] = torch.empty((layer.num_kv_heads, 1), dtype=torch.int8) + + for name, weight in params_dict.items(): + module_name, weight_name = name.rsplit(".", 1) + module = getattr(layer, module_name) + weight_param = torch.nn.Parameter(weight, requires_grad=False) + module.register_parameter(weight_name, weight_param) + # When loading weights, segment them according to TP + weight_param.weight_loader = weight_loader + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + fa_k_scale = torch.squeeze(layer.fa_k.scale).unsqueeze(0) + layer.fak_descale_float = torch.nn.Parameter(fa_k_scale.to(torch.float), requires_grad=False) + layer.fak_descale = torch.nn.Parameter(fa_k_scale, requires_grad=False) + layer.fak_descale_reciprocal = 1.0 / torch.nn.Parameter(fa_k_scale, requires_grad=False) + fa_k_offset = torch.squeeze(layer.fa_k.offset).unsqueeze(0) + layer.fak_offset = torch.nn.Parameter(fa_k_offset.to(layer.fak_descale.dtype), requires_grad=False) + + repeated_quant_kscale = fa_k_scale.repeat(self.kv_lora_rank) + layer.quant_kscale = repeated_quant_kscale.view(1, self.kv_lora_rank) + layer.quant_kscale = 1.0 / torch.nn.Parameter(layer.quant_kscale.to(torch.float), requires_grad=False) diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index c682856b..82b3d279 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -24,6 +24,7 @@ configs generated by the ModelSlim tool, along with model-specific mappings. import glob import json import os +import re from collections.abc import Mapping from types import MappingProxyType from typing import Any, Optional @@ -31,6 +32,7 @@ from typing import Any, Optional import torch from vllm.config import get_current_vllm_config from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization import register_quantization_config @@ -38,7 +40,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf from vllm.model_executor.layers.vocab_parallel_embedding import UnquantizedEmbeddingMethod, VocabParallelEmbedding from vllm.model_executor.models.utils import WeightsMapper -from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD +from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, calc_split_factor from .methods import get_scheme_class @@ -438,6 +440,7 @@ class AscendModelSlimConfig(QuantizationConfig): new_k = k.replace("weight_packed", "weight") extra_quant_dict[new_k] = self.quant_description[k] self.quant_description.update(extra_quant_dict) + self._add_kvcache_quant_metadata() def __repr__(self) -> str: return "AscendModelSlimConfig:\n" + super().__repr__() @@ -509,8 +512,6 @@ class AscendModelSlimConfig(QuantizationConfig): self.packed_modules_mapping = packed_modules_model_mapping[model_type] prefix = self.quant_prefix_mapper(model_type, prefix) - from vllm.model_executor.layers.attention import Attention - if model_type != "kimi_k2": if prefix.startswith("language_model"): prefix = prefix.split(".", 1)[-1] @@ -522,11 +523,7 @@ class AscendModelSlimConfig(QuantizationConfig): return AscendUnquantizedLinearMethod() scheme = create_scheme_for_layer(self.quant_description, prefix, "linear", self.packed_modules_mapping) return AscendLinearMethod(scheme) - elif ( - isinstance(layer, Attention) - and "fa_quant_type" in self.quant_description - and self.quant_description["fa_quant_type"] is not None - ): + elif isinstance(layer, AttentionLayerBase) and self.is_fa_quant_layer(prefix): scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping) return AscendKVCacheMethod(scheme) elif isinstance(layer, FusedMoE): @@ -573,6 +570,39 @@ class AscendModelSlimConfig(QuantizationConfig): assert is_skipped is not None return is_skipped + def is_fa_quant_layer(self, prefix): + if self.enable_fa_quant: + layer_id_str = "".join(re.findall(r"\.(\d+)\.", prefix)) + if layer_id_str.isdigit() and int(layer_id_str) in self.kvcache_quant_layers: + return True + return False + + def enabling_fa_quant(self, vllm_config, layer_name) -> bool: + is_decode_instance = ( + vllm_config.kv_transfer_config is not None + and vllm_config.kv_transfer_config.is_kv_consumer + and not vllm_config.kv_transfer_config.is_kv_producer + ) + return bool(is_decode_instance and self.is_fa_quant_layer(layer_name)) + + def get_kv_quant_dtype(self, layer_name, cache_dtype, model_config): + if self.enable_fa_quant and self.is_fa_quant_layer(layer_name): + ori_dtype = model_config.dtype + quant_dtype = torch.int8 + # For MLA models like deepseek, we only quantify K cache to ensure accuracy + if model_config.use_mla: + return quant_dtype, ori_dtype + else: + return quant_dtype, quant_dtype + return cache_dtype, cache_dtype + + def get_kv_quant_split_factor(self, layer_name, kv_head_dim_list): + if self.enable_fa_quant and self.is_fa_quant_layer(layer_name): + k_quant_head_dim = kv_head_dim_list[0] + v_quant_head_dim = kv_head_dim_list[1] * 2 + kv_head_dim_list = [k_quant_head_dim, v_quant_head_dim] + return calc_split_factor(kv_head_dim_list) + def maybe_update_config(self, model_name: str, revision: str | None = None) -> None: """Load the ModelSlim quantization config from model directory. @@ -606,6 +636,7 @@ class AscendModelSlimConfig(QuantizationConfig): with open(config_path) as f: self.quant_description = json.load(f) self._apply_extra_quant_adaptations() + self._add_kvcache_quant_metadata() return # Collect diagnostic info for the error message @@ -678,3 +709,13 @@ class AscendModelSlimConfig(QuantizationConfig): def get_scaled_act_names(self) -> list[str]: return [] + + def _add_kvcache_quant_metadata(self): + fa_quant_type = self.quant_description.get("fa_quant_type", "") + self.enable_fa_quant = fa_quant_type != "" + self.kvcache_quant_layers = [] + if self.enable_fa_quant: + for key in self.quant_description: + if "fa_k.scale" in key: + _id = "".join(re.findall(r"\.(\d+)\.", key)) + self.kvcache_quant_layers.append(int(_id)) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 66d94475..458b95cd 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1211,3 +1211,36 @@ def get_rope_dim(vllm_config): rope_dim = int(model_config.hf_text_config.rotary_dim) return rope_dim + + +def calc_split_factor(num_list: list[int]): + total = sum(num_list) + split_factor_list = [] + for num in num_list: + split_factor_list.append(total / num) + return split_factor_list + + +# NOTE: The last two dimensions of ND are transferred to NZ +def trans_nd_to_nz(cache_tensor: torch.Tensor): + assert len(cache_tensor.shape) >= 2 + batch = cache_tensor.shape[:-2] + a, b = cache_tensor.shape[-2], cache_tensor.shape[-1] + + dtype = cache_tensor.dtype + if dtype == torch.int8: + a0, b0 = 16, 32 + else: + a0, b0 = 16, 16 + + nz_shape = list(batch) + [math.ceil(b / b0), math.ceil(a / a0), a0, b0] + + # Generate the axis order for the transpose operation. + offset = len(cache_tensor.shape) - 2 + base = [2, 0, 1, 3] + array_trans = [i for i in range(offset)] + [i + offset for i in base] + # Perform shape transformation and transpose operation. + *_, n1, m1, m0, n0 = nz_shape + cache_tensor = cache_tensor.reshape(nz_shape[:-4] + [m1, m0, n1, n0]) + cache_tensor = cache_tensor.permute(*array_trans) + return cache_tensor diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index fc02fc0a..8fc168ac 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -117,6 +117,7 @@ from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer from vllm_ascend.utils import ( + calc_split_factor, check_gdn_layer, enable_sp, enable_sp_by_pass, @@ -2683,12 +2684,6 @@ class NPUModelRunner(GPUModelRunner): # as it only support the 0-dim of kv_cache is `num_blocks`. # For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim # and rope head dim. - if self.model_config.use_mla: - head_size = ( - self.model_config.hf_text_config.qk_rope_head_dim - + self.model_config.hf_text_config.kv_lora_rank - ) - if not self.model_config.use_mla: # for non-mla model, use FullAttentionSpec k_tensor_split_factor = 2.0 @@ -2703,8 +2698,16 @@ class NPUModelRunner(GPUModelRunner): dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3] else: # for other deepseek models, use MLAAttentionSpec - k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank - v_tensor_split_factor = head_size / self.model_config.hf_text_config.qk_rope_head_dim + kv_head_dim_list = [ + self.model_config.hf_text_config.kv_lora_rank, + self.model_config.hf_text_config.qk_rope_head_dim, + ] + if self.is_kv_consumer and self.vllm_config.quant_config is not None: + k_tensor_split_factor, v_tensor_split_factor = ( + self.vllm_config.quant_config.get_kv_quant_split_factor(layer_name, kv_head_dim_list) + ) + else: + k_tensor_split_factor, v_tensor_split_factor = calc_split_factor(kv_head_dim_list) k_tensor_size = int(kv_cache_tensor.size // k_tensor_split_factor) v_tensor_size = int(kv_cache_tensor.size // v_tensor_split_factor) @@ -2881,8 +2884,13 @@ class NPUModelRunner(GPUModelRunner): num_kv_heads, self.model_config.hf_text_config.qk_rope_head_dim, ] - k_cache = raw_k_tensor.view(kv_cache_spec.dtype).view(k_shape) - v_cache = raw_v_tensor.view(kv_cache_spec.dtype).view(v_shape) + k_cache_dtype = v_cache_dtype = kv_cache_spec.dtype + if self.is_kv_consumer and self.vllm_config.quant_config is not None: + k_cache_dtype, v_cache_dtype = self.vllm_config.quant_config.get_kv_quant_dtype( + layer_name, kv_cache_spec.dtype, self.model_config + ) + k_cache = raw_k_tensor.view(k_cache_dtype).view(k_shape) + v_cache = raw_v_tensor.view(v_cache_dtype).view(v_shape) if self.use_sparse: dsa_k_cache_shape = ( @@ -3199,12 +3207,17 @@ class NPUModelRunner(GPUModelRunner): elif spec := attn_module.get_kv_cache_spec(self.vllm_config): assert isinstance(spec, MLAAttentionSpec) from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec + if getattr(attn_module.impl, "fa_quant_layer", False): + head_size = attn_module.head_size + attn_module.qk_rope_head_dim + dtype, cache_dtype_str = attn_module.impl.dtype, None + else: + head_size, dtype, cache_dtype_str = spec.head_size, spec.dtype, spec.cache_dtype_str kv_cache_spec[layer_name] = AscendMLAAttentionSpec( block_size=spec.block_size, num_kv_heads=spec.num_kv_heads, - head_size=spec.head_size, - dtype=spec.dtype, - cache_dtype_str=spec.cache_dtype_str, + head_size=head_size, + dtype=dtype, + cache_dtype_str=cache_dtype_str, ) elif isinstance(attn_module, MambaBase):