[CI] Add mla ut (#4280)

### What this PR does / why we need it?
add mla_v1.py and mla.py ut
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
`pytest tests/ut/attention/test_mla_v1.py`
`pytest tests/ut/models/test_mla.py`

- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

Signed-off-by: GDzhu01 <809721801@qq.com>
This commit is contained in:
Zhu Yi Lin
2025-11-20 20:29:09 +08:00
committed by GitHub
parent 470fe05df6
commit 15c1eb025c
2 changed files with 475 additions and 0 deletions

View File

@@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import torch
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.model_executor.layers.linear import LinearBase
@@ -11,6 +12,7 @@ from vllm_ascend.attention.mla_v1 import (AscendMLABackend,
AscendMLAImpl, AscendMLAMetadata,
AscendMLAMetadataBuilder,
AscendMLAPrefillMetadata)
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
class TestAscendMLABackend(TestBase):
@@ -306,6 +308,264 @@ class TestAscendMLAMetadataBuilder(TestBase):
input_batch.swap_states.assert_called_once_with(1, 2)
class TestAscendMLAMetadataBuilderBuild(TestBase):
def setUp(self):
self.mock_vllm_config = MagicMock(spec=VllmConfig)
self.mock_vllm_config.model_config = ModelConfig(max_model_len=2048)
self.mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 32
self.mock_vllm_config.cache_config = CacheConfig(block_size=32)
self.mock_vllm_config.scheduler_config = SchedulerConfig(
max_num_seqs=8, chunked_prefill_enabled=True)
self.mock_vllm_config.speculative_config = None
self.mock_device = torch.device("cpu")
self.kv_cache_spec = MagicMock()
self.kv_cache_spec.num_layers = 32
self.kv_cache_spec.head_size = 128
self.kv_cache_spec.num_heads = 32
@patch(
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
)
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def test_build_prefix_no_cache_metadata(self, mock_get_ascend_config,
mock_dcp_world_size):
mock_dcp_world_size.return_value = 1
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 3, 7]),
query_start_loc_cpu=torch.tensor([0, 3, 7]),
seq_lens_cpu=torch.tensor([5, 6]),
num_reqs=2,
num_actual_tokens=10,
max_query_len=5,
decode_token_per_req=torch.tensor([1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((10, 10)),
spec_attn_mask=None,
attn_state=AscendAttentionState.PrefillNoCache,
num_computed_tokens_cpu=None,
seq_lens=None)
base_inputs = {
"num_actual_tokens": 10,
"slot_mapping": torch.tensor(range(10)),
"query_start_loc": torch.tensor([0, 3, 7]),
"seq_lens": torch.tensor([5, 6]),
"block_tables": torch.zeros((10, 10)),
"num_prefills": 2,
}
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config,
device=self.mock_device)
mock_model = MagicMock()
metadata = builder.build(1, common_attn_metadata, mock_model)
self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_actual_tokens,
base_inputs["num_actual_tokens"])
self.assertTrue(
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch(
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
)
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def test_build_chunked_prefix_metadata(self, mock_get_ascend_config,
mock_dcp_world_size):
mock_dcp_world_size.return_value = 1
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 2, 5, 9]),
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
seq_lens_cpu=torch.tensor([4, 5, 6]),
num_reqs=3,
num_actual_tokens=15,
max_query_len=6,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill,
num_computed_tokens_cpu=None,
seq_lens=None)
base_inputs = {
"num_actual_tokens": 15,
"slot_mapping": torch.tensor(range(15)),
"query_start_loc": torch.tensor([0, 2, 5, 9]),
"seq_lens": torch.tensor([4, 5, 6]),
"block_tables": torch.zeros((10, 10)),
"num_prefills": 3,
}
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config,
device=self.mock_device)
mock_model = MagicMock()
metadata = builder.build(1, common_attn_metadata, mock_model)
self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_actual_tokens,
base_inputs["num_actual_tokens"])
self.assertTrue(
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch(
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
)
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def test_build_decode_only_metadata(self, mock_get_ascend_config,
mock_dcp_world_size):
mock_dcp_world_size.return_value = 1
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 1, 2, 3]),
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
seq_lens_cpu=torch.tensor([4, 5, 6]),
num_reqs=3,
num_actual_tokens=3,
max_query_len=1,
block_table_tensor=torch.zeros((10, 10)),
slot_mapping=torch.tensor(range(3)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
decode_token_per_req=torch.tensor([1, 1, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((3, 3)),
spec_attn_mask=None,
attn_state=AscendAttentionState.DecodeOnly,
num_computed_tokens_cpu=None,
seq_lens=None)
base_inputs = {
"num_actual_tokens": 3,
"slot_mapping": torch.tensor(range(3)),
"query_start_loc": torch.tensor([0, 1, 2, 3]),
"seq_lens": torch.tensor([4, 5, 6]),
"num_decodes": 3,
}
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config,
device=self.mock_device)
mock_model = MagicMock()
metadata = builder.build(1, common_attn_metadata, mock_model)
self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_actual_tokens,
base_inputs["num_actual_tokens"])
self.assertTrue(
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch(
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
)
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def test_build_for_graph_capture_decode_only(self, mock_get_ascend_config,
mock_dcp_world_size):
mock_dcp_world_size.return_value = 1
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 1, 2, 3]),
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
seq_lens_cpu=torch.tensor([4, 5, 6]),
num_reqs=3,
num_actual_tokens=3,
max_query_len=1,
block_table_tensor=torch.zeros((10, 10)),
slot_mapping=torch.tensor(range(3)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
decode_token_per_req=torch.tensor([1, 1, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((3, 3)),
spec_attn_mask=None,
attn_state=AscendAttentionState.DecodeOnly,
num_computed_tokens_cpu=None,
seq_lens=None)
base_inputs = {
"num_actual_tokens": 3,
"slot_mapping": torch.tensor(range(3)),
"query_start_loc": torch.tensor([0, 1, 2, 3]),
"seq_lens": torch.tensor([4, 5, 6]),
"num_decodes": 3,
}
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config,
device=self.mock_device)
mock_model = MagicMock()
metadata = builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.DecodeOnly, mock_model)
self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_actual_tokens,
base_inputs["num_actual_tokens"])
self.assertTrue(
torch.all(metadata.slot_mapping == base_inputs["slot_mapping"]))
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch(
"vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
)
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def test_build_for_graph_capture_prefill(self, mock_get_ascend_config,
mock_dcp_world_size):
mock_dcp_world_size.return_value = 1
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 3, 7]),
query_start_loc_cpu=torch.tensor([0, 3, 7]),
seq_lens_cpu=torch.tensor([5, 6]),
num_reqs=2,
num_actual_tokens=10,
max_query_len=5,
decode_token_per_req=torch.tensor([1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((10, 10)),
spec_attn_mask=None,
attn_state=AscendAttentionState.PrefillNoCache,
num_computed_tokens_cpu=None,
seq_lens=None)
builder = AscendMLAMetadataBuilder(kv_cache_spec=self.kv_cache_spec,
layer_names=["layer_0", "layer_1"],
vllm_config=self.mock_vllm_config,
device=self.mock_device)
mock_model = MagicMock()
with self.assertRaises(NotImplementedError) as ctx:
builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.PrefillNoCache,
mock_model)
self.assertIn(
"Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state",
str(ctx.exception))
class TestAscendMLAImpl(TestBase):
@patch('vllm.distributed.parallel_state._DCP',