[long seq feat]GQA support long-prefill-token-threshold and fixbug (#4209)
### What this PR does / why we need it?
GQA chunk prefill with pcp and dcp support long-prefill-token-threshold
The markdown format results is as below:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8kdataset | - | accuracy | gen | 96.13 |
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Signed-off-by: Delphine-Nic <t00608739@china.huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <t00608739@china.huawei.com>
This commit is contained in:
@@ -294,21 +294,23 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
||||
self.block_size)
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
decode_max_num_seqs = getattr(self.scheduler_config,
|
||||
'decode_max_num_seqs', 0)
|
||||
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
|
||||
decode_max_num_seqs)
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||
) if self.pcp_size > 1 else 0
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
decode_max_num_seqs = getattr(self.scheduler_config,
|
||||
'decode_max_num_seqs', 0)
|
||||
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
|
||||
decode_max_num_seqs)
|
||||
if self.pcp_size > 1:
|
||||
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
|
||||
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
||||
self.block_size)
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.device = device
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
|
||||
self.prefetch_stream = torch.npu.Stream(device=device)
|
||||
@@ -1007,10 +1009,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def _make_attention_mask(self, seq_lens, position,
|
||||
attn_state) -> torch.Tensor:
|
||||
# pcp situation.
|
||||
if self.pcp_size > 1:
|
||||
return None
|
||||
if self.attn_mask_builder is None:
|
||||
raise ValueError("Attn mask builder is None")
|
||||
# dcp situation.
|
||||
if self.dcp_size > 1:
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
# Pooling situation.
|
||||
@@ -1018,12 +1022,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
return self.attn_mask_builder.get_pooling_mask(self.device)
|
||||
# Chunk Prefill situation.
|
||||
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse:
|
||||
if self.dcp_size > 1:
|
||||
max_seq_len = max(seq_lens.max().item(), 0)
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
max_seq_len, self.dtype, self.device)
|
||||
else:
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
|
||||
# Prefill without cache situation.
|
||||
elif attn_state == AscendAttentionState.PrefillNoCache:
|
||||
@@ -1039,6 +1038,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
return None
|
||||
|
||||
def _make_fia_attention_mask(self) -> torch.Tensor:
|
||||
# pcp situation.
|
||||
if self.pcp_size > 1:
|
||||
return None
|
||||
if self.attn_mask_builder is None:
|
||||
raise ValueError("Attn mask builder is None")
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
|
||||
Reference in New Issue
Block a user