[long seq feat]GQA support long-prefill-token-threshold and fixbug (#4209)
### What this PR does / why we need it?
GQA chunk prefill with pcp and dcp support long-prefill-token-threshold
The markdown format results is as below:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8kdataset | - | accuracy | gen | 96.13 |
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Signed-off-by: Delphine-Nic <t00608739@china.huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <t00608739@china.huawei.com>
This commit is contained in:
@@ -166,10 +166,12 @@ class AscendMetadataForPrefill:
|
||||
actual_chunk_seq_lengths: list[int]
|
||||
actual_seq_lengths_kv: list[int]
|
||||
starts: torch.Tensor
|
||||
chunk_seq_mask_filtered_indices: torch.Tensor
|
||||
chunked_req_mask: Optional[list[bool]] = None
|
||||
local_context_lens_allranks: Optional[list[list[int]]] = None
|
||||
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
|
||||
kv_inverse_idx_for_chunk: Optional[list[int]] = None
|
||||
batch_chunk_seq_mask: Optional[list[bool]] = None
|
||||
|
||||
""" Prefill Specific Metadata for Ascend"""
|
||||
pcp_metadata: Optional[AscendPCPMetadata] = None
|
||||
@@ -401,6 +403,14 @@ class AscendAttentionMetadataBuilder:
|
||||
cp_kv_recover_idx_for_chunk.to(torch.float32)
|
||||
) if cp_kv_recover_idx_for_chunk is not None else None
|
||||
|
||||
batch_chunk_seq_mask = (
|
||||
local_context_lens_allranks[:, self.pcp_rank,
|
||||
self.dcp_rank] == 0)
|
||||
batch_chunk_seq_mask = torch.repeat_interleave(
|
||||
batch_chunk_seq_mask,
|
||||
repeats=(query_lens * self.pcp_size).to(self.device))
|
||||
chunk_seq_mask_filtered_indices = filter_chunked_req_indices(
|
||||
query_lens, chunked_req_mask).to(self.device)
|
||||
chunked_context_metadata = \
|
||||
AscendMetadataForPrefill.ChunkedContextMetadata(
|
||||
actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0),
|
||||
@@ -409,7 +419,9 @@ class AscendAttentionMetadataBuilder:
|
||||
starts=local_chunk_starts,
|
||||
local_context_lens_allranks=local_context_lens_allranks,
|
||||
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk,
|
||||
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk
|
||||
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk,
|
||||
batch_chunk_seq_mask=batch_chunk_seq_mask,
|
||||
chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices
|
||||
)
|
||||
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
|
||||
head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens
|
||||
@@ -571,10 +583,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor],
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor,
|
||||
num_tokens=0):
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
intermediate_output = self._forward_pcp_dcp(
|
||||
query, key, value, kv_cache, attn_metadata, output)
|
||||
return intermediate_output, query.shape[0]
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
block_size = 128
|
||||
block_table = None
|
||||
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
|
||||
@@ -1276,9 +1293,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :]
|
||||
|
||||
assert attn_output_full_chunk.shape == current_attn_output_prefill.shape and attn_lse_full_chunk.shape == current_attn_lse_prefill.shape
|
||||
seq_len = attn_metadata.query_lens.detach().clone()
|
||||
filtered_indices = filter_chunked_req_indices(
|
||||
seq_len, attn_metadata.prefill.chunked_context.chunked_req_mask)
|
||||
filtered_indices = attn_metadata.prefill.chunked_context.chunk_seq_mask_filtered_indices
|
||||
|
||||
attn_output_prefill_filtered = current_attn_output_prefill[
|
||||
filtered_indices, :, :]
|
||||
@@ -1322,9 +1337,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
local_chunked_kv_lens_rank = local_chunked_kv_lens[:, self.pcp_rank,
|
||||
self.dcp_rank]
|
||||
total_toks = local_chunked_kv_lens_rank.sum()
|
||||
|
||||
key, value = self._load_kv_for_chunk(attn_metadata, kv_cache,
|
||||
local_chunked_kv_lens_rank, query)
|
||||
local_chunked_kv_lens_rank, query,
|
||||
total_toks)
|
||||
if self.dcp_size > 1:
|
||||
num_heads = self.num_heads * self.dcp_size
|
||||
else:
|
||||
@@ -1340,7 +1357,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
dtype=torch.float32,
|
||||
device=query.device)
|
||||
|
||||
if not torch.all(local_chunked_kv_lens_rank == 0).item():
|
||||
if total_toks > 0:
|
||||
prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
query,
|
||||
key,
|
||||
@@ -1358,6 +1375,14 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
actual_seq_lengths_kv,
|
||||
actual_seq_lengths=attn_metadata.prefill.chunked_context.
|
||||
actual_chunk_seq_lengths)
|
||||
batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask
|
||||
out_mask = batch_chunk_seq_mask[:, None, None].expand_as(
|
||||
prefix_chunk_output)
|
||||
prefix_chunk_output = torch.where(out_mask, 0, prefix_chunk_output)
|
||||
lse_mask = batch_chunk_seq_mask[:, None,
|
||||
None].expand_as(prefix_chunk_lse)
|
||||
prefix_chunk_lse = torch.where(lse_mask, -torch.inf,
|
||||
prefix_chunk_lse)
|
||||
|
||||
prefix_output, prefix_lse = self._update_chunk_attn_out_lse(
|
||||
prefix_chunk_output, prefix_chunk_lse)
|
||||
@@ -1413,14 +1438,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
def _load_kv_for_chunk(self, attn_metadata, kv_cache,
|
||||
local_chunked_kv_lens_rank, query):
|
||||
local_chunked_kv_lens_rank, query, total_toks):
|
||||
cache_key = kv_cache[0]
|
||||
cache_value = kv_cache[1]
|
||||
num_heads = cache_key.size(2)
|
||||
head_size = kv_cache[0].size(-1)
|
||||
|
||||
total_toks = local_chunked_kv_lens_rank.sum()
|
||||
|
||||
key = torch.empty(total_toks,
|
||||
num_heads,
|
||||
head_size,
|
||||
@@ -1579,7 +1602,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
query, attn_metadata, output)
|
||||
else:
|
||||
intermediate_output, num_tokens = self.full_graph_attention(
|
||||
query, key, value, attn_metadata, output)
|
||||
query, key, value, kv_cache, attn_metadata, output)
|
||||
output[:num_tokens] = intermediate_output[:num_tokens]
|
||||
|
||||
return output
|
||||
|
||||
@@ -294,21 +294,23 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
||||
self.block_size)
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
decode_max_num_seqs = getattr(self.scheduler_config,
|
||||
'decode_max_num_seqs', 0)
|
||||
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
|
||||
decode_max_num_seqs)
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||
) if self.pcp_size > 1 else 0
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
decode_max_num_seqs = getattr(self.scheduler_config,
|
||||
'decode_max_num_seqs', 0)
|
||||
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
|
||||
decode_max_num_seqs)
|
||||
if self.pcp_size > 1:
|
||||
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
|
||||
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
||||
self.block_size)
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.device = device
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
|
||||
self.prefetch_stream = torch.npu.Stream(device=device)
|
||||
@@ -1007,10 +1009,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def _make_attention_mask(self, seq_lens, position,
|
||||
attn_state) -> torch.Tensor:
|
||||
# pcp situation.
|
||||
if self.pcp_size > 1:
|
||||
return None
|
||||
if self.attn_mask_builder is None:
|
||||
raise ValueError("Attn mask builder is None")
|
||||
# dcp situation.
|
||||
if self.dcp_size > 1:
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
# Pooling situation.
|
||||
@@ -1018,12 +1022,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
return self.attn_mask_builder.get_pooling_mask(self.device)
|
||||
# Chunk Prefill situation.
|
||||
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse:
|
||||
if self.dcp_size > 1:
|
||||
max_seq_len = max(seq_lens.max().item(), 0)
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
max_seq_len, self.dtype, self.device)
|
||||
else:
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
|
||||
# Prefill without cache situation.
|
||||
elif attn_state == AscendAttentionState.PrefillNoCache:
|
||||
@@ -1039,6 +1038,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
return None
|
||||
|
||||
def _make_fia_attention_mask(self) -> torch.Tensor:
|
||||
# pcp situation.
|
||||
if self.pcp_size > 1:
|
||||
return None
|
||||
if self.attn_mask_builder is None:
|
||||
raise ValueError("Attn mask builder is None")
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
|
||||
Reference in New Issue
Block a user