feat(attention_cp): support chunked prefill for Qwen3Next with PCP&DCP (#6900)

### What this PR does / why we need it?
Support chunked prefill for Qwen3Next with PCP&DCP

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
Qiu
2026-03-09 17:55:09 +08:00
committed by GitHub
parent a76a509fae
commit 13adcbe44b
6 changed files with 63 additions and 63 deletions

View File

@@ -141,12 +141,14 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp
assert num_computed_tokens_of_pcp_dcp is not None
chunked_context_metadata = None
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
if num_prefills > 0:
query_lens = query_lens[num_decode_tokens:]
context_lens_cpu = num_computed_tokens_cpu[num_decodes:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item()
pcp_size = get_pcp_group().world_size
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
if self.pcp_size > 1 and common_long_seq_metadata.pcp_use_hybrid_attn:
query_lens = attn_mask_seqlens[0] * 2
local_context_lens_allranks = (
torch.tensor(num_computed_tokens_of_pcp_dcp)[self.num_decodes_flatten :]
.to(self.device)
@@ -163,7 +165,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
# when only using dcp.
if self.pcp_size > 1:
kv_inverse_idx_for_chunk = torch.argsort(
common_long_seq_metadata.pcp_allgather_restore_idx[pcp_size * num_decode_tokens :].to(
common_long_seq_metadata.pcp_allgather_restore_idx[self.pcp_size * num_decode_tokens :].to(
torch.float32
)
)
@@ -172,29 +174,23 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
kv_inverse_idx_for_chunk = None
cp_kv_recover_idx_for_chunk = None
batch_chunk_seq_mask = local_context_lens_allranks[:, self.pcp_rank, self.dcp_rank] == 0
batch_chunk_seq_mask = torch.repeat_interleave(
batch_chunk_seq_mask, repeats=(query_lens * self.pcp_size).to(self.device)
)
chunk_seq_mask_filtered_indices = filter_chunked_req_indices(query_lens, chunked_req_mask).to(
self.device
)
chunked_context_metadata = AscendMetadataForPrefill.ChunkedContextMetadata(
actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0),
actual_chunk_seq_lengths=torch.cumsum(query_lens * self.pcp_size, dim=0),
actual_seq_lengths_kv=actual_seq_lengths_kv,
chunked_req_mask=chunked_req_mask,
starts=local_chunk_starts,
local_context_lens_allranks=local_context_lens_allranks,
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk,
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk,
batch_chunk_seq_mask=batch_chunk_seq_mask,
chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices,
local_total_toks=local_total_toks.item(),
)
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens
tail_attn_nomask_seqlens = common_long_seq_metadata.tail_attn_nomask_seqlens
if pcp_size > 1:
if self.pcp_size > 1:
attn_mask_seqlens = torch.cumsum(attn_mask_seqlens[0], dim=0).tolist()
head_attn_nomask_seqlens = torch.cumsum(head_attn_nomask_seqlens[1], dim=0).tolist()
tail_attn_nomask_seqlens = torch.cumsum(tail_attn_nomask_seqlens[1], dim=0).tolist()
@@ -220,6 +216,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
prefill_metadata = AscendMetadataForPrefill(
pcp_metadata=pcp_metadata,
pcp_exit_fa_scatter_idx=common_long_seq_metadata.pcp_exit_fa_scatter_idx,
chunked_context=chunked_context_metadata,
block_tables=block_table[self.num_decodes_flatten :, ...],
actual_seq_lengths_q=torch.cumsum(query_lens, dim=0),
@@ -475,9 +472,6 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
kv_with_q_head_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx
kv_with_q_tail_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx
kv_with_q_tail_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx
if attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn:
fa_query_idx = attn_metadata.prefill.pcp_metadata.pcp_fa_query_idx
query = torch.index_select(query, 0, fa_query_idx)
q_head = torch.index_select(query, 0, q_head_idx)
q_tail = torch.index_select(query, 0, q_tail_idx)
@@ -541,7 +535,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
assert self.value_cache is not None
if self.dcp_size > 1:
query = get_dcp_group().all_gather(query, 1)
query = get_dcp_group().all_gather(query.contiguous(), 1)
num_heads = self.num_heads * self.dcp_size
else:
num_heads = self.num_heads
@@ -936,6 +930,9 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
if pcp_use_hybrid_attn:
prefill_query = query[self.pcp_size * num_decode_tokens :]
assert attn_metadata.prefill is not None and attn_metadata.prefill.pcp_metadata is not None
fa_query_idx = attn_metadata.prefill.pcp_metadata.pcp_fa_query_idx
prefill_query = torch.index_select(prefill_query, 0, fa_query_idx)
else:
prefill_query = query[num_decode_tokens:num_actual_tokens_pcp_padded].contiguous()
key = key[self.pcp_size * num_decode_tokens : attn_metadata.num_actual_tokens_pcp_padded].contiguous()
@@ -993,7 +990,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
attn_metadata,
)
if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None:
if has_chunked_context:
# update the output of current chunk with context part
torch.npu.current_stream().wait_stream(cp_chunkedprefill_comm_stream())
global_context_output = global_context_output.permute([2, 0, 1]).contiguous()
@@ -1005,9 +1002,9 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
if self.pcp_size > 1 and pcp_use_hybrid_attn:
# layer_idx != num_layers - 1
assert attn_metadata.prefill.pcp_metadata is not None
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx
pcp_exit_fa_scatter_idx = attn_metadata.prefill.pcp_exit_fa_scatter_idx
attn_output_prefill = get_pcp_group().all_gather(attn_output_prefill.contiguous(), dim=0)
attn_output_prefill = torch.index_select(attn_output_prefill, 0, pcp_allgather_restore_idx)
attn_output_prefill = torch.index_select(attn_output_prefill, 0, pcp_exit_fa_scatter_idx)
fla_padding = attn_output_prefill.shape[0] + num_decode_tokens - output.shape[0]
output = F.pad(output, pad=(0, 0, 0, 0, 0, fla_padding), mode="constant", value=0)

View File

@@ -78,11 +78,11 @@ class AscendMetadataForPrefill:
local_context_lens_allranks: list[list[int]] | None = None
cp_kv_recover_idx_for_chunk: list[int] | None = None
kv_inverse_idx_for_chunk: list[int] | None = None
batch_chunk_seq_mask: list[bool] | None = None
local_total_toks: int | None = None
""" Prefill Specific Metadata for Ascend"""
pcp_metadata: AscendPCPMetadata | None = None
pcp_exit_fa_scatter_idx: torch.Tensor | None = None
chunked_context: ChunkedContextMetadata | None = None
block_tables: torch.Tensor = None
actual_seq_lengths_q: torch.Tensor = None

View File

@@ -113,6 +113,10 @@ class AscendPrefillContextParallelMetadata:
# when entering from linear-attention to attention
pcp_enter_fa_restore_idx: torch.Tensor = None
# scatter the full sequence across all pcp ranks
# when exiting from attention to linear-attention
pcp_exit_fa_scatter_idx: torch.Tensor = None
# the number of tokens padded in linear-attn per rank
pcp_padded_tokens_fla: int = 0