feat: add mtp ut and fix some bugs (#2453)
### What this PR does / why we need it?
Fix mtp mode ut
### Does this PR introduce _any_ user-facing change?
Nothing
### How was this patch tested?
This can be tested in the same way as a unit test.
- vLLM version: v0.10.0
- vLLM main:
53415653ff
Signed-off-by: 赵江江 <zhaojiangjiang1@h-partners.com>
Co-authored-by: 赵江江 <zhaojiangjiang1@h-partners.com>
This commit is contained in:
@@ -492,17 +492,17 @@ class AscendMLATorchairMetadataBuilder:
|
||||
graph_pad_size = common_attn_metadata.graph_pad_size
|
||||
use_torchair_graph = graph_pad_size != -1
|
||||
if num_decodes > 0:
|
||||
actual_seq_lengths_q = query_start_loc[1:].tolist()
|
||||
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
|
||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||
seq_lens = seq_lens[:num_decode_tokens]
|
||||
input_positions = input_positions[:num_decode_tokens]
|
||||
block_table = block_table[:num_decode_tokens, ...]
|
||||
num_token_pad_size = 0
|
||||
if use_torchair_graph and common_attn_metadata.attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
AscendAttentionState.SpecDecoding
|
||||
]:
|
||||
num_reqs_pad_size = 0
|
||||
num_token_pad_size = 0
|
||||
if graph_pad_size != 0:
|
||||
pad_value = 0
|
||||
num_token_pad_size = graph_pad_size - num_decode_tokens
|
||||
@@ -535,13 +535,14 @@ class AscendMLATorchairMetadataBuilder:
|
||||
device=input_positions.device)
|
||||
input_positions = torch.cat(
|
||||
[input_positions, position_padding])
|
||||
actual_seq_lengths_q = query_start_loc[1:].tolist(
|
||||
) + common_attn_metadata.actual_seq_lengths_q[
|
||||
num_reqs:num_reqs + num_reqs_pad_size]
|
||||
actual_seq_lengths_q = (
|
||||
actual_seq_lengths_q + common_attn_metadata.
|
||||
actual_seq_lengths_q[num_reqs:num_reqs +
|
||||
num_reqs_pad_size])
|
||||
else:
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
|
||||
batch_size = slot_mapping.size(0)
|
||||
batch_size = num_decode_tokens + num_token_pad_size
|
||||
if actual_seq_lengths_q[-1] != batch_size \
|
||||
and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||
actual_seq_lengths_q[-1] = batch_size
|
||||
|
||||
Reference in New Issue
Block a user