feat: add mtp ut and fix some bugs (#2453)
### What this PR does / why we need it?
Fix mtp mode ut
### Does this PR introduce _any_ user-facing change?
Nothing
### How was this patch tested?
This can be tested in the same way as a unit test.
- vLLM version: v0.10.0
- vLLM main:
53415653ff
Signed-off-by: 赵江江 <zhaojiangjiang1@h-partners.com>
Co-authored-by: 赵江江 <zhaojiangjiang1@h-partners.com>
This commit is contained in:
@@ -374,18 +374,12 @@ class AscendMLAMetadataBuilder:
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
actual_seq_lengths_q = query_start_loc[1:].tolist()
|
||||
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
|
||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||
seq_lens = seq_lens[:num_decode_tokens]
|
||||
input_positions = input_positions[:num_decode_tokens]
|
||||
block_table = block_table[:num_decode_tokens, ...]
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
# TODO(xyx): whether this block is necessary without torchair
|
||||
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
|
||||
batch_size = slot_mapping.size(0)
|
||||
if actual_seq_lengths_q[-1] != batch_size \
|
||||
and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||
actual_seq_lengths_q[-1] = batch_size
|
||||
|
||||
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
|
||||
Reference in New Issue
Block a user