[feat]ds3.2 pcp support mtp and chunkprefill (#6917)
### What this PR does / why we need it?
ds3.2 pcp supports the combination of MTP and chunkprefill features.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -45,6 +45,14 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
|
||||
self.block_size = (self.block_size * self.cp_virtual_block_size) // np.gcd(
|
||||
self.block_size, self.cp_virtual_block_size
|
||||
)
|
||||
self.slot_mapping_buf = torch.empty(
|
||||
(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens
|
||||
+ 2 * self.pcp_size * vllm_config.scheduler_config.max_num_seqs,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
@@ -82,15 +90,31 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
|
||||
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
assert long_seq_metadata is not None
|
||||
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens_pcp_padded]
|
||||
self.slot_mapping_buf[:num_actual_tokens_pcp_padded].copy_(
|
||||
common_attn_metadata.slot_mapping[:num_actual_tokens_pcp_padded], non_blocking=True
|
||||
)
|
||||
if self.enable_mlapo:
|
||||
slot_mapping[:num_decode_tokens] = slot_mapping[: num_decode_tokens * self.pcp_size : self.pcp_size]
|
||||
slot_mapping[num_decode_tokens : num_decode_tokens * self.pcp_size].fill_(-1)
|
||||
metadata_cls.slot_mapping = slot_mapping
|
||||
self.slot_mapping_buf[:num_decode_tokens] = self.slot_mapping_buf[
|
||||
: num_decode_tokens * self.pcp_size : self.pcp_size
|
||||
]
|
||||
self.slot_mapping_buf[num_decode_tokens : num_decode_tokens * self.pcp_size].fill_(-1)
|
||||
elif self.speculative_config is not None and num_decodes > 0:
|
||||
# when mtp, pcp_allgather_restore_idx=[696,-1,697,-1,560,-1,561,-1,100,101,102],
|
||||
# slot_mapping should be [696,697,-1,-1,560,561,-1,-1,100,101,102]
|
||||
num_tokens_per_request = num_decode_tokens // num_decodes
|
||||
decode_slot_mapping = self.slot_mapping_buf[: num_decode_tokens * self.pcp_size].reshape(
|
||||
num_decodes, -1
|
||||
)
|
||||
decode_slot_mapping[:, :num_tokens_per_request] = decode_slot_mapping[
|
||||
:, : num_tokens_per_request * self.pcp_size : self.pcp_size
|
||||
]
|
||||
decode_slot_mapping[:, num_tokens_per_request : num_tokens_per_request * self.pcp_size].fill_(-1)
|
||||
self.slot_mapping_buf[: num_decode_tokens * self.pcp_size] = decode_slot_mapping.flatten()
|
||||
metadata_cls.slot_mapping = self.slot_mapping_buf[:num_actual_tokens_pcp_padded]
|
||||
actual_seq_lengths_query = metadata_cls.cum_query_lens
|
||||
if num_prefills > 0 and num_decode_tokens > 0:
|
||||
prefill_q_cum_seqlens = (
|
||||
actual_seq_lengths_query[num_decode_tokens:] - actual_seq_lengths_query[num_decode_tokens - 1]
|
||||
actual_seq_lengths_query[num_decodes:] - actual_seq_lengths_query[num_decodes - 1]
|
||||
)
|
||||
else:
|
||||
prefill_q_cum_seqlens = actual_seq_lengths_query
|
||||
@@ -108,8 +132,9 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
|
||||
) -> AscendPCPMetadata | None:
|
||||
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
assert common_long_seq_metadata is not None
|
||||
q_head_kv_lens = (seq_lens // 2) * (self.pcp_rank + 1)
|
||||
q_tail_kv_lens = seq_lens * self.pcp_size - (seq_lens // 2) * self.pcp_rank
|
||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(seq_lens.device)
|
||||
q_head_kv_lens = (seq_lens // 2) * (self.pcp_rank + 1) + num_computed_tokens
|
||||
q_tail_kv_lens = seq_lens * self.pcp_size - (seq_lens // 2) * self.pcp_rank + num_computed_tokens
|
||||
return AscendPCPMetadata(
|
||||
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
|
||||
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
|
||||
@@ -181,6 +206,7 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
return self._execute_sparse_flash_attention(
|
||||
ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key
|
||||
)
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
decode_attn_out = None
|
||||
@@ -190,10 +216,10 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
q_pe[:num_decode_tokens],
|
||||
kv,
|
||||
key_rope,
|
||||
block_table[:num_decode_tokens],
|
||||
block_table[:num_decodes],
|
||||
topk_indices[:num_decode_tokens],
|
||||
actual_seq_lengths_query[:num_decode_tokens],
|
||||
actual_seq_lengths_key[:num_decode_tokens],
|
||||
actual_seq_lengths_query[:num_decodes],
|
||||
actual_seq_lengths_key[:num_decodes],
|
||||
)
|
||||
|
||||
if num_prefills < 1:
|
||||
@@ -205,10 +231,10 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
ql_nope = ql_nope[num_decode_tokens:]
|
||||
q_pe = q_pe[num_decode_tokens:]
|
||||
topk_indices = topk_indices[num_decode_tokens:]
|
||||
block_table = block_table[num_decode_tokens:]
|
||||
block_table = block_table[num_decodes:]
|
||||
|
||||
# q head compute
|
||||
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decode_tokens:]
|
||||
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decodes:]
|
||||
q_head_output = self._execute_sparse_flash_attention(
|
||||
torch.index_select(ql_nope, 0, q_head_idx),
|
||||
torch.index_select(q_pe, 0, q_head_idx),
|
||||
@@ -221,7 +247,7 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
)
|
||||
|
||||
# q tail compute
|
||||
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decode_tokens:]
|
||||
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decodes:]
|
||||
q_tail_output = self._execute_sparse_flash_attention(
|
||||
torch.index_select(ql_nope, 0, q_tail_idx),
|
||||
torch.index_select(q_pe, 0, q_tail_idx),
|
||||
@@ -321,6 +347,7 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
)
|
||||
|
||||
# decode compute
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
decode_topk_indices = None
|
||||
@@ -329,9 +356,9 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
q[:num_decode_tokens],
|
||||
key,
|
||||
weights[:num_decode_tokens],
|
||||
actual_seq_lengths_query[:num_decode_tokens],
|
||||
actual_seq_lengths_key[:num_decode_tokens],
|
||||
block_table[:num_decode_tokens],
|
||||
actual_seq_lengths_query[:num_decodes],
|
||||
actual_seq_lengths_key[:num_decodes],
|
||||
block_table[:num_decodes],
|
||||
)
|
||||
|
||||
# prefill compute
|
||||
@@ -339,14 +366,14 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
return decode_topk_indices
|
||||
q = q[num_decode_tokens:]
|
||||
weights = weights[num_decode_tokens:]
|
||||
actual_seq_lengths_key = actual_seq_lengths_key[num_decode_tokens:]
|
||||
block_table = block_table[num_decode_tokens:]
|
||||
actual_seq_lengths_key = actual_seq_lengths_key[num_decodes:]
|
||||
block_table = block_table[num_decodes:]
|
||||
# pcp split for head and tail
|
||||
q_head_idx = attn_metadata.sfa_cp_metadata.q_head_idx
|
||||
q_tail_idx = attn_metadata.sfa_cp_metadata.q_tail_idx
|
||||
|
||||
# q head compute
|
||||
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decode_tokens:]
|
||||
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decodes:]
|
||||
q_head_topk_indices = self._execute_indexer_select(
|
||||
q=torch.index_select(q, 0, q_head_idx),
|
||||
key=key,
|
||||
@@ -357,7 +384,7 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
)
|
||||
|
||||
# q tail compute
|
||||
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decode_tokens:]
|
||||
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decodes:]
|
||||
q_tail_topk_indices = self._execute_indexer_select(
|
||||
q=torch.index_select(q, 0, q_tail_idx),
|
||||
key=key,
|
||||
|
||||
Reference in New Issue
Block a user