feat: add mtp ut and fix some bugs (#2453)

### What this PR does / why we need it?
Fix mtp mode ut

### Does this PR introduce _any_ user-facing change?
Nothing

### How was this patch tested?
This can be tested in the same way as a unit test.


- vLLM version: v0.10.0
- vLLM main:
53415653ff

Signed-off-by: 赵江江 <zhaojiangjiang1@h-partners.com>
Co-authored-by: 赵江江 <zhaojiangjiang1@h-partners.com>
This commit is contained in:
ZhaoJiangJiang
2025-08-22 17:09:08 +08:00
committed by GitHub
parent dd04a96ee3
commit 3629bc4431
10 changed files with 129 additions and 75 deletions

View File

@@ -374,18 +374,12 @@ class AscendMLAMetadataBuilder:
decode_metadata = None
if num_decodes > 0:
actual_seq_lengths_q = query_start_loc[1:].tolist()
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decode_tokens]
input_positions = input_positions[:num_decode_tokens]
block_table = block_table[:num_decode_tokens, ...]
seq_lens_list = seq_lens.tolist()
# TODO(xyx): whether this block is necessary without torchair
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
batch_size = slot_mapping.size(0)
if actual_seq_lengths_q[-1] != batch_size \
and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
actual_seq_lengths_q[-1] = batch_size
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)