[BugFix][main] Adapted Qwen3-Next-MTP to chunked prefill (#4770)

### What this PR does / why we need it?
The pad `-1` modification is from
https://github.com/vllm-project/vllm/pull/25743.

It still has bugs for batched chunked prefill.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: drslark <slarksblood@qq.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
drslark
2025-12-10 22:54:24 +08:00
committed by GitHub
parent 490ddf536f
commit 0fb1dc43a1
8 changed files with 646 additions and 28 deletions

View File

@@ -107,8 +107,7 @@ from vllm.v1.worker.ec_connector_model_runner_mixin import \
ECConnectorModelRunnerMixin
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
gather_mm_placeholders,
from vllm.v1.worker.utils import (AttentionGroup, gather_mm_placeholders,
sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
@@ -138,6 +137,7 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess
from vllm_ascend.eplb.eplb_updator import EplbUpdator
from vllm_ascend.eplb.utils import model_register
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.logits_processor import build_logitsprocs
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
@@ -619,8 +619,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
)
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int64)
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
self.num_decode_draft_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
# Only relevant for multimodal models
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
@@ -1808,17 +1808,26 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
# For chunked prefills, use -1 as mask rather than 0, as guided
# decoding may rollback speculative tokens.
num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
num_decode_draft_tokens[req_idx] = (len(draft_token_ids) if (
self.input_batch.num_computed_tokens_cpu[req_idx]
>= self.input_batch.num_prompt_tokens[req_idx]) else -1)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens, self.num_pcp_pads[:num_reqs])
logits_indices = spec_decode_metadata.logits_indices
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
self.num_draft_tokens.np[num_reqs:].fill(0)
self.num_draft_tokens.copy_to_gpu()
# For DECODE only cuda graph of some attention backends (e.g., GDN).
self.num_decode_draft_tokens.np[:
num_reqs] = num_decode_draft_tokens
self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
self.num_decode_draft_tokens.copy_to_gpu()
# save logits_indices for pcp spec decode usage
self.logits_indices = logits_indices
@@ -1983,11 +1992,12 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
builder = attn_group.get_metadata_builder()
if isinstance(builder, GDNAttentionMetadataBuilder):
if use_spec_decode:
patch_torch_npu_argsort()
extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.
gpu[:num_reqs],
num_decode_draft_tokens_cpu=self.num_draft_tokens.
gpu[:num_reqs],
num_decode_draft_tokens_cpu=self.
num_decode_draft_tokens.cpu[:num_reqs],
)
attn_metadata_i = builder.build(
common_prefix_len=common_prefix_len,
@@ -3485,6 +3495,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
kv_cache_raw_tensors)
from vllm.v1.worker.utils import bind_kv_cache
bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches)