From 81aff9c5551b39386a888b874ceb04f58b814185 Mon Sep 17 00:00:00 2001 From: zouyida2052 Date: Thu, 9 Oct 2025 19:22:46 +0800 Subject: [PATCH] bugfix for mtp (#3300) ### What this PR does / why we need it? when mtp>1, we need refresh cos ans sin in each step. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? - vLLM version: v0.11.0 Signed-off-by: zouyida2052 --- vllm_ascend/attention/utils.py | 2 +- vllm_ascend/spec_decode/mtp_proposer.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 519cde0..efc1103 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -97,7 +97,7 @@ def split_decodes_and_prefills( return num_reqs, 0, num_tokens, 0 first_prefill = is_prefill.int().argmax(dim=-1).item() - assert torch.all(query_lens[first_prefill:] >= decode_threshold) + assert torch.all(query_lens[first_prefill:] > decode_threshold) assert torch.all(query_lens[:first_prefill] <= decode_threshold) num_decodes = first_prefill num_prefills = num_reqs - num_decodes diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 6889efb..f42d381 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -523,6 +523,14 @@ class MtpProposer(Proposer): input_ids = draft_token_ids_list[-1].int() positions += 1 + if not self.torchair_graph_enabled: + attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[ + 1:batch_size + 1].tolist() + attn_metadata_i.decode.cos = builder.cos_cache[ + positions].unsqueeze(1).unsqueeze(2) + attn_metadata_i.decode.sin = builder.sin_cache[ + positions].unsqueeze(1).unsqueeze(2) + # NOTE(woosuk): We should handle the case where the draft model # generates tokens beyond the max model length. Since it is complex # to remove such requests from the batch, we keep them in the batch @@ -556,6 +564,8 @@ class MtpProposer(Proposer): if attn_metadata_i.prefill is not None: attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens + attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist( + ) attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens attn_metadata_i.prefill.input_positions = self.positions[: num_input_tokens] @@ -565,6 +575,8 @@ class MtpProposer(Proposer): self.runner.model_config.max_model_len) if attn_metadata_i.decode is not None: attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens + attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist( + ) attn_metadata_i.decode.input_positions = self.positions[: num_input_tokens] attn_metadata_i.decode.max_seq_lens += 1