fix : support chunked_prefill with deepseek_mtp (#2711)

### What this PR does / why we need it?
fix : support chunked_prefill with deepseek_mtp

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
```
vllm serve $MODEL_PATH
    --quantization ascend \
    --served-model-name auto \
    --trust-remote-code \
    --distributed-executor-backend=mp \
    --port 8006 \
    -tp=8 \
    -dp=2 \
    --no-enforce-eager \
    --max-num-seqs 24 \
    --max-model-len 32768 \
    --max-num-batched-tokens 16384 \
    --block-size 128 \
    --no-enable-prefix-caching \
    --disable-log-requests \
    --speculative-config '{"num_speculative_tokens":1, "method": "deepseek_mtp"}' \
    --additional-config '{"torchair_graph_config":{"enabled":true,"use_cached_graph":true,"graph_batch_sizes":[24],"enable_multistream_mla": true},"ascend_scheduler_config":{"enabled":false},"expert_tensor_parallel_size":16, "chunked_prefill_for_mla":true}' \
   --gpu-memory-utilization 0.95
```

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
NeverRaR
2025-10-22 11:52:27 +08:00
committed by GitHub
parent 2f1b9a7a64
commit f2dd5f8d08

View File

@@ -426,8 +426,8 @@ class AscendMLATorchairMetadataBuilder:
if num_prefills > 0: if num_prefills > 0:
reqs_start = num_decodes # prefill_start reqs_start = num_decodes # prefill_start
tokens_start = num_decode_tokens tokens_start = num_decode_tokens
max_query_len = query_lens[tokens_start:].max().item() max_query_len = query_lens[reqs_start:].max().item()
max_seq_lens = seq_lens[tokens_start:].max().item() max_seq_lens = seq_lens[reqs_start:].max().item()
prefill_query_start_loc = query_start_loc[ prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start] reqs_start:] - query_start_loc[reqs_start]
@@ -473,9 +473,9 @@ class AscendMLATorchairMetadataBuilder:
1).unsqueeze(2) 1).unsqueeze(2)
prefill_metadata = AscendMLATorchairPrefillMetadata( prefill_metadata = AscendMLATorchairPrefillMetadata(
attn_mask=common_attn_metadata.attn_mask, attn_mask=common_attn_metadata.attn_mask,
query_lens=query_lens[tokens_start:].to(torch.int32), query_lens=query_lens[reqs_start:].to(torch.int32),
seq_lens=seq_lens, seq_lens=seq_lens,
context_lens=seq_lens[tokens_start:], context_lens=seq_lens[reqs_start:],
input_positions=prefill_input_positions, input_positions=prefill_input_positions,
block_table=block_table[reqs_start:, ...], block_table=block_table[reqs_start:, ...],
max_query_len=max_query_len, max_query_len=max_query_len,