feat(attention_cp): support chunked prefill for Qwen3Next with PCP&DCP (#6900)

### What this PR does / why we need it?
Support chunked prefill for Qwen3Next with PCP&DCP

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
Qiu
2026-03-09 17:55:09 +08:00
committed by GitHub
parent a76a509fae
commit 13adcbe44b
6 changed files with 63 additions and 63 deletions

View File

@@ -75,6 +75,12 @@ class PCPManager:
device=device,
pin_memory=pin_memory,
)
self.pcp_exit_fa_scatter_idx = CpuGpuBuffer(
max_buffer_num_tokens,
dtype=torch.int64,
device=device,
pin_memory=pin_memory,
)
self.pcp_padded_slot_mapping = torch.full(
(max_buffer_num_tokens,),
fill_value=-1,
@@ -110,9 +116,9 @@ class PCPManager:
self.max_num_tokens, dtype=torch.int64, device="cpu", pin_memory=pin_memory
)
self.positions_pcp_full_np = self.positions_pcp_full.numpy()
self.query_lens_pcp_full = CpuGpuBuffer(
self.max_num_reqs, dtype=torch.int32, device=device, pin_memory=pin_memory
)
self.query_lens_pcp_full = CpuGpuBuffer(
self.max_num_reqs, dtype=torch.int32, device=device, pin_memory=pin_memory
)
self.pcp_fa_query_idx = torch.zeros(
self.max_num_tokens + 2 * self.max_num_reqs, dtype=torch.int32, device=self.device
)
@@ -164,6 +170,10 @@ class PCPManager:
self.num_prefill_reqs = num_reqs - self.num_decode_reqs
self.num_decode_tokens = num_scheduled_tokens[: self.num_decode_reqs].sum()
self.query_lens_pcp_full.cpu[: self.num_reqs] = torch.from_numpy(num_scheduled_tokens)
self.query_lens_pcp_full.cpu[self.num_reqs :].fill_(0)
self.query_lens_pcp_full.copy_to_gpu()
def update_tokens_for_pcp(
self,
num_scheduled_tokens: np.ndarray,
@@ -301,6 +311,17 @@ class PCPManager:
num_scheduled_tokens[: self.num_decode_reqs], arange_np
)[1]
# Build the restore index used after allgather.
all_positions_lst = [
get_current_rank_positions(padded_pos_start_loc, rank_i) for rank_i in range(self.pcp_world_size)
]
all_positions = np.concatenate(all_positions_lst)
self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = all_positions.argsort()
self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0])
self.pcp_tokens[: self.num_reqs] = pcp_tokens[: self.num_reqs]
self.total_num_sampled_tokens_pcp = pcp_tokens[: self.num_reqs].sum()
if self.pcp_use_hybrid_attn:
max_scheduled_prefill_tokens = 0
self.pcp_padded_tokens_fla = 0
@@ -405,7 +426,7 @@ class PCPManager:
for rank_i in range(self.pcp_world_size)
]
all_positions_prefill_tensor = torch.from_numpy(np.concatenate(all_positions_prefill))
all_enter_fla_restore_idx = all_positions_prefill_tensor.float().argsort()
all_exit_fa_restore_idx = all_positions_prefill_tensor.float().argsort()
unpad_mask_prefill = self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length][
self.num_decode_reqs * self.pcp_world_size :
]
@@ -413,14 +434,15 @@ class PCPManager:
ori_tokens_start_loc = np.roll(np.cumsum(num_scheduled_tokens[self.num_decode_tokens :]), 1)
ori_tokens_start_loc[0] = 0
# [0,1,2] [3,4] | [0,1,7,8] [2,3,9] [4,5,10] [6,11]
enter_fla_scatter_idx = positions_linear[self.num_decode_reqs :] + np.repeat(
exit_fa_scatter_indices = positions_linear[self.num_decode_reqs :] + np.repeat(
ori_tokens_start_loc, num_prefill_scheduled_tokens_linear
)
enter_fla_restore_idx = torch.index_select(
all_enter_fla_restore_idx[unpad_mask_prefill], 0, torch.from_numpy(enter_fla_scatter_idx)
exit_fa_scatter_idx = torch.index_select(
all_exit_fa_restore_idx[unpad_mask_prefill], 0, torch.from_numpy(exit_fa_scatter_indices)
)
self.pcp_allgather_restore_idx.gpu[: enter_fla_restore_idx.shape[0]].copy_(
enter_fla_restore_idx.long(), non_blocking=True
self.pcp_exit_fa_scatter_idx.gpu[: exit_fa_scatter_idx.shape[0]].copy_(
exit_fa_scatter_idx.long(), non_blocking=True
)
positions_prefill = all_positions_prefill[self.pcp_world_rank]
@@ -434,18 +456,7 @@ class PCPManager:
self.pcp_tokens_padded = pcp_tokens[: self.num_reqs]
self.num_scheduled_tokens_padded = np.array(self.pcp_tokens_padded, dtype=np.int32)
return num_padded_scheduled_tokens, positions_linear
else:
# Build the restore index used after allgather.
all_positions_lst = [
get_current_rank_positions(padded_pos_start_loc, rank_i) for rank_i in range(self.pcp_world_size)
]
all_positions = np.concatenate(all_positions_lst)
self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = all_positions.argsort()
self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0])
self.pcp_tokens[: self.num_reqs] = pcp_tokens[: self.num_reqs]
self.total_num_sampled_tokens_pcp = pcp_tokens[: self.num_reqs].sum()
return pcp_tokens[: self.num_reqs], positions
return pcp_tokens[: self.num_reqs], positions
def get_logits_indices(
self,
@@ -539,7 +550,6 @@ class PCPManager:
num_scheduled_tokens_pcp_full = np.empty(self.num_reqs, dtype=np.int32)
for i, req_id in enumerate(input_batch.req_ids):
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
self.query_lens_pcp_full.cpu[: self.num_reqs] = torch.from_numpy(num_scheduled_tokens_pcp_full)
req_indices_pcp_full = np.repeat(arange_np[: self.num_reqs], num_scheduled_tokens_pcp_full)
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
self.query_start_loc_pcp_full.np[0] = 0
@@ -567,7 +577,6 @@ class PCPManager:
cu_num_tokens_pcp_full,
num_spec_tokens,
)
self.query_lens_pcp_full.copy_to_gpu()
self.query_start_loc_pcp_full.copy_to_gpu()
self.input_ids_pcp_full.copy_to_gpu(total_num_scheduled_tokens_pcp_full)
self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full
@@ -719,15 +728,10 @@ class PCPManager:
if self.pcp_world_size > 1 and self.pcp_use_hybrid_attn:
assert self.num_scheduled_tokens_padded is not None
total_num_scheduled_tokens = self.num_scheduled_tokens_padded.sum()
query_lens_new = (
self.query_lens_pcp_full.cpu[:num_reqs]
if self.pcp_world_size > 1 and self.speculative_config
else query_lens
)
num_decodes = (query_lens_new <= self.decode_threshold).sum().item()
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
long_seq_metadata = None
ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs_padded]
if self.pcp_world_size * self.dcp_world_size > 1:
assert num_scheduled_tokens is not None
decode_context_lens = (
@@ -753,7 +757,6 @@ class PCPManager:
self.vllm_config.parallel_config.cp_kv_cache_interleave_size,
)
)
ori_query_lens_cpu = None
if self.decode_threshold > 1:
num_computed_tokens_of_pcp_dcp_list = []
if self.num_decode_reqs:
@@ -781,7 +784,6 @@ class PCPManager:
# (num_reqs_d + num_reqs_p, max_num_blocks),
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs_padded]
ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs_padded]
num_prefill_reqs = self.num_prefill_reqs
num_decode_reqs = self.num_decode_reqs
@@ -806,10 +808,9 @@ class PCPManager:
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.numpy(),
pcp_unpad_mask=torch.from_numpy(pcp_unpad_mask),
pcp_padded_tokens_fla=self.pcp_padded_tokens_fla,
query_lens_pcp_full_cpu=ori_query_lens_cpu,
max_query_len_pcp_full=ori_query_lens_cpu.max().item(),
)
if ori_query_lens_cpu is not None:
long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu
long_seq_metadata.max_query_len_pcp_full = ori_query_lens_cpu.max().item()
if self.pcp_world_size > 1:
q_head_idx, q_tail_idx = [], []
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
@@ -906,19 +907,18 @@ class PCPManager:
"head_attn_nomask_seqlens": head_attn_nomask_seqlens,
"tail_attn_nomask_seqlens": tail_attn_nomask_seqlens,
}
if not self.pcp_use_hybrid_attn:
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[
:num_actual_tokens_pcp_padded
]
else:
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[
: num_scheduled_tokens.sum() - num_decodes
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[
:num_actual_tokens_pcp_padded
]
if self.pcp_use_hybrid_attn:
long_seq_metadata.pcp_exit_fa_scatter_idx = self.pcp_exit_fa_scatter_idx.gpu[
: num_scheduled_tokens.sum() - self.num_decode_reqs
]
long_seq_metadata.pcp_fa_query_idx = self.pcp_fa_query_idx[
: num_actual_tokens_pcp_padded // self.pcp_world_size - num_decodes
: num_actual_tokens_pcp_padded // self.pcp_world_size - self.num_decode_reqs
]
long_seq_metadata.pcp_enter_fa_restore_idx = self.pcp_enter_fa_restore_idx[
: pcp_unpad_mask.sum() + num_decodes * (self.pcp_world_size - 1)
: pcp_unpad_mask.sum() + self.num_decode_reqs * (self.pcp_world_size - 1)
]
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor