[long seq feat]GQA support long-prefill-token-threshold and fixbug (#4209)

### What this PR does / why we need it?
GQA chunk prefill with pcp and dcp support long-prefill-token-threshold

The markdown format results is as below:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8kdataset | - | accuracy | gen | 96.13 |

- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

---------

Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Signed-off-by: Delphine-Nic <t00608739@china.huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <t00608739@china.huawei.com>
This commit is contained in:
Delphine-Nic
2025-11-19 18:10:27 +08:00
committed by GitHub
parent 97daf7f78c
commit a3e9673137
2 changed files with 51 additions and 26 deletions

View File

@@ -294,21 +294,23 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.block_size = vllm_config.cache_config.block_size
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
self.block_size)
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
decode_max_num_seqs = getattr(self.scheduler_config,
'decode_max_num_seqs', 0)
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
decode_max_num_seqs)
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.dcp_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
self.pcp_rank = get_prefill_context_model_parallel_rank(
) if self.pcp_size > 1 else 0
self.dcp_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
decode_max_num_seqs = getattr(self.scheduler_config,
'decode_max_num_seqs', 0)
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
decode_max_num_seqs)
if self.pcp_size > 1:
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
self.block_size)
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.device = device
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
self.prefetch_stream = torch.npu.Stream(device=device)
@@ -1007,10 +1009,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _make_attention_mask(self, seq_lens, position,
attn_state) -> torch.Tensor:
# pcp situation.
if self.pcp_size > 1:
return None
if self.attn_mask_builder is None:
raise ValueError("Attn mask builder is None")
# dcp situation.
if self.dcp_size > 1:
return self.attn_mask_builder.get_splitfuse_attn_mask()
# Pooling situation.
@@ -1018,12 +1022,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return self.attn_mask_builder.get_pooling_mask(self.device)
# Chunk Prefill situation.
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse:
if self.dcp_size > 1:
max_seq_len = max(seq_lens.max().item(), 0)
return self.attn_mask_builder.get_attn_mask(
max_seq_len, self.dtype, self.device)
else:
return self.attn_mask_builder.get_splitfuse_attn_mask()
return self.attn_mask_builder.get_splitfuse_attn_mask()
# Prefill without cache situation.
elif attn_state == AscendAttentionState.PrefillNoCache:
@@ -1039,6 +1038,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return None
def _make_fia_attention_mask(self) -> torch.Tensor:
# pcp situation.
if self.pcp_size > 1:
return None
if self.attn_mask_builder is None:
raise ValueError("Attn mask builder is None")
return self.attn_mask_builder.get_splitfuse_attn_mask()