vllm-ascend support chunked prefill (#1172)

### What this PR does / why we need it?
vllm-ascend support chunked prefill for MLA


---------

Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
fems14
2025-06-14 22:31:16 +08:00
committed by GitHub
parent a3b5af8307
commit ab5d110fcc
5 changed files with 303 additions and 20 deletions

View File

@@ -134,7 +134,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.lora_config = vllm_config.lora_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.chunked_prefill_enabled = vllm_config.scheduler_config.chunked_prefill_enabled
ascend_config = get_ascend_config()
if ascend_config.ascend_scheduler_config.enabled:
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
else:
self.chunked_prefill_enabled = True
self.device = device
self.is_multimodal_model = self.model_config.is_multimodal_model
@@ -1260,6 +1264,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
discard_sampled_tokens_req_indices = []
for i, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
@@ -1270,6 +1275,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
generator = self.input_batch.generators.get(i)
if generator is not None:
generator.set_offset(generator.get_offset() - 4)
discard_sampled_tokens_req_indices.append(i)
# NOTE: NPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
@@ -1290,6 +1296,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.vocab_size,
)
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
spec_token_ids = self._get_spec_token_ids(
valid_sampled_token_ids,
sampling_metadata,