[feat]ds3.2 pcp support mtp and chunkprefill (#6917)
### What this PR does / why we need it?
ds3.2 pcp supports the combination of MTP and chunkprefill features.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -45,6 +45,14 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
|
||||
self.block_size = (self.block_size * self.cp_virtual_block_size) // np.gcd(
|
||||
self.block_size, self.cp_virtual_block_size
|
||||
)
|
||||
self.slot_mapping_buf = torch.empty(
|
||||
(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens
|
||||
+ 2 * self.pcp_size * vllm_config.scheduler_config.max_num_seqs,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
@@ -82,15 +90,31 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
|
||||
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
assert long_seq_metadata is not None
|
||||
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens_pcp_padded]
|
||||
self.slot_mapping_buf[:num_actual_tokens_pcp_padded].copy_(
|
||||
common_attn_metadata.slot_mapping[:num_actual_tokens_pcp_padded], non_blocking=True
|
||||
)
|
||||
if self.enable_mlapo:
|
||||
slot_mapping[:num_decode_tokens] = slot_mapping[: num_decode_tokens * self.pcp_size : self.pcp_size]
|
||||
slot_mapping[num_decode_tokens : num_decode_tokens * self.pcp_size].fill_(-1)
|
||||
metadata_cls.slot_mapping = slot_mapping
|
||||
self.slot_mapping_buf[:num_decode_tokens] = self.slot_mapping_buf[
|
||||
: num_decode_tokens * self.pcp_size : self.pcp_size
|
||||
]
|
||||
self.slot_mapping_buf[num_decode_tokens : num_decode_tokens * self.pcp_size].fill_(-1)
|
||||
elif self.speculative_config is not None and num_decodes > 0:
|
||||
# when mtp, pcp_allgather_restore_idx=[696,-1,697,-1,560,-1,561,-1,100,101,102],
|
||||
# slot_mapping should be [696,697,-1,-1,560,561,-1,-1,100,101,102]
|
||||
num_tokens_per_request = num_decode_tokens // num_decodes
|
||||
decode_slot_mapping = self.slot_mapping_buf[: num_decode_tokens * self.pcp_size].reshape(
|
||||
num_decodes, -1
|
||||
)
|
||||
decode_slot_mapping[:, :num_tokens_per_request] = decode_slot_mapping[
|
||||
:, : num_tokens_per_request * self.pcp_size : self.pcp_size
|
||||
]
|
||||
decode_slot_mapping[:, num_tokens_per_request : num_tokens_per_request * self.pcp_size].fill_(-1)
|
||||
self.slot_mapping_buf[: num_decode_tokens * self.pcp_size] = decode_slot_mapping.flatten()
|
||||
metadata_cls.slot_mapping = self.slot_mapping_buf[:num_actual_tokens_pcp_padded]
|
||||
actual_seq_lengths_query = metadata_cls.cum_query_lens
|
||||
if num_prefills > 0 and num_decode_tokens > 0:
|
||||
prefill_q_cum_seqlens = (
|
||||
actual_seq_lengths_query[num_decode_tokens:] - actual_seq_lengths_query[num_decode_tokens - 1]
|
||||
actual_seq_lengths_query[num_decodes:] - actual_seq_lengths_query[num_decodes - 1]
|
||||
)
|
||||
else:
|
||||
prefill_q_cum_seqlens = actual_seq_lengths_query
|
||||
@@ -108,8 +132,9 @@ class AscendSFACPMetadataBuilder(AscendSFAMetadataBuilder):
|
||||
) -> AscendPCPMetadata | None:
|
||||
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
assert common_long_seq_metadata is not None
|
||||
q_head_kv_lens = (seq_lens // 2) * (self.pcp_rank + 1)
|
||||
q_tail_kv_lens = seq_lens * self.pcp_size - (seq_lens // 2) * self.pcp_rank
|
||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(seq_lens.device)
|
||||
q_head_kv_lens = (seq_lens // 2) * (self.pcp_rank + 1) + num_computed_tokens
|
||||
q_tail_kv_lens = seq_lens * self.pcp_size - (seq_lens // 2) * self.pcp_rank + num_computed_tokens
|
||||
return AscendPCPMetadata(
|
||||
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
|
||||
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
|
||||
@@ -181,6 +206,7 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
return self._execute_sparse_flash_attention(
|
||||
ql_nope, q_pe, kv, key_rope, block_table, topk_indices, actual_seq_lengths_query, actual_seq_lengths_key
|
||||
)
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
decode_attn_out = None
|
||||
@@ -190,10 +216,10 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
q_pe[:num_decode_tokens],
|
||||
kv,
|
||||
key_rope,
|
||||
block_table[:num_decode_tokens],
|
||||
block_table[:num_decodes],
|
||||
topk_indices[:num_decode_tokens],
|
||||
actual_seq_lengths_query[:num_decode_tokens],
|
||||
actual_seq_lengths_key[:num_decode_tokens],
|
||||
actual_seq_lengths_query[:num_decodes],
|
||||
actual_seq_lengths_key[:num_decodes],
|
||||
)
|
||||
|
||||
if num_prefills < 1:
|
||||
@@ -205,10 +231,10 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
ql_nope = ql_nope[num_decode_tokens:]
|
||||
q_pe = q_pe[num_decode_tokens:]
|
||||
topk_indices = topk_indices[num_decode_tokens:]
|
||||
block_table = block_table[num_decode_tokens:]
|
||||
block_table = block_table[num_decodes:]
|
||||
|
||||
# q head compute
|
||||
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decode_tokens:]
|
||||
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decodes:]
|
||||
q_head_output = self._execute_sparse_flash_attention(
|
||||
torch.index_select(ql_nope, 0, q_head_idx),
|
||||
torch.index_select(q_pe, 0, q_head_idx),
|
||||
@@ -221,7 +247,7 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
)
|
||||
|
||||
# q tail compute
|
||||
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decode_tokens:]
|
||||
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decodes:]
|
||||
q_tail_output = self._execute_sparse_flash_attention(
|
||||
torch.index_select(ql_nope, 0, q_tail_idx),
|
||||
torch.index_select(q_pe, 0, q_tail_idx),
|
||||
@@ -321,6 +347,7 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
)
|
||||
|
||||
# decode compute
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
decode_topk_indices = None
|
||||
@@ -329,9 +356,9 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
q[:num_decode_tokens],
|
||||
key,
|
||||
weights[:num_decode_tokens],
|
||||
actual_seq_lengths_query[:num_decode_tokens],
|
||||
actual_seq_lengths_key[:num_decode_tokens],
|
||||
block_table[:num_decode_tokens],
|
||||
actual_seq_lengths_query[:num_decodes],
|
||||
actual_seq_lengths_key[:num_decodes],
|
||||
block_table[:num_decodes],
|
||||
)
|
||||
|
||||
# prefill compute
|
||||
@@ -339,14 +366,14 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
return decode_topk_indices
|
||||
q = q[num_decode_tokens:]
|
||||
weights = weights[num_decode_tokens:]
|
||||
actual_seq_lengths_key = actual_seq_lengths_key[num_decode_tokens:]
|
||||
block_table = block_table[num_decode_tokens:]
|
||||
actual_seq_lengths_key = actual_seq_lengths_key[num_decodes:]
|
||||
block_table = block_table[num_decodes:]
|
||||
# pcp split for head and tail
|
||||
q_head_idx = attn_metadata.sfa_cp_metadata.q_head_idx
|
||||
q_tail_idx = attn_metadata.sfa_cp_metadata.q_tail_idx
|
||||
|
||||
# q head compute
|
||||
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decode_tokens:]
|
||||
q_head_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.head_attn_nomask_seqlens[num_decodes:]
|
||||
q_head_topk_indices = self._execute_indexer_select(
|
||||
q=torch.index_select(q, 0, q_head_idx),
|
||||
key=key,
|
||||
@@ -357,7 +384,7 @@ class AscendSFACPImpl(AscendSFAImpl):
|
||||
)
|
||||
|
||||
# q tail compute
|
||||
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decode_tokens:]
|
||||
q_tail_actual_seq_lengths_key = attn_metadata.sfa_cp_metadata.tail_attn_nomask_seqlens[num_decodes:]
|
||||
q_tail_topk_indices = self._execute_indexer_select(
|
||||
q=torch.index_select(q, 0, q_tail_idx),
|
||||
key=key,
|
||||
|
||||
@@ -246,36 +246,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
try:
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||
except Exception:
|
||||
self.dcp_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.pcp_size = 1
|
||||
self.pcp_rank = 0
|
||||
if self.pcp_size > 1:
|
||||
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
|
||||
max_buffer_num_tokens = self.max_num_tokens
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
max_buffer_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size
|
||||
self.pcp_manager = PCPManager(
|
||||
self.pcp_size,
|
||||
self.pcp_rank,
|
||||
self.dcp_size,
|
||||
self.dcp_rank,
|
||||
max_buffer_num_tokens,
|
||||
self.max_num_reqs,
|
||||
self.device,
|
||||
self.vllm_config,
|
||||
self.use_async_scheduling,
|
||||
self.pin_memory,
|
||||
)
|
||||
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
|
||||
self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32)
|
||||
self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64)
|
||||
|
||||
self.sampler = AscendSampler()
|
||||
self.attn_state: AscendAttentionState | None = None
|
||||
|
||||
@@ -310,6 +281,38 @@ class NPUModelRunner(GPUModelRunner):
|
||||
use_mm_prefix=self.model_config is not None and self.model_config.is_mm_prefix_lm,
|
||||
)
|
||||
|
||||
try:
|
||||
self.dcp_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
self.pcp_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||
except Exception:
|
||||
self.dcp_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.pcp_size = 1
|
||||
self.pcp_rank = 0
|
||||
if self.pcp_size > 1:
|
||||
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
|
||||
max_buffer_num_tokens = self.max_num_tokens
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
max_buffer_num_tokens = self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size
|
||||
self.pcp_manager = PCPManager(
|
||||
self.pcp_size,
|
||||
self.pcp_rank,
|
||||
self.dcp_size,
|
||||
self.dcp_rank,
|
||||
max_buffer_num_tokens,
|
||||
self.max_num_reqs,
|
||||
self.device,
|
||||
self.vllm_config,
|
||||
self.use_async_scheduling,
|
||||
self.pin_memory,
|
||||
self.use_sparse,
|
||||
)
|
||||
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
|
||||
self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32)
|
||||
self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64)
|
||||
|
||||
self._set_up_drafter()
|
||||
|
||||
# kv role
|
||||
|
||||
@@ -56,6 +56,7 @@ class PCPManager:
|
||||
vllm_config: VllmConfig,
|
||||
use_async_scheduling: bool,
|
||||
pin_memory: bool = False,
|
||||
use_sparse: bool = False,
|
||||
) -> None:
|
||||
self.pcp_world_size = pcp_world_size
|
||||
self.pcp_world_rank = pcp_rank
|
||||
@@ -97,6 +98,7 @@ class PCPManager:
|
||||
+ self.pcp_world_size * self.dcp_world_size * self.max_num_reqs
|
||||
)
|
||||
)
|
||||
self.use_sparse = use_sparse
|
||||
if self.speculative_config and self.pcp_world_size * self.dcp_world_size > 1:
|
||||
self.input_ids_pcp_full = CpuGpuBuffer(
|
||||
self.max_num_tokens, dtype=torch.int32, device=device, pin_memory=pin_memory
|
||||
@@ -784,16 +786,19 @@ class PCPManager:
|
||||
num_prefill_reqs = self.num_prefill_reqs
|
||||
num_decode_reqs = self.num_decode_reqs
|
||||
num_decode_reqs_flatten = ori_query_lens_cpu[:num_decode_reqs].sum().item()
|
||||
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
|
||||
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
|
||||
)
|
||||
block_table_tensor[:num_decode_reqs_flatten].copy_(
|
||||
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
|
||||
)
|
||||
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
|
||||
if num_reqs_padded > num_reqs:
|
||||
pad_size = num_reqs_padded - num_reqs
|
||||
ori_query_lens_cpu[-pad_size:] = torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())
|
||||
if not self.use_sparse:
|
||||
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
|
||||
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
|
||||
)
|
||||
block_table_tensor[:num_decode_reqs_flatten].copy_(
|
||||
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
|
||||
)
|
||||
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
|
||||
if num_reqs_padded > num_reqs:
|
||||
pad_size = num_reqs_padded - num_reqs
|
||||
ori_query_lens_cpu[-pad_size:] = torch.full(
|
||||
[pad_size], ori_query_lens_cpu[-pad_size - 1].item()
|
||||
)
|
||||
pcp_unpad_mask = self.pcp_unpad_mask_cpu[: self.pcp_padded_tokens_length]
|
||||
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
||||
pcp_use_hybrid_attn=self.pcp_use_hybrid_attn,
|
||||
|
||||
Reference in New Issue
Block a user