[feature] support pcp + mtp in full graph (#4572)

1. support pcp + mtp in full graph
2. pcp/dcp related mtp bugfix
3. support pcp + mtpx

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
zhangsicheng5
2025-12-22 16:13:39 +08:00
committed by GitHub
parent 12d581605b
commit 78aa7f2693
10 changed files with 478 additions and 94 deletions

View File

@@ -75,7 +75,7 @@ class BlockTable:
logical_table_size = max_num_blocks_per_req
duplicate_size = 1
if self.pcp_world_size > 1:
if self.pcp_world_size * self.dcp_world_size > 1:
duplicate_size += num_speculative_tokens
self.block_table = self._make_buffer(max_num_reqs * duplicate_size,
logical_table_size,

View File

@@ -280,7 +280,7 @@ class NPUModelRunner(GPUModelRunner):
dtype=torch.int32,
device=self.device)
self.num_actual_tokens_pcp_padded = 0
if self.speculative_config and self.pcp_size > 1:
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
self.input_ids_pcp_full = self._make_buffer(self.max_num_tokens,
dtype=torch.int32)
self.query_start_loc_pcp_full = self._make_buffer(
@@ -289,8 +289,9 @@ class NPUModelRunner(GPUModelRunner):
dtype=torch.int64,
device="cpu",
pin_memory=True)
self.decode_token_per_req += self.speculative_config.num_speculative_tokens
self.positions_pcp_full_np = self.positions_pcp_full.numpy()
self.query_lens_pcp_full = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
self.decode_threshold = 1 + (
self.speculative_config.num_speculative_tokens
if self.speculative_config else 0)
@@ -575,6 +576,7 @@ class NPUModelRunner(GPUModelRunner):
if self.pcp_size > 1:
if not self.vllm_config.model_config.use_mla:
self.generate_kv_idx(scheduler_output)
tokens_before_update = tokens.copy()
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
tokens)
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
@@ -591,7 +593,8 @@ class NPUModelRunner(GPUModelRunner):
num_valid_tokens = np.array([
num_tokens -
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
for num_tokens, i in zip(tokens, req_ids)
for num_tokens, i in zip((tokens_before_update if self.
pcp_size > 1 else tokens), req_ids)
],
dtype=np.int32)
@@ -909,7 +912,8 @@ class NPUModelRunner(GPUModelRunner):
>= self.input_batch.num_prompt_tokens[req_idx]) else -1)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens, self.num_pcp_pads[:num_reqs])
num_draft_tokens, cu_num_tokens,
self.num_pcp_pads[:num_reqs].numpy())
logits_indices = spec_decode_metadata.logits_indices
# For DECODE only cuda graph of some attention backends (e.g., GDN).
@@ -931,10 +935,11 @@ class NPUModelRunner(GPUModelRunner):
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
if self.speculative_config and self.pcp_size > 1:
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
self._generate_pcp_mtp_input(
num_reqs, scheduler_output.total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens)
scheduler_output.num_scheduled_tokens, with_prefill,
req_indices, positions_np, cu_num_tokens)
long_seq_metadata = self._generate_pcp_metadata(
total_num_scheduled_tokens)
@@ -1040,7 +1045,7 @@ class NPUModelRunner(GPUModelRunner):
prefill_context_parallel_metadata=long_seq_metadata,
)
if self.speculative_config and self.pcp_size > 1:
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
# For pcp + spec decode, we flatten block_table
# to avoid irregular spec_attn_mask shape, e.g.,
# num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1,
@@ -1048,12 +1053,13 @@ class NPUModelRunner(GPUModelRunner):
# (num_reqs_d + num_reqs_p, max_num_blocks),
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
ori_query_lens = self.query_start_loc_pcp_full.cpu[1:num_reqs + 1] - \
self.query_start_loc_pcp_full.cpu[:num_reqs]
ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs]
ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs]
num_prefill_reqs = (ori_query_lens
> self.decode_threshold).sum().item()
num_decode_reqs = num_reqs - num_prefill_reqs
num_decode_reqs_flatten = num_decode_reqs * self.decode_threshold
num_decode_reqs_flatten = \
ori_query_lens_cpu[:num_decode_reqs].sum().item()
blk_table_tensor[
num_decode_reqs_flatten:num_decode_reqs_flatten +
num_prefill_reqs].copy_(
@@ -1061,9 +1067,15 @@ class NPUModelRunner(GPUModelRunner):
num_prefill_reqs].clone())
blk_table_tensor[:num_decode_reqs_flatten].copy_(
blk_table_tensor[:num_decode_reqs].repeat_interleave(
self.decode_threshold, dim=0))
ori_query_lens[:num_decode_reqs], dim=0))
common_attn_metadata.block_table_tensor = \
blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs]
long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu
if 'pad_size' in locals() and pad_size > 0:
ori_query_lens_cpu[-pad_size:] = \
torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())
long_seq_metadata.max_query_len_pcp_full = \
ori_query_lens_cpu.max().item()
if self.speculative_config and \
self.spec_decode_common_attn_metadata is None:
@@ -1861,7 +1873,7 @@ class NPUModelRunner(GPUModelRunner):
decode_token_per_req=self.decode_token_per_req,
prefill_context_parallel_metadata=long_seq_metadata,
)
if self.pcp_size > 1:
if self.pcp_size * self.dcp_size > 1:
common_attn_metadata.block_table_tensor = \
block_table_tensor[:num_reqs * self.decode_threshold]
attn_state = AscendAttentionState.DecodeOnly
@@ -3029,9 +3041,7 @@ class NPUModelRunner(GPUModelRunner):
num_reqs = self.input_batch.num_reqs
self.num_pcp_pads = self.num_pcp_pads[:num_reqs]
tokens = np.array(tokens, dtype=np.int32)
num_decode_reqs = sum(
self.input_batch.num_computed_tokens_cpu[:num_reqs] >=
self.input_batch.num_prompt_tokens[:num_reqs])
num_decode_reqs = (np.array(tokens) <= self.decode_threshold).sum()
num_decode_tokens = sum(tokens[:num_decode_reqs])
num_padded_scheduled_tokens = np.ceil(
tokens /
@@ -3118,8 +3128,10 @@ class NPUModelRunner(GPUModelRunner):
def _generate_pcp_metadata(self, total_num_scheduled_tokens):
# In dummy run num_reqs == 0, update it from seq_lens
num_reqs = self.input_batch.num_reqs or self.query_lens.size(0)
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
>= self.input_batch.num_prompt_tokens[:num_reqs])
query_lens = self.query_lens_pcp_full.cpu[:num_reqs] \
if self.pcp_size > 1 and self.speculative_config else self.query_lens
num_decodes = (query_lens <= self.decode_threshold).sum().item()
num_prefills = num_reqs - num_decodes
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
long_seq_metadata = None
@@ -3137,16 +3149,41 @@ class NPUModelRunner(GPUModelRunner):
dtype=torch.int32,
)
# For pcp + spec decode, we flatten seq_lens
# to avoid irregular spec_attn_mask shape
# to avoid irregular spec_attn_mask shape.
# Same as block_table, we flatten decode seq_lens to query_lens,
# and keep prefill seq_lens unchanged.
for decode_idx in range(self.decode_threshold):
num_computed_tokens_of_pcp_dcp[
self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \
self._get_cp_local_seq_lens(
torch.tensor(context_lens),
torch.tensor(context_lens) - decode_idx,
self.pcp_size,
self.dcp_size,
self.parallel_config.cp_kv_cache_interleave_size,
)
if self.decode_threshold > 1:
num_computed_tokens_of_pcp_dcp_list = []
if num_decodes:
num_decodes_flatten = \
self.query_lens[:num_decodes].sum().item()
if self.query_lens[:num_decodes].min().item(
) == self.decode_threshold:
decode_flatten_idx = list(range(num_decodes_flatten))
else:
decode_flatten_idx = []
for req_id in range(num_decodes):
offset = (req_id + 1) * self.decode_threshold
decode_flatten_idx += \
list(range(offset - self.query_lens[req_id], offset))
num_computed_tokens_of_pcp_dcp_list.append(
num_computed_tokens_of_pcp_dcp[decode_flatten_idx])
if num_prefills:
num_computed_tokens_of_pcp_dcp_list.append(
num_computed_tokens_of_pcp_dcp[
(num_decodes + 1) * self.decode_threshold -
1::self.decode_threshold])
num_computed_tokens_of_pcp_dcp = torch.cat(
num_computed_tokens_of_pcp_dcp_list, dim=0)
long_seq_metadata = AscendPrefillContextParallelMetadata(
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.
@@ -3278,6 +3315,10 @@ class NPUModelRunner(GPUModelRunner):
num_reqs: int,
total_num_scheduled_tokens: int,
num_scheduled_tokens: dict[str, int],
with_prefill: bool = True,
req_indices=None,
positions_np=None,
cu_num_tokens=None,
):
"""
While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group,
@@ -3288,6 +3329,8 @@ class NPUModelRunner(GPUModelRunner):
num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32)
for i, req_id in enumerate(self.input_batch.req_ids):
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
self.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy(
num_scheduled_tokens_pcp_full)
req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens_pcp_full)
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
@@ -3313,11 +3356,45 @@ class NPUModelRunner(GPUModelRunner):
torch.from_numpy(token_indices_pcp_full),
out=self.input_ids_pcp_full.
cpu[:total_num_scheduled_tokens_pcp_full])
self.query_lens_pcp_full.copy_to_gpu()
self.query_start_loc_pcp_full.copy_to_gpu()
self.input_ids_pcp_full.gpu[:total_num_scheduled_tokens_pcp_full].copy_(
self.input_ids_pcp_full.cpu[:total_num_scheduled_tokens_pcp_full],
non_blocking=True,
)
self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full
# For mtpx, pre-allocate mtp slot_mapping here
if self.decode_threshold > 2 and not with_prefill:
num_tokens_ori = sum(list(num_scheduled_tokens.values()))
num_tokens_mtp = \
num_tokens_ori + num_reqs * (self.decode_threshold - 2)
num_tokens_mtp_pad = num_tokens_mtp * self.pcp_size
req_indices_split = np.array_split(req_indices,
cu_num_tokens)[:num_reqs]
positions_split = np.array_split(positions_np,
cu_num_tokens)[:num_reqs]
for req_idx in range(num_reqs):
ori_req_indice = req_indices_split[req_idx]
ori_position = positions_split[req_idx]
req_indices_split[req_idx] = np.append(
ori_req_indice,
np.repeat(ori_req_indice[-1], self.decode_threshold - 2))
positions_split[req_idx] = np.append(
ori_position,
np.arange(ori_position[-1] + 1,
ori_position[-1] + self.decode_threshold - 1))
req_indices_mtp = np.concatenate(req_indices_split)
positions_mtp = np.concatenate(positions_split)
self.input_batch.block_table.compute_slot_mapping(
req_indices_mtp, positions_mtp)
mtp_slot_ori = self.input_batch.block_table.block_tables[
0].slot_mapping.cpu[:num_tokens_mtp]
unpad_mask = np.repeat(False, num_tokens_mtp_pad)
unpad_mask[::self.pcp_size] = True
mtp_slot_pad = \
torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32)
mtp_slot_pad[unpad_mask] = mtp_slot_ori
self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True)
@contextmanager