From 50441e46506adfd255604c130c2d93ffd68e839e Mon Sep 17 00:00:00 2001 From: Cao Yi Date: Thu, 5 Mar 2026 17:33:10 +0800 Subject: [PATCH] [BugFix][MTP] Fix prefill misclassified as decode when prompt tokens == num_spec_tokens + 1 (#6835) ## Problem When MTP is enabled, prefill requests with `prompt_tokens == num_spec_tokens + 1` are incorrectly classified as decode requests, causing accuracy issues. ## Root Cause The `uniform_decode` condition only checked: - `max_num_scheduled_tokens == uniform_decode_query_len` - `num_tokens == max_num_scheduled_tokens * num_reqs` This is insufficient because a prefill request with specific prompt length satisfies these conditions as well. ## Fix Add `is_all_decode` check to ensure all requests have `num_computed_tokens > 0` before classifying as uniform decode, since decode requests must have computed at least one token. - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83b47f67b1dfad505606070ae4d9f83e50ad4ebd --------- Signed-off-by: SlightwindSec --- .../e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py | 1 + vllm_ascend/worker/model_runner_v1.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py b/tests/e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py index 98cc4de6..470dd968 100644 --- a/tests/e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py +++ b/tests/e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py @@ -100,6 +100,7 @@ def test_qwen3_next_mtp_correctness_tp4(model_name: str, "The president of the United States is", "The capital of France is", "The future of AI is", + "Who are you?", ] max_tokens = 20 diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ca475dd3..0bbe13f5 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1810,9 +1810,11 @@ class NPUModelRunner(GPUModelRunner): num_encoder_reqs: int = 0, ) -> tuple[CUDAGraphMode, BatchDescriptor, bool, torch.Tensor | None, CUDAGraphStat | None]: num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) + is_all_decode = np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] > 0) uniform_decode = ( ( - (max_num_scheduled_tokens == self.uniform_decode_query_len) + (is_all_decode if self.speculative_config else True) + and (max_num_scheduled_tokens == self.uniform_decode_query_len) and (num_tokens == max_num_scheduled_tokens * num_reqs) ) if force_uniform_decode is None