[feature] support pcp + mtp (in pd co-locate scenario) (#4098)

1. support pcp + mtp in pd co-locate scenario
2. llmdatadist connector pcp related bugfix and cleancode

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
zhangsicheng5
2025-11-12 17:22:21 +08:00
committed by GitHub
parent 1b4ce63ec9
commit a123f355e9
6 changed files with 246 additions and 97 deletions

View File

@@ -281,14 +281,17 @@ class AscendMLAMetadataBuilder:
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
0)
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
self.batch_seq_mask_buf = torch.empty(max_num_seqs,
self.batch_seq_mask_buf = torch.empty(max_num_seqs *
self.decode_threshold,
dtype=torch.uint8,
device=device)
self.seq_mask_pcp_buf = torch.empty(max_num_seqs,
self.seq_mask_pcp_buf = torch.empty(max_num_seqs *
self.decode_threshold,
self.pcp_size,
dtype=torch.uint8,
device=device)
self.seq_mask_dcp_buf = torch.empty(max_num_seqs,
self.seq_mask_dcp_buf = torch.empty(max_num_seqs *
self.decode_threshold,
self.dcp_size,
dtype=torch.uint8,
device=device)
@@ -504,12 +507,18 @@ class AscendMLAMetadataBuilder:
seq_lens = seq_lens[:num_decodes]
input_positions = input_positions[:num_decode_tokens]
block_table = block_table[:num_decodes, ...]
# For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular spec_attn_mask shape
if self.pcp_size > 1:
block_table = block_table.repeat_interleave(
self.decode_threshold, dim=0)
seq_lens_list = seq_lens.tolist()
if num_computed_tokens_of_pcp_dcp is not None:
# [bs, pcp_size, dcp_size]
num_computed_tokens_of_cp_dcp_array = np.array(
num_computed_tokens_of_pcp_dcp
)[:num_decodes] # [bs, pcp_size, dcp_size]
num_computed_tokens_of_pcp_dcp)[:num_decodes *
self.decode_threshold]
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
self.pcp_rank,