feat(attention_cp): support chunked prefill for Qwen3Next with PCP&DCP (#6900)
### What this PR does / why we need it?
Support chunked prefill for Qwen3Next with PCP&DCP
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -80,6 +80,7 @@ def test_models_pcp_dcp_basic():
|
|||||||
decode_context_parallel_size=1,
|
decode_context_parallel_size=1,
|
||||||
max_num_batched_tokens=1024,
|
max_num_batched_tokens=1024,
|
||||||
enable_expert_parallel=True,
|
enable_expert_parallel=True,
|
||||||
|
long_prefill_token_threshold=4,
|
||||||
gpu_memory_utilization=0.8,
|
gpu_memory_utilization=0.8,
|
||||||
block_size=128) as runner:
|
block_size=128) as runner:
|
||||||
runner.model.generate(prompts, sampling_params)
|
runner.model.generate(prompts, sampling_params)
|
||||||
|
|||||||
@@ -169,8 +169,6 @@ class TestAscendAttentionCPImpl(TestBase):
|
|||||||
attn_metadata.prefill.chunked_context = MagicMock()
|
attn_metadata.prefill.chunked_context = MagicMock()
|
||||||
local_context_lens_allranks = torch.tensor([[[256, 256], [256, 256]]])
|
local_context_lens_allranks = torch.tensor([[[256, 256], [256, 256]]])
|
||||||
attn_metadata.prefill.chunked_context.local_context_lens_allranks = local_context_lens_allranks
|
attn_metadata.prefill.chunked_context.local_context_lens_allranks = local_context_lens_allranks
|
||||||
attn_metadata.prefill.chunked_context.batch_chunk_seq_mask = torch.randint(
|
|
||||||
0, 2, (1024, ), dtype=torch.bool)
|
|
||||||
attn_metadata.prefill.chunked_context.local_total_toks = local_context_lens_allranks[:,
|
attn_metadata.prefill.chunked_context.local_total_toks = local_context_lens_allranks[:,
|
||||||
0,
|
0,
|
||||||
0].sum(
|
0].sum(
|
||||||
|
|||||||
@@ -141,12 +141,14 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
|||||||
num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp
|
num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp
|
||||||
assert num_computed_tokens_of_pcp_dcp is not None
|
assert num_computed_tokens_of_pcp_dcp is not None
|
||||||
chunked_context_metadata = None
|
chunked_context_metadata = None
|
||||||
|
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
|
||||||
if num_prefills > 0:
|
if num_prefills > 0:
|
||||||
query_lens = query_lens[num_decode_tokens:]
|
query_lens = query_lens[num_decode_tokens:]
|
||||||
context_lens_cpu = num_computed_tokens_cpu[num_decodes:num_reqs]
|
context_lens_cpu = num_computed_tokens_cpu[num_decodes:num_reqs]
|
||||||
max_context_len_cpu = context_lens_cpu.max().item()
|
max_context_len_cpu = context_lens_cpu.max().item()
|
||||||
pcp_size = get_pcp_group().world_size
|
|
||||||
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
||||||
|
if self.pcp_size > 1 and common_long_seq_metadata.pcp_use_hybrid_attn:
|
||||||
|
query_lens = attn_mask_seqlens[0] * 2
|
||||||
local_context_lens_allranks = (
|
local_context_lens_allranks = (
|
||||||
torch.tensor(num_computed_tokens_of_pcp_dcp)[self.num_decodes_flatten :]
|
torch.tensor(num_computed_tokens_of_pcp_dcp)[self.num_decodes_flatten :]
|
||||||
.to(self.device)
|
.to(self.device)
|
||||||
@@ -163,7 +165,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
|||||||
# when only using dcp.
|
# when only using dcp.
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
kv_inverse_idx_for_chunk = torch.argsort(
|
kv_inverse_idx_for_chunk = torch.argsort(
|
||||||
common_long_seq_metadata.pcp_allgather_restore_idx[pcp_size * num_decode_tokens :].to(
|
common_long_seq_metadata.pcp_allgather_restore_idx[self.pcp_size * num_decode_tokens :].to(
|
||||||
torch.float32
|
torch.float32
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -172,29 +174,23 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
|||||||
kv_inverse_idx_for_chunk = None
|
kv_inverse_idx_for_chunk = None
|
||||||
cp_kv_recover_idx_for_chunk = None
|
cp_kv_recover_idx_for_chunk = None
|
||||||
|
|
||||||
batch_chunk_seq_mask = local_context_lens_allranks[:, self.pcp_rank, self.dcp_rank] == 0
|
|
||||||
batch_chunk_seq_mask = torch.repeat_interleave(
|
|
||||||
batch_chunk_seq_mask, repeats=(query_lens * self.pcp_size).to(self.device)
|
|
||||||
)
|
|
||||||
chunk_seq_mask_filtered_indices = filter_chunked_req_indices(query_lens, chunked_req_mask).to(
|
chunk_seq_mask_filtered_indices = filter_chunked_req_indices(query_lens, chunked_req_mask).to(
|
||||||
self.device
|
self.device
|
||||||
)
|
)
|
||||||
chunked_context_metadata = AscendMetadataForPrefill.ChunkedContextMetadata(
|
chunked_context_metadata = AscendMetadataForPrefill.ChunkedContextMetadata(
|
||||||
actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0),
|
actual_chunk_seq_lengths=torch.cumsum(query_lens * self.pcp_size, dim=0),
|
||||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||||
chunked_req_mask=chunked_req_mask,
|
chunked_req_mask=chunked_req_mask,
|
||||||
starts=local_chunk_starts,
|
starts=local_chunk_starts,
|
||||||
local_context_lens_allranks=local_context_lens_allranks,
|
local_context_lens_allranks=local_context_lens_allranks,
|
||||||
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk,
|
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk,
|
||||||
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk,
|
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk,
|
||||||
batch_chunk_seq_mask=batch_chunk_seq_mask,
|
|
||||||
chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices,
|
chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices,
|
||||||
local_total_toks=local_total_toks.item(),
|
local_total_toks=local_total_toks.item(),
|
||||||
)
|
)
|
||||||
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
|
|
||||||
head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens
|
head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens
|
||||||
tail_attn_nomask_seqlens = common_long_seq_metadata.tail_attn_nomask_seqlens
|
tail_attn_nomask_seqlens = common_long_seq_metadata.tail_attn_nomask_seqlens
|
||||||
if pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
attn_mask_seqlens = torch.cumsum(attn_mask_seqlens[0], dim=0).tolist()
|
attn_mask_seqlens = torch.cumsum(attn_mask_seqlens[0], dim=0).tolist()
|
||||||
head_attn_nomask_seqlens = torch.cumsum(head_attn_nomask_seqlens[1], dim=0).tolist()
|
head_attn_nomask_seqlens = torch.cumsum(head_attn_nomask_seqlens[1], dim=0).tolist()
|
||||||
tail_attn_nomask_seqlens = torch.cumsum(tail_attn_nomask_seqlens[1], dim=0).tolist()
|
tail_attn_nomask_seqlens = torch.cumsum(tail_attn_nomask_seqlens[1], dim=0).tolist()
|
||||||
@@ -220,6 +216,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
|||||||
|
|
||||||
prefill_metadata = AscendMetadataForPrefill(
|
prefill_metadata = AscendMetadataForPrefill(
|
||||||
pcp_metadata=pcp_metadata,
|
pcp_metadata=pcp_metadata,
|
||||||
|
pcp_exit_fa_scatter_idx=common_long_seq_metadata.pcp_exit_fa_scatter_idx,
|
||||||
chunked_context=chunked_context_metadata,
|
chunked_context=chunked_context_metadata,
|
||||||
block_tables=block_table[self.num_decodes_flatten :, ...],
|
block_tables=block_table[self.num_decodes_flatten :, ...],
|
||||||
actual_seq_lengths_q=torch.cumsum(query_lens, dim=0),
|
actual_seq_lengths_q=torch.cumsum(query_lens, dim=0),
|
||||||
@@ -475,9 +472,6 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
|||||||
kv_with_q_head_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx
|
kv_with_q_head_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx
|
||||||
kv_with_q_tail_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx
|
kv_with_q_tail_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx
|
||||||
kv_with_q_tail_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx
|
kv_with_q_tail_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx
|
||||||
if attn_metadata.prefill.pcp_metadata.pcp_use_hybrid_attn:
|
|
||||||
fa_query_idx = attn_metadata.prefill.pcp_metadata.pcp_fa_query_idx
|
|
||||||
query = torch.index_select(query, 0, fa_query_idx)
|
|
||||||
|
|
||||||
q_head = torch.index_select(query, 0, q_head_idx)
|
q_head = torch.index_select(query, 0, q_head_idx)
|
||||||
q_tail = torch.index_select(query, 0, q_tail_idx)
|
q_tail = torch.index_select(query, 0, q_tail_idx)
|
||||||
@@ -541,7 +535,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
|||||||
assert self.value_cache is not None
|
assert self.value_cache is not None
|
||||||
|
|
||||||
if self.dcp_size > 1:
|
if self.dcp_size > 1:
|
||||||
query = get_dcp_group().all_gather(query, 1)
|
query = get_dcp_group().all_gather(query.contiguous(), 1)
|
||||||
num_heads = self.num_heads * self.dcp_size
|
num_heads = self.num_heads * self.dcp_size
|
||||||
else:
|
else:
|
||||||
num_heads = self.num_heads
|
num_heads = self.num_heads
|
||||||
@@ -936,6 +930,9 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
|||||||
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
||||||
if pcp_use_hybrid_attn:
|
if pcp_use_hybrid_attn:
|
||||||
prefill_query = query[self.pcp_size * num_decode_tokens :]
|
prefill_query = query[self.pcp_size * num_decode_tokens :]
|
||||||
|
assert attn_metadata.prefill is not None and attn_metadata.prefill.pcp_metadata is not None
|
||||||
|
fa_query_idx = attn_metadata.prefill.pcp_metadata.pcp_fa_query_idx
|
||||||
|
prefill_query = torch.index_select(prefill_query, 0, fa_query_idx)
|
||||||
else:
|
else:
|
||||||
prefill_query = query[num_decode_tokens:num_actual_tokens_pcp_padded].contiguous()
|
prefill_query = query[num_decode_tokens:num_actual_tokens_pcp_padded].contiguous()
|
||||||
key = key[self.pcp_size * num_decode_tokens : attn_metadata.num_actual_tokens_pcp_padded].contiguous()
|
key = key[self.pcp_size * num_decode_tokens : attn_metadata.num_actual_tokens_pcp_padded].contiguous()
|
||||||
@@ -993,7 +990,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
|||||||
attn_metadata,
|
attn_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None:
|
if has_chunked_context:
|
||||||
# update the output of current chunk with context part
|
# update the output of current chunk with context part
|
||||||
torch.npu.current_stream().wait_stream(cp_chunkedprefill_comm_stream())
|
torch.npu.current_stream().wait_stream(cp_chunkedprefill_comm_stream())
|
||||||
global_context_output = global_context_output.permute([2, 0, 1]).contiguous()
|
global_context_output = global_context_output.permute([2, 0, 1]).contiguous()
|
||||||
@@ -1005,9 +1002,9 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
|||||||
if self.pcp_size > 1 and pcp_use_hybrid_attn:
|
if self.pcp_size > 1 and pcp_use_hybrid_attn:
|
||||||
# layer_idx != num_layers - 1
|
# layer_idx != num_layers - 1
|
||||||
assert attn_metadata.prefill.pcp_metadata is not None
|
assert attn_metadata.prefill.pcp_metadata is not None
|
||||||
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_metadata.pcp_allgather_restore_idx
|
pcp_exit_fa_scatter_idx = attn_metadata.prefill.pcp_exit_fa_scatter_idx
|
||||||
attn_output_prefill = get_pcp_group().all_gather(attn_output_prefill.contiguous(), dim=0)
|
attn_output_prefill = get_pcp_group().all_gather(attn_output_prefill.contiguous(), dim=0)
|
||||||
attn_output_prefill = torch.index_select(attn_output_prefill, 0, pcp_allgather_restore_idx)
|
attn_output_prefill = torch.index_select(attn_output_prefill, 0, pcp_exit_fa_scatter_idx)
|
||||||
fla_padding = attn_output_prefill.shape[0] + num_decode_tokens - output.shape[0]
|
fla_padding = attn_output_prefill.shape[0] + num_decode_tokens - output.shape[0]
|
||||||
output = F.pad(output, pad=(0, 0, 0, 0, 0, fla_padding), mode="constant", value=0)
|
output = F.pad(output, pad=(0, 0, 0, 0, 0, fla_padding), mode="constant", value=0)
|
||||||
|
|
||||||
|
|||||||
@@ -78,11 +78,11 @@ class AscendMetadataForPrefill:
|
|||||||
local_context_lens_allranks: list[list[int]] | None = None
|
local_context_lens_allranks: list[list[int]] | None = None
|
||||||
cp_kv_recover_idx_for_chunk: list[int] | None = None
|
cp_kv_recover_idx_for_chunk: list[int] | None = None
|
||||||
kv_inverse_idx_for_chunk: list[int] | None = None
|
kv_inverse_idx_for_chunk: list[int] | None = None
|
||||||
batch_chunk_seq_mask: list[bool] | None = None
|
|
||||||
local_total_toks: int | None = None
|
local_total_toks: int | None = None
|
||||||
|
|
||||||
""" Prefill Specific Metadata for Ascend"""
|
""" Prefill Specific Metadata for Ascend"""
|
||||||
pcp_metadata: AscendPCPMetadata | None = None
|
pcp_metadata: AscendPCPMetadata | None = None
|
||||||
|
pcp_exit_fa_scatter_idx: torch.Tensor | None = None
|
||||||
chunked_context: ChunkedContextMetadata | None = None
|
chunked_context: ChunkedContextMetadata | None = None
|
||||||
block_tables: torch.Tensor = None
|
block_tables: torch.Tensor = None
|
||||||
actual_seq_lengths_q: torch.Tensor = None
|
actual_seq_lengths_q: torch.Tensor = None
|
||||||
|
|||||||
@@ -113,6 +113,10 @@ class AscendPrefillContextParallelMetadata:
|
|||||||
# when entering from linear-attention to attention
|
# when entering from linear-attention to attention
|
||||||
pcp_enter_fa_restore_idx: torch.Tensor = None
|
pcp_enter_fa_restore_idx: torch.Tensor = None
|
||||||
|
|
||||||
|
# scatter the full sequence across all pcp ranks
|
||||||
|
# when exiting from attention to linear-attention
|
||||||
|
pcp_exit_fa_scatter_idx: torch.Tensor = None
|
||||||
|
|
||||||
# the number of tokens padded in linear-attn per rank
|
# the number of tokens padded in linear-attn per rank
|
||||||
pcp_padded_tokens_fla: int = 0
|
pcp_padded_tokens_fla: int = 0
|
||||||
|
|
||||||
|
|||||||
@@ -75,6 +75,12 @@ class PCPManager:
|
|||||||
device=device,
|
device=device,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
)
|
)
|
||||||
|
self.pcp_exit_fa_scatter_idx = CpuGpuBuffer(
|
||||||
|
max_buffer_num_tokens,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
self.pcp_padded_slot_mapping = torch.full(
|
self.pcp_padded_slot_mapping = torch.full(
|
||||||
(max_buffer_num_tokens,),
|
(max_buffer_num_tokens,),
|
||||||
fill_value=-1,
|
fill_value=-1,
|
||||||
@@ -110,9 +116,9 @@ class PCPManager:
|
|||||||
self.max_num_tokens, dtype=torch.int64, device="cpu", pin_memory=pin_memory
|
self.max_num_tokens, dtype=torch.int64, device="cpu", pin_memory=pin_memory
|
||||||
)
|
)
|
||||||
self.positions_pcp_full_np = self.positions_pcp_full.numpy()
|
self.positions_pcp_full_np = self.positions_pcp_full.numpy()
|
||||||
self.query_lens_pcp_full = CpuGpuBuffer(
|
self.query_lens_pcp_full = CpuGpuBuffer(
|
||||||
self.max_num_reqs, dtype=torch.int32, device=device, pin_memory=pin_memory
|
self.max_num_reqs, dtype=torch.int32, device=device, pin_memory=pin_memory
|
||||||
)
|
)
|
||||||
self.pcp_fa_query_idx = torch.zeros(
|
self.pcp_fa_query_idx = torch.zeros(
|
||||||
self.max_num_tokens + 2 * self.max_num_reqs, dtype=torch.int32, device=self.device
|
self.max_num_tokens + 2 * self.max_num_reqs, dtype=torch.int32, device=self.device
|
||||||
)
|
)
|
||||||
@@ -164,6 +170,10 @@ class PCPManager:
|
|||||||
self.num_prefill_reqs = num_reqs - self.num_decode_reqs
|
self.num_prefill_reqs = num_reqs - self.num_decode_reqs
|
||||||
self.num_decode_tokens = num_scheduled_tokens[: self.num_decode_reqs].sum()
|
self.num_decode_tokens = num_scheduled_tokens[: self.num_decode_reqs].sum()
|
||||||
|
|
||||||
|
self.query_lens_pcp_full.cpu[: self.num_reqs] = torch.from_numpy(num_scheduled_tokens)
|
||||||
|
self.query_lens_pcp_full.cpu[self.num_reqs :].fill_(0)
|
||||||
|
self.query_lens_pcp_full.copy_to_gpu()
|
||||||
|
|
||||||
def update_tokens_for_pcp(
|
def update_tokens_for_pcp(
|
||||||
self,
|
self,
|
||||||
num_scheduled_tokens: np.ndarray,
|
num_scheduled_tokens: np.ndarray,
|
||||||
@@ -301,6 +311,17 @@ class PCPManager:
|
|||||||
num_scheduled_tokens[: self.num_decode_reqs], arange_np
|
num_scheduled_tokens[: self.num_decode_reqs], arange_np
|
||||||
)[1]
|
)[1]
|
||||||
|
|
||||||
|
# Build the restore index used after allgather.
|
||||||
|
all_positions_lst = [
|
||||||
|
get_current_rank_positions(padded_pos_start_loc, rank_i) for rank_i in range(self.pcp_world_size)
|
||||||
|
]
|
||||||
|
all_positions = np.concatenate(all_positions_lst)
|
||||||
|
self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = all_positions.argsort()
|
||||||
|
self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0])
|
||||||
|
|
||||||
|
self.pcp_tokens[: self.num_reqs] = pcp_tokens[: self.num_reqs]
|
||||||
|
self.total_num_sampled_tokens_pcp = pcp_tokens[: self.num_reqs].sum()
|
||||||
|
|
||||||
if self.pcp_use_hybrid_attn:
|
if self.pcp_use_hybrid_attn:
|
||||||
max_scheduled_prefill_tokens = 0
|
max_scheduled_prefill_tokens = 0
|
||||||
self.pcp_padded_tokens_fla = 0
|
self.pcp_padded_tokens_fla = 0
|
||||||
@@ -405,7 +426,7 @@ class PCPManager:
|
|||||||
for rank_i in range(self.pcp_world_size)
|
for rank_i in range(self.pcp_world_size)
|
||||||
]
|
]
|
||||||
all_positions_prefill_tensor = torch.from_numpy(np.concatenate(all_positions_prefill))
|
all_positions_prefill_tensor = torch.from_numpy(np.concatenate(all_positions_prefill))
|
||||||
all_enter_fla_restore_idx = all_positions_prefill_tensor.float().argsort()
|
all_exit_fa_restore_idx = all_positions_prefill_tensor.float().argsort()
|
||||||
unpad_mask_prefill = self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length][
|
unpad_mask_prefill = self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length][
|
||||||
self.num_decode_reqs * self.pcp_world_size :
|
self.num_decode_reqs * self.pcp_world_size :
|
||||||
]
|
]
|
||||||
@@ -413,14 +434,15 @@ class PCPManager:
|
|||||||
ori_tokens_start_loc = np.roll(np.cumsum(num_scheduled_tokens[self.num_decode_tokens :]), 1)
|
ori_tokens_start_loc = np.roll(np.cumsum(num_scheduled_tokens[self.num_decode_tokens :]), 1)
|
||||||
ori_tokens_start_loc[0] = 0
|
ori_tokens_start_loc[0] = 0
|
||||||
# [0,1,2] [3,4] | [0,1,7,8] [2,3,9] [4,5,10] [6,11]
|
# [0,1,2] [3,4] | [0,1,7,8] [2,3,9] [4,5,10] [6,11]
|
||||||
enter_fla_scatter_idx = positions_linear[self.num_decode_reqs :] + np.repeat(
|
exit_fa_scatter_indices = positions_linear[self.num_decode_reqs :] + np.repeat(
|
||||||
ori_tokens_start_loc, num_prefill_scheduled_tokens_linear
|
ori_tokens_start_loc, num_prefill_scheduled_tokens_linear
|
||||||
)
|
)
|
||||||
enter_fla_restore_idx = torch.index_select(
|
|
||||||
all_enter_fla_restore_idx[unpad_mask_prefill], 0, torch.from_numpy(enter_fla_scatter_idx)
|
exit_fa_scatter_idx = torch.index_select(
|
||||||
|
all_exit_fa_restore_idx[unpad_mask_prefill], 0, torch.from_numpy(exit_fa_scatter_indices)
|
||||||
)
|
)
|
||||||
self.pcp_allgather_restore_idx.gpu[: enter_fla_restore_idx.shape[0]].copy_(
|
self.pcp_exit_fa_scatter_idx.gpu[: exit_fa_scatter_idx.shape[0]].copy_(
|
||||||
enter_fla_restore_idx.long(), non_blocking=True
|
exit_fa_scatter_idx.long(), non_blocking=True
|
||||||
)
|
)
|
||||||
|
|
||||||
positions_prefill = all_positions_prefill[self.pcp_world_rank]
|
positions_prefill = all_positions_prefill[self.pcp_world_rank]
|
||||||
@@ -434,18 +456,7 @@ class PCPManager:
|
|||||||
self.pcp_tokens_padded = pcp_tokens[: self.num_reqs]
|
self.pcp_tokens_padded = pcp_tokens[: self.num_reqs]
|
||||||
self.num_scheduled_tokens_padded = np.array(self.pcp_tokens_padded, dtype=np.int32)
|
self.num_scheduled_tokens_padded = np.array(self.pcp_tokens_padded, dtype=np.int32)
|
||||||
return num_padded_scheduled_tokens, positions_linear
|
return num_padded_scheduled_tokens, positions_linear
|
||||||
else:
|
return pcp_tokens[: self.num_reqs], positions
|
||||||
# Build the restore index used after allgather.
|
|
||||||
all_positions_lst = [
|
|
||||||
get_current_rank_positions(padded_pos_start_loc, rank_i) for rank_i in range(self.pcp_world_size)
|
|
||||||
]
|
|
||||||
all_positions = np.concatenate(all_positions_lst)
|
|
||||||
self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = all_positions.argsort()
|
|
||||||
self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0])
|
|
||||||
|
|
||||||
self.pcp_tokens[: self.num_reqs] = pcp_tokens[: self.num_reqs]
|
|
||||||
self.total_num_sampled_tokens_pcp = pcp_tokens[: self.num_reqs].sum()
|
|
||||||
return pcp_tokens[: self.num_reqs], positions
|
|
||||||
|
|
||||||
def get_logits_indices(
|
def get_logits_indices(
|
||||||
self,
|
self,
|
||||||
@@ -539,7 +550,6 @@ class PCPManager:
|
|||||||
num_scheduled_tokens_pcp_full = np.empty(self.num_reqs, dtype=np.int32)
|
num_scheduled_tokens_pcp_full = np.empty(self.num_reqs, dtype=np.int32)
|
||||||
for i, req_id in enumerate(input_batch.req_ids):
|
for i, req_id in enumerate(input_batch.req_ids):
|
||||||
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
|
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
|
||||||
self.query_lens_pcp_full.cpu[: self.num_reqs] = torch.from_numpy(num_scheduled_tokens_pcp_full)
|
|
||||||
req_indices_pcp_full = np.repeat(arange_np[: self.num_reqs], num_scheduled_tokens_pcp_full)
|
req_indices_pcp_full = np.repeat(arange_np[: self.num_reqs], num_scheduled_tokens_pcp_full)
|
||||||
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
|
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
|
||||||
self.query_start_loc_pcp_full.np[0] = 0
|
self.query_start_loc_pcp_full.np[0] = 0
|
||||||
@@ -567,7 +577,6 @@ class PCPManager:
|
|||||||
cu_num_tokens_pcp_full,
|
cu_num_tokens_pcp_full,
|
||||||
num_spec_tokens,
|
num_spec_tokens,
|
||||||
)
|
)
|
||||||
self.query_lens_pcp_full.copy_to_gpu()
|
|
||||||
self.query_start_loc_pcp_full.copy_to_gpu()
|
self.query_start_loc_pcp_full.copy_to_gpu()
|
||||||
self.input_ids_pcp_full.copy_to_gpu(total_num_scheduled_tokens_pcp_full)
|
self.input_ids_pcp_full.copy_to_gpu(total_num_scheduled_tokens_pcp_full)
|
||||||
self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full
|
self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full
|
||||||
@@ -719,15 +728,10 @@ class PCPManager:
|
|||||||
if self.pcp_world_size > 1 and self.pcp_use_hybrid_attn:
|
if self.pcp_world_size > 1 and self.pcp_use_hybrid_attn:
|
||||||
assert self.num_scheduled_tokens_padded is not None
|
assert self.num_scheduled_tokens_padded is not None
|
||||||
total_num_scheduled_tokens = self.num_scheduled_tokens_padded.sum()
|
total_num_scheduled_tokens = self.num_scheduled_tokens_padded.sum()
|
||||||
query_lens_new = (
|
|
||||||
self.query_lens_pcp_full.cpu[:num_reqs]
|
|
||||||
if self.pcp_world_size > 1 and self.speculative_config
|
|
||||||
else query_lens
|
|
||||||
)
|
|
||||||
num_decodes = (query_lens_new <= self.decode_threshold).sum().item()
|
|
||||||
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size
|
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size
|
||||||
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
|
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
|
||||||
long_seq_metadata = None
|
long_seq_metadata = None
|
||||||
|
ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs_padded]
|
||||||
if self.pcp_world_size * self.dcp_world_size > 1:
|
if self.pcp_world_size * self.dcp_world_size > 1:
|
||||||
assert num_scheduled_tokens is not None
|
assert num_scheduled_tokens is not None
|
||||||
decode_context_lens = (
|
decode_context_lens = (
|
||||||
@@ -753,7 +757,6 @@ class PCPManager:
|
|||||||
self.vllm_config.parallel_config.cp_kv_cache_interleave_size,
|
self.vllm_config.parallel_config.cp_kv_cache_interleave_size,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
ori_query_lens_cpu = None
|
|
||||||
if self.decode_threshold > 1:
|
if self.decode_threshold > 1:
|
||||||
num_computed_tokens_of_pcp_dcp_list = []
|
num_computed_tokens_of_pcp_dcp_list = []
|
||||||
if self.num_decode_reqs:
|
if self.num_decode_reqs:
|
||||||
@@ -781,7 +784,6 @@ class PCPManager:
|
|||||||
# (num_reqs_d + num_reqs_p, max_num_blocks),
|
# (num_reqs_d + num_reqs_p, max_num_blocks),
|
||||||
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
|
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
|
||||||
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
|
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
|
||||||
ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs_padded]
|
|
||||||
ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs_padded]
|
ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs_padded]
|
||||||
num_prefill_reqs = self.num_prefill_reqs
|
num_prefill_reqs = self.num_prefill_reqs
|
||||||
num_decode_reqs = self.num_decode_reqs
|
num_decode_reqs = self.num_decode_reqs
|
||||||
@@ -806,10 +808,9 @@ class PCPManager:
|
|||||||
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.numpy(),
|
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.numpy(),
|
||||||
pcp_unpad_mask=torch.from_numpy(pcp_unpad_mask),
|
pcp_unpad_mask=torch.from_numpy(pcp_unpad_mask),
|
||||||
pcp_padded_tokens_fla=self.pcp_padded_tokens_fla,
|
pcp_padded_tokens_fla=self.pcp_padded_tokens_fla,
|
||||||
|
query_lens_pcp_full_cpu=ori_query_lens_cpu,
|
||||||
|
max_query_len_pcp_full=ori_query_lens_cpu.max().item(),
|
||||||
)
|
)
|
||||||
if ori_query_lens_cpu is not None:
|
|
||||||
long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu
|
|
||||||
long_seq_metadata.max_query_len_pcp_full = ori_query_lens_cpu.max().item()
|
|
||||||
if self.pcp_world_size > 1:
|
if self.pcp_world_size > 1:
|
||||||
q_head_idx, q_tail_idx = [], []
|
q_head_idx, q_tail_idx = [], []
|
||||||
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
|
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
|
||||||
@@ -906,19 +907,18 @@ class PCPManager:
|
|||||||
"head_attn_nomask_seqlens": head_attn_nomask_seqlens,
|
"head_attn_nomask_seqlens": head_attn_nomask_seqlens,
|
||||||
"tail_attn_nomask_seqlens": tail_attn_nomask_seqlens,
|
"tail_attn_nomask_seqlens": tail_attn_nomask_seqlens,
|
||||||
}
|
}
|
||||||
if not self.pcp_use_hybrid_attn:
|
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[
|
||||||
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[
|
:num_actual_tokens_pcp_padded
|
||||||
:num_actual_tokens_pcp_padded
|
]
|
||||||
]
|
if self.pcp_use_hybrid_attn:
|
||||||
else:
|
long_seq_metadata.pcp_exit_fa_scatter_idx = self.pcp_exit_fa_scatter_idx.gpu[
|
||||||
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[
|
: num_scheduled_tokens.sum() - self.num_decode_reqs
|
||||||
: num_scheduled_tokens.sum() - num_decodes
|
|
||||||
]
|
]
|
||||||
long_seq_metadata.pcp_fa_query_idx = self.pcp_fa_query_idx[
|
long_seq_metadata.pcp_fa_query_idx = self.pcp_fa_query_idx[
|
||||||
: num_actual_tokens_pcp_padded // self.pcp_world_size - num_decodes
|
: num_actual_tokens_pcp_padded // self.pcp_world_size - self.num_decode_reqs
|
||||||
]
|
]
|
||||||
long_seq_metadata.pcp_enter_fa_restore_idx = self.pcp_enter_fa_restore_idx[
|
long_seq_metadata.pcp_enter_fa_restore_idx = self.pcp_enter_fa_restore_idx[
|
||||||
: pcp_unpad_mask.sum() + num_decodes * (self.pcp_world_size - 1)
|
: pcp_unpad_mask.sum() + self.num_decode_reqs * (self.pcp_world_size - 1)
|
||||||
]
|
]
|
||||||
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
|
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
|
||||||
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
|
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
|
||||||
|
|||||||
Reference in New Issue
Block a user