[long seq feat]GQA support long-prefill-token-threshold and fixbug (#4209)
### What this PR does / why we need it?
GQA chunk prefill with pcp and dcp support long-prefill-token-threshold
The markdown format results is as below:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8kdataset | - | accuracy | gen | 96.13 |
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Signed-off-by: Delphine-Nic <t00608739@china.huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <t00608739@china.huawei.com>
This commit is contained in:
@@ -166,10 +166,12 @@ class AscendMetadataForPrefill:
|
|||||||
actual_chunk_seq_lengths: list[int]
|
actual_chunk_seq_lengths: list[int]
|
||||||
actual_seq_lengths_kv: list[int]
|
actual_seq_lengths_kv: list[int]
|
||||||
starts: torch.Tensor
|
starts: torch.Tensor
|
||||||
|
chunk_seq_mask_filtered_indices: torch.Tensor
|
||||||
chunked_req_mask: Optional[list[bool]] = None
|
chunked_req_mask: Optional[list[bool]] = None
|
||||||
local_context_lens_allranks: Optional[list[list[int]]] = None
|
local_context_lens_allranks: Optional[list[list[int]]] = None
|
||||||
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
|
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
|
||||||
kv_inverse_idx_for_chunk: Optional[list[int]] = None
|
kv_inverse_idx_for_chunk: Optional[list[int]] = None
|
||||||
|
batch_chunk_seq_mask: Optional[list[bool]] = None
|
||||||
|
|
||||||
""" Prefill Specific Metadata for Ascend"""
|
""" Prefill Specific Metadata for Ascend"""
|
||||||
pcp_metadata: Optional[AscendPCPMetadata] = None
|
pcp_metadata: Optional[AscendPCPMetadata] = None
|
||||||
@@ -401,6 +403,14 @@ class AscendAttentionMetadataBuilder:
|
|||||||
cp_kv_recover_idx_for_chunk.to(torch.float32)
|
cp_kv_recover_idx_for_chunk.to(torch.float32)
|
||||||
) if cp_kv_recover_idx_for_chunk is not None else None
|
) if cp_kv_recover_idx_for_chunk is not None else None
|
||||||
|
|
||||||
|
batch_chunk_seq_mask = (
|
||||||
|
local_context_lens_allranks[:, self.pcp_rank,
|
||||||
|
self.dcp_rank] == 0)
|
||||||
|
batch_chunk_seq_mask = torch.repeat_interleave(
|
||||||
|
batch_chunk_seq_mask,
|
||||||
|
repeats=(query_lens * self.pcp_size).to(self.device))
|
||||||
|
chunk_seq_mask_filtered_indices = filter_chunked_req_indices(
|
||||||
|
query_lens, chunked_req_mask).to(self.device)
|
||||||
chunked_context_metadata = \
|
chunked_context_metadata = \
|
||||||
AscendMetadataForPrefill.ChunkedContextMetadata(
|
AscendMetadataForPrefill.ChunkedContextMetadata(
|
||||||
actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0),
|
actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0),
|
||||||
@@ -409,7 +419,9 @@ class AscendAttentionMetadataBuilder:
|
|||||||
starts=local_chunk_starts,
|
starts=local_chunk_starts,
|
||||||
local_context_lens_allranks=local_context_lens_allranks,
|
local_context_lens_allranks=local_context_lens_allranks,
|
||||||
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk,
|
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk,
|
||||||
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk
|
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk,
|
||||||
|
batch_chunk_seq_mask=batch_chunk_seq_mask,
|
||||||
|
chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices
|
||||||
)
|
)
|
||||||
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
|
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
|
||||||
head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens
|
head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens
|
||||||
@@ -571,10 +583,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
|
kv_cache: Tuple[torch.Tensor],
|
||||||
attn_metadata: AscendMetadata,
|
attn_metadata: AscendMetadata,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
num_tokens=0):
|
num_tokens=0):
|
||||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
|
intermediate_output = self._forward_pcp_dcp(
|
||||||
|
query, key, value, kv_cache, attn_metadata, output)
|
||||||
|
return intermediate_output, query.shape[0]
|
||||||
|
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||||
block_size = 128
|
block_size = 128
|
||||||
block_table = None
|
block_table = None
|
||||||
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
|
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
|
||||||
@@ -1276,9 +1293,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :]
|
self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :]
|
||||||
|
|
||||||
assert attn_output_full_chunk.shape == current_attn_output_prefill.shape and attn_lse_full_chunk.shape == current_attn_lse_prefill.shape
|
assert attn_output_full_chunk.shape == current_attn_output_prefill.shape and attn_lse_full_chunk.shape == current_attn_lse_prefill.shape
|
||||||
seq_len = attn_metadata.query_lens.detach().clone()
|
filtered_indices = attn_metadata.prefill.chunked_context.chunk_seq_mask_filtered_indices
|
||||||
filtered_indices = filter_chunked_req_indices(
|
|
||||||
seq_len, attn_metadata.prefill.chunked_context.chunked_req_mask)
|
|
||||||
|
|
||||||
attn_output_prefill_filtered = current_attn_output_prefill[
|
attn_output_prefill_filtered = current_attn_output_prefill[
|
||||||
filtered_indices, :, :]
|
filtered_indices, :, :]
|
||||||
@@ -1322,9 +1337,11 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
|
|
||||||
local_chunked_kv_lens_rank = local_chunked_kv_lens[:, self.pcp_rank,
|
local_chunked_kv_lens_rank = local_chunked_kv_lens[:, self.pcp_rank,
|
||||||
self.dcp_rank]
|
self.dcp_rank]
|
||||||
|
total_toks = local_chunked_kv_lens_rank.sum()
|
||||||
|
|
||||||
key, value = self._load_kv_for_chunk(attn_metadata, kv_cache,
|
key, value = self._load_kv_for_chunk(attn_metadata, kv_cache,
|
||||||
local_chunked_kv_lens_rank, query)
|
local_chunked_kv_lens_rank, query,
|
||||||
|
total_toks)
|
||||||
if self.dcp_size > 1:
|
if self.dcp_size > 1:
|
||||||
num_heads = self.num_heads * self.dcp_size
|
num_heads = self.num_heads * self.dcp_size
|
||||||
else:
|
else:
|
||||||
@@ -1340,7 +1357,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=query.device)
|
device=query.device)
|
||||||
|
|
||||||
if not torch.all(local_chunked_kv_lens_rank == 0).item():
|
if total_toks > 0:
|
||||||
prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score(
|
prefix_chunk_output, prefix_chunk_lse = torch.ops.npu.npu_fused_infer_attention_score(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@@ -1358,6 +1375,14 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
actual_seq_lengths_kv,
|
actual_seq_lengths_kv,
|
||||||
actual_seq_lengths=attn_metadata.prefill.chunked_context.
|
actual_seq_lengths=attn_metadata.prefill.chunked_context.
|
||||||
actual_chunk_seq_lengths)
|
actual_chunk_seq_lengths)
|
||||||
|
batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask
|
||||||
|
out_mask = batch_chunk_seq_mask[:, None, None].expand_as(
|
||||||
|
prefix_chunk_output)
|
||||||
|
prefix_chunk_output = torch.where(out_mask, 0, prefix_chunk_output)
|
||||||
|
lse_mask = batch_chunk_seq_mask[:, None,
|
||||||
|
None].expand_as(prefix_chunk_lse)
|
||||||
|
prefix_chunk_lse = torch.where(lse_mask, -torch.inf,
|
||||||
|
prefix_chunk_lse)
|
||||||
|
|
||||||
prefix_output, prefix_lse = self._update_chunk_attn_out_lse(
|
prefix_output, prefix_lse = self._update_chunk_attn_out_lse(
|
||||||
prefix_chunk_output, prefix_chunk_lse)
|
prefix_chunk_output, prefix_chunk_lse)
|
||||||
@@ -1413,14 +1438,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
return prefix_output, prefix_lse
|
return prefix_output, prefix_lse
|
||||||
|
|
||||||
def _load_kv_for_chunk(self, attn_metadata, kv_cache,
|
def _load_kv_for_chunk(self, attn_metadata, kv_cache,
|
||||||
local_chunked_kv_lens_rank, query):
|
local_chunked_kv_lens_rank, query, total_toks):
|
||||||
cache_key = kv_cache[0]
|
cache_key = kv_cache[0]
|
||||||
cache_value = kv_cache[1]
|
cache_value = kv_cache[1]
|
||||||
num_heads = cache_key.size(2)
|
num_heads = cache_key.size(2)
|
||||||
head_size = kv_cache[0].size(-1)
|
head_size = kv_cache[0].size(-1)
|
||||||
|
|
||||||
total_toks = local_chunked_kv_lens_rank.sum()
|
|
||||||
|
|
||||||
key = torch.empty(total_toks,
|
key = torch.empty(total_toks,
|
||||||
num_heads,
|
num_heads,
|
||||||
head_size,
|
head_size,
|
||||||
@@ -1579,7 +1602,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
query, attn_metadata, output)
|
query, attn_metadata, output)
|
||||||
else:
|
else:
|
||||||
intermediate_output, num_tokens = self.full_graph_attention(
|
intermediate_output, num_tokens = self.full_graph_attention(
|
||||||
query, key, value, attn_metadata, output)
|
query, key, value, kv_cache, attn_metadata, output)
|
||||||
output[:num_tokens] = intermediate_output[:num_tokens]
|
output[:num_tokens] = intermediate_output[:num_tokens]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -294,21 +294,23 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.scheduler_config = vllm_config.scheduler_config
|
self.scheduler_config = vllm_config.scheduler_config
|
||||||
self.speculative_config = vllm_config.speculative_config
|
self.speculative_config = vllm_config.speculative_config
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
|
||||||
self.block_size)
|
|
||||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
|
||||||
decode_max_num_seqs = getattr(self.scheduler_config,
|
|
||||||
'decode_max_num_seqs', 0)
|
|
||||||
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
|
|
||||||
decode_max_num_seqs)
|
|
||||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||||
|
self.dcp_size = get_dcp_group().world_size
|
||||||
|
self.dcp_rank = get_dcp_group().rank_in_group
|
||||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||||
) if prefill_context_parallel_enable() else 1
|
) if prefill_context_parallel_enable() else 1
|
||||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||||
) if self.pcp_size > 1 else 0
|
) if self.pcp_size > 1 else 0
|
||||||
self.dcp_size = get_dcp_group().world_size
|
decode_max_num_seqs = getattr(self.scheduler_config,
|
||||||
self.dcp_rank = get_dcp_group().rank_in_group
|
'decode_max_num_seqs', 0)
|
||||||
|
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
|
||||||
|
decode_max_num_seqs)
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
|
||||||
|
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
||||||
|
self.block_size)
|
||||||
|
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
self.device = device
|
self.device = device
|
||||||
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
|
if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP:
|
||||||
self.prefetch_stream = torch.npu.Stream(device=device)
|
self.prefetch_stream = torch.npu.Stream(device=device)
|
||||||
@@ -1007,10 +1009,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
def _make_attention_mask(self, seq_lens, position,
|
def _make_attention_mask(self, seq_lens, position,
|
||||||
attn_state) -> torch.Tensor:
|
attn_state) -> torch.Tensor:
|
||||||
|
# pcp situation.
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
return None
|
return None
|
||||||
if self.attn_mask_builder is None:
|
if self.attn_mask_builder is None:
|
||||||
raise ValueError("Attn mask builder is None")
|
raise ValueError("Attn mask builder is None")
|
||||||
|
# dcp situation.
|
||||||
if self.dcp_size > 1:
|
if self.dcp_size > 1:
|
||||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||||
# Pooling situation.
|
# Pooling situation.
|
||||||
@@ -1018,12 +1022,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
return self.attn_mask_builder.get_pooling_mask(self.device)
|
return self.attn_mask_builder.get_pooling_mask(self.device)
|
||||||
# Chunk Prefill situation.
|
# Chunk Prefill situation.
|
||||||
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse:
|
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse:
|
||||||
if self.dcp_size > 1:
|
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||||
max_seq_len = max(seq_lens.max().item(), 0)
|
|
||||||
return self.attn_mask_builder.get_attn_mask(
|
|
||||||
max_seq_len, self.dtype, self.device)
|
|
||||||
else:
|
|
||||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
|
||||||
|
|
||||||
# Prefill without cache situation.
|
# Prefill without cache situation.
|
||||||
elif attn_state == AscendAttentionState.PrefillNoCache:
|
elif attn_state == AscendAttentionState.PrefillNoCache:
|
||||||
@@ -1039,6 +1038,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _make_fia_attention_mask(self) -> torch.Tensor:
|
def _make_fia_attention_mask(self) -> torch.Tensor:
|
||||||
|
# pcp situation.
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
return None
|
||||||
if self.attn_mask_builder is None:
|
if self.attn_mask_builder is None:
|
||||||
raise ValueError("Attn mask builder is None")
|
raise ValueError("Attn mask builder is None")
|
||||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||||
|
|||||||
Reference in New Issue
Block a user