[BugFix][main] Adapted Qwen3-Next-MTP to chunked prefill (#4770)
### What this PR does / why we need it?
The pad `-1` modification is from
https://github.com/vllm-project/vllm/pull/25743.
It still has bugs for batched chunked prefill.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: drslark <slarksblood@qq.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -107,8 +107,7 @@ from vllm.v1.worker.ec_connector_model_runner_mixin import \
|
||||
ECConnectorModelRunnerMixin
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
|
||||
gather_mm_placeholders,
|
||||
from vllm.v1.worker.utils import (AttentionGroup, gather_mm_placeholders,
|
||||
sanity_check_mm_encoder_outputs,
|
||||
scatter_mm_placeholders)
|
||||
|
||||
@@ -138,6 +137,7 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
||||
from vllm_ascend.eplb.eplb_updator import EplbUpdator
|
||||
from vllm_ascend.eplb.utils import model_register
|
||||
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
@@ -619,8 +619,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
)
|
||||
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int64)
|
||||
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
self.num_decode_draft_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
# Only relevant for multimodal models
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
||||
@@ -1808,17 +1808,26 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
# Iterate over the dictionary rather than all requests since not all
|
||||
# requests have draft tokens.
|
||||
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
||||
# For chunked prefills, use -1 as mask rather than 0, as guided
|
||||
# decoding may rollback speculative tokens.
|
||||
num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32)
|
||||
for req_id, draft_token_ids in (
|
||||
scheduler_output.scheduled_spec_decode_tokens.items()):
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
num_draft_tokens[req_idx] = len(draft_token_ids)
|
||||
num_decode_draft_tokens[req_idx] = (len(draft_token_ids) if (
|
||||
self.input_batch.num_computed_tokens_cpu[req_idx]
|
||||
>= self.input_batch.num_prompt_tokens[req_idx]) else -1)
|
||||
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens, self.num_pcp_pads[:num_reqs])
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
|
||||
self.num_draft_tokens.np[num_reqs:].fill(0)
|
||||
self.num_draft_tokens.copy_to_gpu()
|
||||
|
||||
# For DECODE only cuda graph of some attention backends (e.g., GDN).
|
||||
self.num_decode_draft_tokens.np[:
|
||||
num_reqs] = num_decode_draft_tokens
|
||||
self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
|
||||
self.num_decode_draft_tokens.copy_to_gpu()
|
||||
# save logits_indices for pcp spec decode usage
|
||||
self.logits_indices = logits_indices
|
||||
|
||||
@@ -1983,11 +1992,12 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
builder = attn_group.get_metadata_builder()
|
||||
if isinstance(builder, GDNAttentionMetadataBuilder):
|
||||
if use_spec_decode:
|
||||
patch_torch_npu_argsort()
|
||||
extra_attn_metadata_args = dict(
|
||||
num_accepted_tokens=self.num_accepted_tokens.
|
||||
gpu[:num_reqs],
|
||||
num_decode_draft_tokens_cpu=self.num_draft_tokens.
|
||||
gpu[:num_reqs],
|
||||
num_decode_draft_tokens_cpu=self.
|
||||
num_decode_draft_tokens.cpu[:num_reqs],
|
||||
)
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
@@ -3485,6 +3495,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
|
||||
kv_cache_raw_tensors)
|
||||
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
Reference in New Issue
Block a user