feat: add mtp ut and fix some bugs (#2453)

### What this PR does / why we need it?
Fix mtp mode ut

### Does this PR introduce _any_ user-facing change?
Nothing

### How was this patch tested?
This can be tested in the same way as a unit test.


- vLLM version: v0.10.0
- vLLM main:
53415653ff

Signed-off-by: 赵江江 <zhaojiangjiang1@h-partners.com>
Co-authored-by: 赵江江 <zhaojiangjiang1@h-partners.com>
This commit is contained in:
ZhaoJiangJiang
2025-08-22 17:09:08 +08:00
committed by GitHub
parent dd04a96ee3
commit 3629bc4431
10 changed files with 129 additions and 75 deletions

View File

@@ -113,6 +113,7 @@ class TestAscendQuantConfig(TestBase):
def test_get_quant_method_for_fused_moe(self):
fused_moe_layer = MagicMock(spec=FusedMoE)
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig)
# Test skipped layer
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \

View File

@@ -1,11 +1,13 @@
from unittest.mock import MagicMock, patch
import torch
from torch import nn
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.model_executor.layers.linear import LinearBase
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.torchair.torchair_mla import (
AscendMLATorchairBackend, AscendMLATorchairDecodeMetadata,
AscendMLATorchairImpl, AscendMLATorchairMetadata,
@@ -398,6 +400,68 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
assert torch.equal(sin_golden, metadata.decode.sin)
assert torch.equal(cos_golden, metadata.decode.cos)
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_build_decode(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_vllm_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_device = 'cpu'
model = MagicMock(spec=nn.Module)
model.model = MagicMock(spec=nn.Module)
builder = AscendMLATorchairMetadataBuilder(
mock_vllm_config,
mock_device,
metadata_cls=AscendMLATorchairMetadata)
builder.rope_dim = 64
builder.sin_cache = torch.tensor([10, 10])
builder.cos_cache = torch.tensor([10, 10])
with patch.object(builder,
"_get_graph_runner_block_tables",
side_effect=lambda x, y: y):
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 1, 2, 3]),
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
seq_lens_cpu=torch.tensor([1, 1, 1]),
num_reqs=3,
num_actual_tokens=3,
max_query_len=1,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([1, 1]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill)
metadata = builder.build(common_attn_metadata, model)
self.assertIsInstance(metadata, AscendMLATorchairMetadata)
self.assertEqual(metadata.num_input_tokens, 0)
self.assertEqual(metadata.num_actual_tokens, 3)
self.assertEqual(metadata.num_decodes, 3)
self.assertEqual(metadata.num_decode_tokens, 3)
self.assertEqual(metadata.num_prefills, 0)
self.assertEqual(metadata.attn_state,
AscendAttentionState.ChunkedPrefill)
self.assertIsNone(metadata.prefill)
self.assertIsInstance(metadata.decode, AscendMLATorchairDecodeMetadata)
self.assertEqual(metadata.block_tables.shape[0], 3)
self.assertEqual(metadata.block_tables.shape[1], 10)
self.assertEqual(metadata.seq_lens.shape[0], 3)
self.assertEqual(metadata.slot_mapping.shape[0], 3)
self.assertEqual(metadata.query_start_loc.shape[0], 4)
class TestAscendMLATorchairImpl(TestBase):