feat(attention_cp): support chunked prefill for Qwen3Next with PCP&DCP (#6900)
### What this PR does / why we need it?
Support chunked prefill for Qwen3Next with PCP&DCP
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -141,12 +141,14 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp
|
||||
assert num_computed_tokens_of_pcp_dcp is not None
|
||||
chunked_context_metadata = None
|
||||
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
|
||||
if num_prefills > 0:
|
||||
query_lens = query_lens[num_decode_tokens:]
|
||||
context_lens_cpu = num_computed_tokens_cpu[num_decodes:num_reqs]
|
||||
max_context_len_cpu = context_lens_cpu.max().item()
|
||||
pcp_size = get_pcp_group().world_size
|
||||
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
||||
if self.pcp_size > 1 and common_long_seq_metadata.pcp_use_hybrid_attn:
|
||||
query_lens = attn_mask_seqlens[0] * 2
|
||||
local_context_lens_allranks = (
|
||||
torch.tensor(num_computed_tokens_of_pcp_dcp)[self.num_decodes_flatten :]
|
||||
.to(self.device)
|
||||
@@ -163,7 +165,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
# when only using dcp.
|
||||
if self.pcp_size > 1:
|
||||
kv_inverse_idx_for_chunk = torch.argsort(
|
||||
common_long_seq_metadata.pcp_allgather_restore_idx[pcp_size * num_decode_tokens :].to(
|
||||
common_long_seq_metadata.pcp_allgather_restore_idx[self.pcp_size * num_decode_tokens :].to(
|
||||
torch.float32
|
||||
)
|
||||
)
|
||||
@@ -172,29 +174,23 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
kv_inverse_idx_for_chunk = None
|
||||
cp_kv_recover_idx_for_chunk = None
|
||||
|
||||
batch_chunk_seq_mask = local_context_lens_allranks[:, self.pcp_rank, self.dcp_rank] == 0
|
||||
batch_chunk_seq_mask = torch.repeat_interleave(
|
||||
batch_chunk_seq_mask, repeats=(query_lens * self.pcp_size).to(self.device)
|
||||
)
|
||||
chunk_seq_mask_filtered_indices = filter_chunked_req_indices(query_lens, chunked_req_mask).to(
|
||||
self.device
|
||||
)
|
||||
chunked_context_metadata = AscendMetadataForPrefill.ChunkedContextMetadata(
|
||||
actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0),
|
||||
actual_chunk_seq_lengths=torch.cumsum(query_lens * self.pcp_size, dim=0),
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
chunked_req_mask=chunked_req_mask,
|
||||
starts=local_chunk_starts,
|
||||
local_context_lens_allranks=local_context_lens_allranks,
|
||||
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk,
|
||||
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk,
|
||||
batch_chunk_seq_mask=batch_chunk_seq_mask,
|
||||
chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices,
|
||||
local_total_toks=local_total_toks.item(),
|
||||
)
|
||||
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
|
||||
head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens
|
||||
tail_attn_nomask_seqlens = common_long_seq_metadata.tail_attn_nomask_seqlens
|
||||
if pcp_size > 1:
|
||||
if self.pcp_size > 1:
|
||||
attn_mask_seqlens = torch.cumsum(attn_mask_seqlens[0], dim=0).tolist()
|
||||
head_attn_nomask_seqlens = torch.cumsum(head_attn_nomask_seqlens[1], dim=0).tolist()
|
||||
tail_attn_nomask_seqlens = torch.cumsum(tail_attn_nomask_seqlens[1], dim=0).tolist()
|
||||
@@ -220,6 +216,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
|
||||
prefill_metadata = AscendMetadataForPrefill(
|
||||
pcp_metadata=pcp_metadata,
|
||||
pcp_exit_fa_scatter_idx=common_long_seq_metadata.pcp_exit_fa_scatter_idx,
|
||||
chunked_context=chunked_context_metadata,
|
||||
block_tables=block_table[self.num_decodes_flatten :, ...],
|
||||
actual_seq_lengths_q=torch.cumsum(query_lens, dim=0),
|
||||
@@ -475,9 +472,6 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
||||
kv_with_q_head_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx
|
||||
kv_with_q_tail_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx
|
||||
kv_with_q_tail_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx
|
||||
if attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn:
|
||||
fa_query_idx = attn_metadata.prefill.pcp_metadata.pcp_fa_query_idx
|
||||
query = torch.index_select(query, 0, fa_query_idx)
|
||||
|
||||
q_head = torch.index_select(query, 0, q_head_idx)
|
||||
q_tail = torch.index_select(query, 0, q_tail_idx)
|
||||
@@ -541,7 +535,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
||||
assert self.value_cache is not None
|
||||
|
||||
if self.dcp_size > 1:
|
||||
query = get_dcp_group().all_gather(query, 1)
|
||||
query = get_dcp_group().all_gather(query.contiguous(), 1)
|
||||
num_heads = self.num_heads * self.dcp_size
|
||||
else:
|
||||
num_heads = self.num_heads
|
||||
@@ -936,6 +930,9 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
||||
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
||||
if pcp_use_hybrid_attn:
|
||||
prefill_query = query[self.pcp_size * num_decode_tokens :]
|
||||
assert attn_metadata.prefill is not None and attn_metadata.prefill.pcp_metadata is not None
|
||||
fa_query_idx = attn_metadata.prefill.pcp_metadata.pcp_fa_query_idx
|
||||
prefill_query = torch.index_select(prefill_query, 0, fa_query_idx)
|
||||
else:
|
||||
prefill_query = query[num_decode_tokens:num_actual_tokens_pcp_padded].contiguous()
|
||||
key = key[self.pcp_size * num_decode_tokens : attn_metadata.num_actual_tokens_pcp_padded].contiguous()
|
||||
@@ -993,7 +990,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None:
|
||||
if has_chunked_context:
|
||||
# update the output of current chunk with context part
|
||||
torch.npu.current_stream().wait_stream(cp_chunkedprefill_comm_stream())
|
||||
global_context_output = global_context_output.permute([2, 0, 1]).contiguous()
|
||||
@@ -1005,9 +1002,9 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
||||
if self.pcp_size > 1 and pcp_use_hybrid_attn:
|
||||
# layer_idx != num_layers - 1
|
||||
assert attn_metadata.prefill.pcp_metadata is not None
|
||||
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx
|
||||
pcp_exit_fa_scatter_idx = attn_metadata.prefill.pcp_exit_fa_scatter_idx
|
||||
attn_output_prefill = get_pcp_group().all_gather(attn_output_prefill.contiguous(), dim=0)
|
||||
attn_output_prefill = torch.index_select(attn_output_prefill, 0, pcp_allgather_restore_idx)
|
||||
attn_output_prefill = torch.index_select(attn_output_prefill, 0, pcp_exit_fa_scatter_idx)
|
||||
fla_padding = attn_output_prefill.shape[0] + num_decode_tokens - output.shape[0]
|
||||
output = F.pad(output, pad=(0, 0, 0, 0, 0, fla_padding), mode="constant", value=0)
|
||||
|
||||
|
||||
@@ -78,11 +78,11 @@ class AscendMetadataForPrefill:
|
||||
local_context_lens_allranks: list[list[int]] | None = None
|
||||
cp_kv_recover_idx_for_chunk: list[int] | None = None
|
||||
kv_inverse_idx_for_chunk: list[int] | None = None
|
||||
batch_chunk_seq_mask: list[bool] | None = None
|
||||
local_total_toks: int | None = None
|
||||
|
||||
""" Prefill Specific Metadata for Ascend"""
|
||||
pcp_metadata: AscendPCPMetadata | None = None
|
||||
pcp_exit_fa_scatter_idx: torch.Tensor | None = None
|
||||
chunked_context: ChunkedContextMetadata | None = None
|
||||
block_tables: torch.Tensor = None
|
||||
actual_seq_lengths_q: torch.Tensor = None
|
||||
|
||||
Reference in New Issue
Block a user