[BugFix][MTP] Fix prefill misclassified as decode when prompt tokens == num_spec_tokens + 1 (#6835)

## Problem
When MTP is enabled, prefill requests with `prompt_tokens ==
num_spec_tokens + 1` are incorrectly classified as decode requests,
causing accuracy issues.

## Root Cause
The `uniform_decode` condition only checked:
- `max_num_scheduled_tokens == uniform_decode_query_len`
- `num_tokens == max_num_scheduled_tokens * num_reqs`

This is insufficient because a prefill request with specific prompt
length satisfies these conditions as well.

## Fix
Add `is_all_decode` check to ensure all requests have
`num_computed_tokens > 0` before classifying as uniform decode, since
decode requests must have computed at least one token.
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1

---------

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
Cao Yi
2026-03-05 17:33:10 +08:00
committed by GitHub
parent 91c39ebae6
commit 50441e4650
2 changed files with 4 additions and 1 deletions

View File

@@ -100,6 +100,7 @@ def test_qwen3_next_mtp_correctness_tp4(model_name: str,
"The president of the United States is", "The president of the United States is",
"The capital of France is", "The capital of France is",
"The future of AI is", "The future of AI is",
"Who are you?",
] ]
max_tokens = 20 max_tokens = 20

View File

@@ -1810,9 +1810,11 @@ class NPUModelRunner(GPUModelRunner):
num_encoder_reqs: int = 0, num_encoder_reqs: int = 0,
) -> tuple[CUDAGraphMode, BatchDescriptor, bool, torch.Tensor | None, CUDAGraphStat | None]: ) -> tuple[CUDAGraphMode, BatchDescriptor, bool, torch.Tensor | None, CUDAGraphStat | None]:
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
is_all_decode = np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] > 0)
uniform_decode = ( uniform_decode = (
( (
(max_num_scheduled_tokens == self.uniform_decode_query_len) (is_all_decode if self.speculative_config else True)
and (max_num_scheduled_tokens == self.uniform_decode_query_len)
and (num_tokens == max_num_scheduled_tokens * num_reqs) and (num_tokens == max_num_scheduled_tokens * num_reqs)
) )
if force_uniform_decode is None if force_uniform_decode is None