diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index f3872a15..fabfb7b4 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import torch +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.distributed.parallel_state import GroupCoordinator from vllm.model_executor.layers.linear import LinearBase @@ -11,6 +12,7 @@ from vllm_ascend.attention.mla_v1 import (AscendMLABackend, AscendMLAImpl, AscendMLAMetadata, AscendMLAMetadataBuilder, AscendMLAPrefillMetadata) +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata class TestAscendMLABackend(TestBase): @@ -306,6 +308,264 @@ class TestAscendMLAMetadataBuilder(TestBase): input_batch.swap_states.assert_called_once_with(1, 2) +class TestAscendMLAMetadataBuilderBuild(TestBase): + + def setUp(self): + self.mock_vllm_config = MagicMock(spec=VllmConfig) + self.mock_vllm_config.model_config = ModelConfig(max_model_len=2048) + self.mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 32 + self.mock_vllm_config.cache_config = CacheConfig(block_size=32) + self.mock_vllm_config.scheduler_config = SchedulerConfig( + max_num_seqs=8, chunked_prefill_enabled=True) + self.mock_vllm_config.speculative_config = None + self.mock_device = torch.device("cpu") + + self.kv_cache_spec = MagicMock() + self.kv_cache_spec.num_layers = 32 + self.kv_cache_spec.head_size = 128 + self.kv_cache_spec.num_heads = 32 + + @patch( + "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size" + ) + @patch("vllm_ascend.attention.mla_v1.get_ascend_config") + def test_build_prefix_no_cache_metadata(self, mock_get_ascend_config, + mock_dcp_world_size): + mock_dcp_world_size.return_value = 1 + + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=torch.tensor([0, 3, 7]), + query_start_loc_cpu=torch.tensor([0, 3, 7]), + seq_lens_cpu=torch.tensor([5, 6]), + num_reqs=2, + num_actual_tokens=10, + max_query_len=5, + decode_token_per_req=torch.tensor([1, 1]), + block_table_tensor=torch.zeros((10, 10)), + slot_mapping=torch.tensor(range(20)), + actual_seq_lengths_q=torch.tensor([0, 1]), + positions=torch.tensor([10, 10]), + attn_mask=torch.ones((10, 10)), + spec_attn_mask=None, + attn_state=AscendAttentionState.PrefillNoCache, + num_computed_tokens_cpu=None, + seq_lens=None) + + base_inputs = { + "num_actual_tokens": 10, + "slot_mapping": torch.tensor(range(10)), + "query_start_loc": torch.tensor([0, 3, 7]), + "seq_lens": torch.tensor([5, 6]), + "block_tables": torch.zeros((10, 10)), + "num_prefills": 2, + } + + builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec, + layer_names=["layer_0", "layer_1"], + vllm_config=self.mock_vllm_config, + device=self.mock_device) + + mock_model = MagicMock() + metadata = builder.build(1, common_attn_metadata, mock_model) + + self.assertIsInstance(metadata, AscendMLAMetadata) + self.assertEqual(metadata.num_actual_tokens, + base_inputs["num_actual_tokens"]) + self.assertTrue( + torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) + self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) + + @patch( + "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size" + ) + @patch("vllm_ascend.attention.mla_v1.get_ascend_config") + def test_build_chunked_prefix_metadata(self, mock_get_ascend_config, + mock_dcp_world_size): + mock_dcp_world_size.return_value = 1 + + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=torch.tensor([0, 2, 5, 9]), + query_start_loc_cpu=torch.tensor([0, 2, 5, 9]), + seq_lens_cpu=torch.tensor([4, 5, 6]), + num_reqs=3, + num_actual_tokens=15, + max_query_len=6, + decode_token_per_req=torch.tensor([1, 1, 1]), + block_table_tensor=torch.zeros((10, 10)), + slot_mapping=torch.tensor(range(20)), + actual_seq_lengths_q=torch.tensor([0, 1, 2]), + positions=torch.tensor([10, 10]), + attn_mask=torch.ones((15, 15)), + spec_attn_mask=None, + attn_state=AscendAttentionState.ChunkedPrefill, + num_computed_tokens_cpu=None, + seq_lens=None) + + base_inputs = { + "num_actual_tokens": 15, + "slot_mapping": torch.tensor(range(15)), + "query_start_loc": torch.tensor([0, 2, 5, 9]), + "seq_lens": torch.tensor([4, 5, 6]), + "block_tables": torch.zeros((10, 10)), + "num_prefills": 3, + } + + builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec, + layer_names=["layer_0", "layer_1"], + vllm_config=self.mock_vllm_config, + device=self.mock_device) + + mock_model = MagicMock() + metadata = builder.build(1, common_attn_metadata, mock_model) + + self.assertIsInstance(metadata, AscendMLAMetadata) + self.assertEqual(metadata.num_actual_tokens, + base_inputs["num_actual_tokens"]) + self.assertTrue( + torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) + self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) + + @patch( + "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size" + ) + @patch("vllm_ascend.attention.mla_v1.get_ascend_config") + def test_build_decode_only_metadata(self, mock_get_ascend_config, + mock_dcp_world_size): + mock_dcp_world_size.return_value = 1 + + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=torch.tensor([0, 1, 2, 3]), + query_start_loc_cpu=torch.tensor([0, 1, 2, 3]), + seq_lens_cpu=torch.tensor([4, 5, 6]), + num_reqs=3, + num_actual_tokens=3, + max_query_len=1, + block_table_tensor=torch.zeros((10, 10)), + slot_mapping=torch.tensor(range(3)), + actual_seq_lengths_q=torch.tensor([0, 1, 2]), + decode_token_per_req=torch.tensor([1, 1, 1]), + positions=torch.tensor([10, 10]), + attn_mask=torch.ones((3, 3)), + spec_attn_mask=None, + attn_state=AscendAttentionState.DecodeOnly, + num_computed_tokens_cpu=None, + seq_lens=None) + + base_inputs = { + "num_actual_tokens": 3, + "slot_mapping": torch.tensor(range(3)), + "query_start_loc": torch.tensor([0, 1, 2, 3]), + "seq_lens": torch.tensor([4, 5, 6]), + "num_decodes": 3, + } + + builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec, + layer_names=["layer_0", "layer_1"], + vllm_config=self.mock_vllm_config, + device=self.mock_device) + + mock_model = MagicMock() + metadata = builder.build(1, common_attn_metadata, mock_model) + + self.assertIsInstance(metadata, AscendMLAMetadata) + self.assertEqual(metadata.num_actual_tokens, + base_inputs["num_actual_tokens"]) + self.assertTrue( + torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) + self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) + + @patch( + "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size" + ) + @patch("vllm_ascend.attention.mla_v1.get_ascend_config") + def test_build_for_graph_capture_decode_only(self, mock_get_ascend_config, + mock_dcp_world_size): + mock_dcp_world_size.return_value = 1 + + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=torch.tensor([0, 1, 2, 3]), + query_start_loc_cpu=torch.tensor([0, 1, 2, 3]), + seq_lens_cpu=torch.tensor([4, 5, 6]), + num_reqs=3, + num_actual_tokens=3, + max_query_len=1, + block_table_tensor=torch.zeros((10, 10)), + slot_mapping=torch.tensor(range(3)), + actual_seq_lengths_q=torch.tensor([0, 1, 2]), + decode_token_per_req=torch.tensor([1, 1, 1]), + positions=torch.tensor([10, 10]), + attn_mask=torch.ones((3, 3)), + spec_attn_mask=None, + attn_state=AscendAttentionState.DecodeOnly, + num_computed_tokens_cpu=None, + seq_lens=None) + + base_inputs = { + "num_actual_tokens": 3, + "slot_mapping": torch.tensor(range(3)), + "query_start_loc": torch.tensor([0, 1, 2, 3]), + "seq_lens": torch.tensor([4, 5, 6]), + "num_decodes": 3, + } + + builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec, + layer_names=["layer_0", "layer_1"], + vllm_config=self.mock_vllm_config, + device=self.mock_device) + + mock_model = MagicMock() + metadata = builder.build_for_graph_capture( + common_attn_metadata, AscendAttentionState.DecodeOnly, mock_model) + + self.assertIsInstance(metadata, AscendMLAMetadata) + self.assertEqual(metadata.num_actual_tokens, + base_inputs["num_actual_tokens"]) + self.assertTrue( + torch.all(metadata.slot_mapping == base_inputs["slot_mapping"])) + self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) + + @patch( + "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size" + ) + @patch("vllm_ascend.attention.mla_v1.get_ascend_config") + def test_build_for_graph_capture_prefill(self, mock_get_ascend_config, + mock_dcp_world_size): + mock_dcp_world_size.return_value = 1 + + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=torch.tensor([0, 3, 7]), + query_start_loc_cpu=torch.tensor([0, 3, 7]), + seq_lens_cpu=torch.tensor([5, 6]), + num_reqs=2, + num_actual_tokens=10, + max_query_len=5, + decode_token_per_req=torch.tensor([1, 1]), + block_table_tensor=torch.zeros((10, 10)), + slot_mapping=torch.tensor(range(20)), + actual_seq_lengths_q=torch.tensor([0, 1]), + positions=torch.tensor([10, 10]), + attn_mask=torch.ones((10, 10)), + spec_attn_mask=None, + attn_state=AscendAttentionState.PrefillNoCache, + num_computed_tokens_cpu=None, + seq_lens=None) + + builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec, + layer_names=["layer_0", "layer_1"], + vllm_config=self.mock_vllm_config, + device=self.mock_device) + + mock_model = MagicMock() + + with self.assertRaises(NotImplementedError) as ctx: + builder.build_for_graph_capture( + common_attn_metadata, AscendAttentionState.PrefillNoCache, + mock_model) + self.assertIn( + "Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state", + str(ctx.exception)) + + class TestAscendMLAImpl(TestBase): @patch('vllm.distributed.parallel_state._DCP', diff --git a/tests/ut/models/test_mla.py b/tests/ut/models/test_mla.py new file mode 100644 index 00000000..0f3e166f --- /dev/null +++ b/tests/ut/models/test_mla.py @@ -0,0 +1,215 @@ +from unittest.mock import MagicMock, patch + +import torch +from torch import nn +from vllm.config import CacheConfig, CompilationConfig, VllmConfig +from vllm.forward_context import ForwardContext +from vllm.model_executor.layers.mla import MLAModules + +from tests.ut.base import TestBase +from vllm_ascend.models.layers.mla import (AscendMultiHeadLatentAttention, + IndexerWrapper) +from vllm_ascend.utils import vllm_version_is + + +class TestIndexerWrapper(TestBase): + + def test_initialization(self): + mock_indexer = MagicMock() + mock_indexer.n_head = 64 + mock_indexer.head_dim = 128 + mock_indexer.topk_tokens = 2048 + mock_indexer.q_lora_rank = 1536 + mock_indexer.wq_b = nn.Linear(128, 128) + mock_indexer.wk = nn.Linear(128, 128) + mock_indexer.weights_proj = nn.Linear(128, 128) + mock_indexer.k_norm = nn.LayerNorm(128) + mock_indexer.softmax_scale = 0.123 + mock_indexer.topk_indices_buffer = torch.randn(10) + mock_indexer.k_cache = torch.randn(10) + + wrapper = IndexerWrapper(mock_indexer) + + self.assertEqual(wrapper.n_head, 64) + self.assertEqual(wrapper.head_dim, 128) + self.assertEqual(wrapper.topk_tokens, 2048) + self.assertEqual(wrapper.q_lora_rank, 1536) + self.assertIs(wrapper.wq_b, mock_indexer.wq_b) + self.assertIs(wrapper.wk, mock_indexer.wk) + self.assertIs(wrapper.weights_proj, mock_indexer.weights_proj) + self.assertIs(wrapper.k_norm, mock_indexer.k_norm) + self.assertEqual(wrapper.softmax_scale, 0.123) + + self.assertIsNone(mock_indexer.topk_indices_buffer) + self.assertIsNone(mock_indexer.k_cache) + + def test_forward(self): + mock_indexer = MagicMock() + wrapper = IndexerWrapper(mock_indexer) + result = wrapper.forward() + self.assertIsNone(result) + + +class TestAscendMultiHeadLatentAttention(TestBase): + + def setUp(self): + self.hidden_size = 4096 + self.num_heads = 32 + self.scale = 0.123 + self.qk_nope_head_dim = 64 + self.qk_rope_head_dim = 64 + self.v_head_dim = 128 + self.q_lora_rank = 1536 + self.kv_lora_rank = 128 + self.prefix = "model.layers.0.mla" + + self.mock_mla_modules = MagicMock(spec=MLAModules) + self.mock_mla_modules.indexer = MagicMock() + self.mock_mla_modules.is_sparse = False + self.mock_mla_modules.rotary_emb = MagicMock() + self.mock_mla_modules.fused_qkv_a_proj = MagicMock() + self.mock_mla_modules.q_b_proj = MagicMock() + self.mock_mla_modules.q_a_layernorm = MagicMock() + self.mock_mla_modules.q_proj = MagicMock() + self.mock_mla_modules.kv_a_proj_with_mqa = MagicMock() + self.mock_mla_modules.kv_a_layernorm = MagicMock() + self.mock_mla_modules.kv_b_proj = MagicMock() + self.mock_mla_modules.o_proj = MagicMock() + + self.mock_cache_config = MagicMock(spec=CacheConfig) + self.mock_quant_config = MagicMock() + + @patch("vllm_ascend.models.layers.mla.get_current_vllm_config") + @patch("vllm_ascend.models.layers.mla.get_ascend_config") + @patch( + "vllm_ascend.models.layers.mla.get_tensor_model_parallel_world_size") + def test_initialization(self, mock_tp_size, mock_ascend_config, + mock_get_vllm_config): + if vllm_version_is("0.11.0"): + with patch("vllm_ascend.models.layers.mla.Attention", + return_value=True): + mock_tp_size.return_value = 1 + mock_ascend_config.return_value.enable_shared_expert_dp = False + mock_vllm_config = MagicMock(spec=VllmConfig) + mock_vllm_config.model_config.hf_config = MagicMock( + num_hidden_layers=32, first_k_dense_replace=False) + mock_get_vllm_config.return_value = mock_vllm_config + mock_vllm_config.compilation_config = CompilationConfig() + + attn = AscendMultiHeadLatentAttention( + hidden_size=self.hidden_size, + num_heads=self.num_heads, + scale=self.scale, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + mla_modules=self.mock_mla_modules, + cache_config=self.mock_cache_config, + quant_config=self.mock_quant_config, + prefix=self.prefix, + ) + + self.assertEqual(attn.hidden_size, self.hidden_size) + self.assertEqual(attn.kv_lora_rank, self.kv_lora_rank) + self.assertEqual(attn.debug_layer_idx, 0) + self.assertIsNotNone(attn.mla_attn) + self.assertIn( + self.prefix, + mock_vllm_config.compilation_config.static_forward_context) + else: + with patch("vllm_ascend.models.layers.mla.MLAAttention", + return_value=True): + mock_tp_size.return_value = 2 + mock_ascend_config.return_value.enable_shared_expert_dp = True + mock_vllm_config = MagicMock(spec=VllmConfig) + mock_vllm_config.model_config.hf_config = MagicMock( + num_hidden_layers=32, first_k_dense_replace=True) + mock_get_vllm_config.return_value = mock_vllm_config + mock_vllm_config.compilation_config = CompilationConfig() + + attn = AscendMultiHeadLatentAttention( + hidden_size=self.hidden_size, + num_heads=self.num_heads, + scale=self.scale, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + mla_modules=self.mock_mla_modules, + cache_config=self.mock_cache_config, + quant_config=self.mock_quant_config, + prefix=self.prefix, + ) + + self.assertEqual(attn.tp_size, 2) + self.assertTrue(attn.enable_shared_expert_dp) + self.assertIsNotNone(attn.mla_attn) + + @patch("vllm_ascend.models.layers.mla.torch.ops.vllm.mla_forward") + @patch("vllm_ascend.models.layers.mla.get_current_vllm_config") + @patch("vllm_ascend.models.layers.mla.get_ascend_config") + @patch( + "vllm_ascend.models.layers.mla.get_tensor_model_parallel_world_size") + @patch("vllm_ascend.models.layers.mla.get_forward_context") + def test_forward(self, mock_get_forward_context, mock_tp_size, + mock_ascend_config, mock_get_vllm_config, + mock_mla_forward): + mock_tp_size.return_value = 1 + mock_ascend_config.return_value.enable_shared_expert_dp = False + mock_vllm_config = MagicMock(spec=VllmConfig) + mock_vllm_config.model_config.hf_config = MagicMock( + num_hidden_layers=32, first_k_dense_replace=False) + mock_get_vllm_config.return_value = mock_vllm_config + mock_vllm_config.compilation_config = CompilationConfig() + + if vllm_version_is("0.11.0"): + with patch("vllm_ascend.models.layers.mla.Attention", + return_value=True): + attn = AscendMultiHeadLatentAttention( + hidden_size=self.hidden_size, + num_heads=self.num_heads, + scale=self.scale, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + mla_modules=self.mock_mla_modules, + cache_config=self.mock_cache_config, + quant_config=self.mock_quant_config, + prefix=self.prefix, + ) + else: + with patch("vllm_ascend.models.layers.mla.MLAAttention", + return_value=True): + attn = AscendMultiHeadLatentAttention( + hidden_size=self.hidden_size, + num_heads=self.num_heads, + scale=self.scale, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + mla_modules=self.mock_mla_modules, + cache_config=self.mock_cache_config, + quant_config=self.mock_quant_config, + prefix=self.prefix, + ) + positions = torch.tensor([0, 1, 2]) + hidden_states = torch.randn(3, self.hidden_size) + + mock_forward_context = MagicMock(spec=ForwardContext) + mock_forward_context.sp_enabled = False + mock_get_forward_context.return_value = mock_forward_context + + mock_mla_forward.return_value = (3, self.hidden_size) + + output = attn.forward(positions, hidden_states) + + self.assertEqual(output.shape, (3, self.hidden_size)) + self.assertTrue( + torch.allclose(output, output.view(-1, self.hidden_size)))