feat: add mtp ut and fix some bugs (#2453)

### What this PR does / why we need it?
Fix mtp mode ut

### Does this PR introduce _any_ user-facing change?
Nothing

### How was this patch tested?
This can be tested in the same way as a unit test.


- vLLM version: v0.10.0
- vLLM main:
53415653ff

Signed-off-by: 赵江江 <zhaojiangjiang1@h-partners.com>
Co-authored-by: 赵江江 <zhaojiangjiang1@h-partners.com>
This commit is contained in:
ZhaoJiangJiang
2025-08-22 17:09:08 +08:00
committed by GitHub
parent dd04a96ee3
commit 3629bc4431
10 changed files with 129 additions and 75 deletions

View File

@@ -374,18 +374,12 @@ class AscendMLAMetadataBuilder:
decode_metadata = None
if num_decodes > 0:
actual_seq_lengths_q = query_start_loc[1:].tolist()
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decode_tokens]
input_positions = input_positions[:num_decode_tokens]
block_table = block_table[:num_decode_tokens, ...]
seq_lens_list = seq_lens.tolist()
# TODO(xyx): whether this block is necessary without torchair
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
batch_size = slot_mapping.size(0)
if actual_seq_lengths_q[-1] != batch_size \
and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
actual_seq_lengths_q[-1] = batch_size
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)

View File

@@ -215,4 +215,4 @@ class CustomDeepSeekMTP(DeepSeekMTP):
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, previous_hidden_states,
inputs_embeds, spec_step_idx)
return hidden_states
return hidden_states

View File

@@ -1178,7 +1178,7 @@ class AscendFusedMoE(FusedMoE):
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
self.moe = FusedMoEConfig.make(
moe = FusedMoEConfig.make(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
@@ -1188,8 +1188,10 @@ class AscendFusedMoE(FusedMoE):
in_dtype=params_dtype,
quant_config=quant_config)
self.moe_config = moe
if quant_config is None:
self.quant_method = AscendUnquantizedFusedMoEMethod(self.moe)
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)

View File

@@ -102,7 +102,7 @@ class AscendQuantConfig(QuantizationConfig):
elif isinstance(layer, FusedMoE):
if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping):
return AscendUnquantizedFusedMoEMethod(layer.moe)
return AscendUnquantizedFusedMoEMethod(layer.moe_config)
return AscendFusedMoEMethod(self, prefix,
self.packed_modules_mapping)
elif isinstance(layer, VocabParallelEmbedding):

View File

@@ -492,17 +492,17 @@ class AscendMLATorchairMetadataBuilder:
graph_pad_size = common_attn_metadata.graph_pad_size
use_torchair_graph = graph_pad_size != -1
if num_decodes > 0:
actual_seq_lengths_q = query_start_loc[1:].tolist()
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decode_tokens]
input_positions = input_positions[:num_decode_tokens]
block_table = block_table[:num_decode_tokens, ...]
num_token_pad_size = 0
if use_torchair_graph and common_attn_metadata.attn_state in [
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding
]:
num_reqs_pad_size = 0
num_token_pad_size = 0
if graph_pad_size != 0:
pad_value = 0
num_token_pad_size = graph_pad_size - num_decode_tokens
@@ -535,13 +535,14 @@ class AscendMLATorchairMetadataBuilder:
device=input_positions.device)
input_positions = torch.cat(
[input_positions, position_padding])
actual_seq_lengths_q = query_start_loc[1:].tolist(
) + common_attn_metadata.actual_seq_lengths_q[
num_reqs:num_reqs + num_reqs_pad_size]
actual_seq_lengths_q = (
actual_seq_lengths_q + common_attn_metadata.
actual_seq_lengths_q[num_reqs:num_reqs +
num_reqs_pad_size])
else:
seq_lens_list = seq_lens.tolist()
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
batch_size = slot_mapping.size(0)
batch_size = num_decode_tokens + num_token_pad_size
if actual_seq_lengths_q[-1] != batch_size \
and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
actual_seq_lengths_q[-1] = batch_size

View File

@@ -190,11 +190,6 @@ class MtpProposer:
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
if attn_metadata.prefill is not None:
attn_metadata.prefill.query_lens = query_lens.cpu()
attn_metadata.prefill.input_positions = target_positions
attn_metadata.prefill.seq_lens = seq_lens
if not self.torchair_graph_enabled:
# torch mode need to update num_tokens_across_dp
# TODO: adapt enable_dbo later
@@ -213,6 +208,7 @@ class MtpProposer:
num_tokens=num_input_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=num_tokens):
with ProfileExecuteDuration().capture_async('mtp_forward'):
@@ -315,6 +311,7 @@ class MtpProposer:
num_tokens=num_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=0):
if is_running_torchair:

View File

@@ -47,9 +47,14 @@ from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (init_ascend_soc_version,
register_ascend_customop, sleep_mode_enabled,
try_register_lib)
try_register_lib, vllm_version_is)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
if not vllm_version_is("0.10.1.1"):
from vllm.v1.outputs import DraftTokenIds
else:
DraftTokenIds = None
class NPUWorker(WorkerBase):
@@ -343,3 +348,6 @@ class NPUWorker(WorkerBase):
def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
return self.model_runner.get_supported_tasks()
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
return self.model_runner.take_draft_token_ids()