[feature] support pcp + mtp in full graph (#4572)

1. support pcp + mtp in full graph
2. pcp/dcp related mtp bugfix
3. support pcp + mtpx

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
zhangsicheng5
2025-12-22 16:13:39 +08:00
committed by GitHub
parent 12d581605b
commit 78aa7f2693
10 changed files with 478 additions and 94 deletions

View File

@@ -32,6 +32,7 @@ from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
update_mla_attn_dcp_pcp_params,
update_mla_attn_params)
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
@@ -98,6 +99,7 @@ class MtpProposer(Proposer):
self.pcp_size = self.runner.pcp_size
self.dcp_size = self.runner.dcp_size
self.pcp_rank = self.runner.pcp_rank
self.dcp_rank = self.runner.dcp_rank
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
self.draft_indexer_metadata_builder: Optional[
@@ -267,6 +269,13 @@ class MtpProposer(Proposer):
attn_state=self.runner.attn_state,
decode_token_per_req=self.runner.decode_token_per_req,
)
if self.pcp_size * self.dcp_size > 1:
# update long_seq related params and flatten block_table
common_attn_metadata.prefill_context_parallel_metadata = \
self.runner.long_seq_metadata
common_attn_metadata.block_table_tensor = \
self.runner.input_batch.block_table[0].get_device_tensor()[
:num_reqs * self.decode_threshold]
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata_mtp = builder.build_for_graph_capture(
@@ -310,9 +319,15 @@ class MtpProposer(Proposer):
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
not forward_context.capturing:
if self.vllm_config.model_config.use_mla and not self.use_sparse:
update_mla_attn_params(
self.update_stream, forward_context, num_tokens,
self.vllm_config.speculative_config)
if self.pcp_size * self.dcp_size > 1:
update_mla_attn_dcp_pcp_params(
self.update_stream, forward_context,
num_tokens)
else:
update_mla_attn_params(
self.update_stream, forward_context,
num_tokens,
self.vllm_config.speculative_config)
if self.enable_shared_expert_dp:
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
positions, True)
@@ -364,11 +379,11 @@ class MtpProposer(Proposer):
valid_sampled_tokens_count)
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
if self.pcp_size > 1:
if self.pcp_size * self.dcp_size > 1:
long_seq_metadata = self.runner.long_seq_metadata
input_ids_pcp_full = self.runner.input_ids_pcp_full
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full
query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full_cpu
input_ids_pcp_full = self.runner.input_ids_pcp_full.gpu
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full.gpu
query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full.cpu
num_reqs = self.runner.input_batch.num_reqs
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
query_start_loc_pcp_full_cpu[:num_reqs]
@@ -396,12 +411,11 @@ class MtpProposer(Proposer):
target_hidden_states = hidden_states[:num_scheduled_tokens]
else:
if self.pcp_size > 1:
common_attn_metadata.query_start_loc_cpu = \
common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] = \
query_start_loc_pcp_full_cpu[:num_reqs + 1]
common_attn_metadata.query_start_loc = \
common_attn_metadata.query_start_loc[:num_reqs + 1] = \
query_start_loc_pcp_full[:num_reqs + 1]
if self.speculative_config.disable_padded_drafter_batch:
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
token_indices_to_sample = None
common_attn_metadata, token_indices =\
self._prepare_inputs(
@@ -630,15 +644,18 @@ class MtpProposer(Proposer):
self.input_ids[last_token_indices] = next_token_ids
# update pcp related params
if self.pcp_size > 1:
if self.pcp_size * self.dcp_size > 1:
assert long_seq_metadata is not None
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
ori_last_token_indices = last_token_indices.clone()
query_lens_d = self.runner.query_lens[:num_decode_reqs]
if self.pcp_size > 1:
# 1. preprocess decode/prefill input_ids & target_hidden_states
# decode input_ids: keep unchanged
# decode target_hidden_states: remove padding
# prefill input_ids: add padding and pcp split
# prefill target_hidden_states: pcp split
num_tokens_d = num_decode_reqs * self.decode_threshold
num_tokens_d = query_lens_d.sum().item()
num_tokens_d_padded = num_tokens_d * self.pcp_size
input_ids_d = self.input_ids[:num_tokens_d]
input_ids_p = self.input_ids[num_tokens_d:num_tokens]
@@ -646,12 +663,17 @@ class MtpProposer(Proposer):
target_hidden_states[:num_tokens_d_padded]
if num_tokens_d:
# remove padding (from pcp all-gather) in decode part
target_hidden_states_d = target_hidden_states_d_padded.reshape(
[
num_decode_reqs, self.decode_threshold * self.pcp_size,
-1
])[:, :self.decode_threshold, :].reshape(
[num_tokens_d, -1])
mask_start_loc = torch.cat([
torch.tensor([0], dtype=torch.int32),
torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]
])
mask_len = query_lens_d
mask = []
for req_id in range(num_decode_reqs):
mask += list(
range(mask_start_loc[req_id],
mask_start_loc[req_id] + mask_len[req_id]))
target_hidden_states_d = target_hidden_states_d_padded[mask]
else:
target_hidden_states_d = target_hidden_states_d_padded
target_hidden_states_p = target_hidden_states[num_tokens_d_padded:]
@@ -670,25 +692,26 @@ class MtpProposer(Proposer):
torch.cat([input_ids_d, input_ids_p], dim=0))
target_hidden_states = torch.cat(
[target_hidden_states_d, target_hidden_states_p], dim=0)
# 2. update attn_metadata params that may be influenced by pcp
common_attn_metadata.num_actual_tokens = num_tokens
common_attn_metadata.max_query_len = max(self.decode_threshold,
max_query_len_p)
common_attn_metadata.seq_lens[num_decode_reqs:] = seq_lens_p
common_attn_metadata.seq_lens_cpu[num_decode_reqs:] = seq_lens_p
query_start_loc_p = cu_num_tokens_p[1:] + \
common_attn_metadata.query_start_loc[num_decode_reqs].item()
common_attn_metadata.query_start_loc[num_decode_reqs + 1:] = \
query_start_loc_p
common_attn_metadata.query_start_loc_cpu[num_decode_reqs + 1:] = \
query_start_loc_p
# 3. update sample_indices according to main model
# 2. update sample_indices according to main model
if num_decode_reqs:
last_token_indices[:num_decode_reqs] = \
self.runner.logits_indices[last_token_indices[:num_decode_reqs]]
if num_prefill_reqs:
last_token_indices[-num_prefill_reqs:] = \
self.runner.logits_indices[-num_prefill_reqs:]
# 3. update attn_metadata params that may be influenced by pcp
common_attn_metadata.num_actual_tokens = num_tokens
common_attn_metadata.max_query_len = max(
self.decode_threshold, max_query_len_p)
common_attn_metadata.seq_lens[-num_prefill_reqs:] = seq_lens_p
common_attn_metadata.seq_lens_cpu[
-num_prefill_reqs:] = seq_lens_p
query_start_loc_p = cu_num_tokens_p[1:] + \
common_attn_metadata.query_start_loc[num_decode_reqs].item()
common_attn_metadata.query_start_loc[-num_prefill_reqs:] = \
query_start_loc_p
common_attn_metadata.query_start_loc_cpu[-num_prefill_reqs:] = \
query_start_loc_p
assert self.runner is not None
@@ -796,10 +819,15 @@ class MtpProposer(Proposer):
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
if self.vllm_config.model_config.use_mla and not self.use_sparse:
update_mla_attn_params(
self.update_stream, forward_context,
num_input_tokens,
self.vllm_config.speculative_config)
if self.pcp_size * self.dcp_size > 1:
update_mla_attn_dcp_pcp_params(
self.update_stream, forward_context,
num_input_tokens)
else:
update_mla_attn_params(
self.update_stream, forward_context,
num_input_tokens,
self.vllm_config.speculative_config)
if self.enable_shared_expert_dp:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
@@ -814,7 +842,9 @@ class MtpProposer(Proposer):
last_token_indices,
(0, max_num_reqs_across_dp - num_indices))
if self.pcp_size > 1:
if self.pcp_size > 1 and step == 0:
# remove graph padding before all_gather
hidden_states = hidden_states[:num_tokens]
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
hidden_states = torch.index_select(
hidden_states, 0, self.runner.
@@ -855,6 +885,51 @@ class MtpProposer(Proposer):
last_token_indices = self.arange[:batch_size]
if getattr(attn_metadata_i, "num_decode_tokens", 0):
attn_metadata_i.num_decode_tokens = batch_size
if self.pcp_size * self.dcp_size > 1:
positions = target_positions[ori_last_token_indices]
# For pcp/dcp, tokens are split across different cp ranks,
# so we can not simply update slot_mapping by += 1.
# Instead, we pre-allocate mtp slot_mapping in model_runner
# (_generate_pcp_mtp_input), and use updated slot_indices
# to get corresponding slot_mapping in each step.
num_reject_tokens = torch.tensor(
self.runner.cu_num_tokens_pcp_full,
dtype=torch.int32).to(
self.device) - ori_last_token_indices - 1
num_accept_tokens = \
query_lens_d.to(self.device) - num_reject_tokens
ori_seq_len = attn_metadata_i.seq_lens
mtp_slot_mapping = self.runner.mtp_slot_pad
# slot_mapping index base offset:
# scheduled tokens + pre-allocated mtp tokens + accepted tokens
slot_idx_base = (
torch.cat([
torch.tensor(
[0], dtype=torch.int32, device=self.device),
(torch.cumsum(query_lens_d, dim=0)[:-1] *
self.pcp_size).to(self.device)
]) +
torch.arange(num_decode_reqs, device=self.device) *
(self.num_speculative_tokens - 1) * self.pcp_size +
(num_accept_tokens - 1) * self.pcp_size)
slot_indices_list = []
for req_id in range(num_decode_reqs):
slot_indices_list.append(
torch.arange(slot_idx_base[req_id],
slot_idx_base[req_id] + self.pcp_size,
device=self.device))
slot_indices = torch.cat(slot_indices_list, dim=0)
# fold block_table (restore it to original size before flattened)
block_indices = torch.cat([
torch.tensor([0], dtype=torch.int32),
torch.cumsum(query_lens_d, dim=0)[:-1]
])
attn_metadata_i.decode.block_table[:batch_size] = \
attn_metadata_i.decode.block_table[block_indices]
attn_metadata_i.decode.block_table = \
attn_metadata_i.decode.block_table[:batch_size]
input_ids = draft_token_ids_list[-1].int()
positions += 1
@@ -901,13 +976,40 @@ class MtpProposer(Proposer):
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
slot_mapping += 1
if self.pcp_size > 1:
exceeds_max_model_len = exceeds_max_model_len.repeat_interleave(
slot_mapping.size(0) // exceeds_max_model_len.size(0))
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:hidden_states.shape[0]] = hidden_states
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
if self.pcp_size * self.dcp_size > 1:
# update local seq_len and batch_seq_mask
num_computed_tokens_of_pcp_dcp = self.runner._get_cp_local_seq_lens(
ori_seq_len + step + 1,
self.pcp_size,
self.dcp_size,
self.runner.parallel_config.cp_kv_cache_interleave_size,
)
cp_seq_len = \
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank]
batch_seq_mask = (cp_seq_len == 0)
builder.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
batch_seq_mask, non_blocking=True)
batch_seq_mask = builder.batch_seq_mask_buf[:batch_seq_mask.
shape[0]]
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
attn_metadata_i.decode.cp_seq_len = cp_seq_len
attn_metadata_i.decode.batch_seq_mask = batch_seq_mask
# update slot_mapping
slot_indices += self.pcp_size
slot_mapping = mtp_slot_mapping[slot_indices]
attn_metadata_i.slot_mapping[:batch_size *
self.pcp_size] = slot_mapping
else:
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
if self.speculative_config.disable_padded_drafter_batch:
self.positions[batch_size:num_input_tokens] = 0
self.input_ids[batch_size:num_input_tokens] = 0