[Feature] Refactor PCP &DCP related code (#5214)

### What this PR does / why we need it?
Refactor pcp& dcp related code. we use pcp_manager class to Unifiy
Manage pcp & dcp . as we do this , many code can be deleted from
model_runner, and can avoid break pcp & dcp by other developments.
RFC:https://github.com/vllm-project/vllm-ascend/issues/5449
### Does this PR introduce _any_ user-facing change?
NO

### How was this patch tested?

- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
This commit is contained in:
zhenwenqi2024
2025-12-31 09:29:57 +08:00
committed by GitHub
parent 46862ce1af
commit 5d9fde9819
7 changed files with 1156 additions and 1047 deletions

View File

@@ -179,7 +179,7 @@ class MtpProposer(EagleProposer):
if self.pcp_size * self.dcp_size > 1:
# update long_seq related params and flatten block_table
common_attn_metadata.prefill_context_parallel_metadata = \
self.runner.long_seq_metadata
self.runner.pcp_manager.long_seq_metadata
common_attn_metadata.block_table_tensor = \
self.runner.input_batch.block_table[0].get_device_tensor()[
:num_reqs * self.decode_threshold]
@@ -286,9 +286,9 @@ class MtpProposer(EagleProposer):
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
if self.pcp_size * self.dcp_size > 1:
long_seq_metadata = self.runner.long_seq_metadata
input_ids_pcp_full = self.runner.input_ids_pcp_full.gpu
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full.gpu
query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full.cpu
input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu
query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu
query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu
num_reqs = self.runner.input_batch.num_reqs
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
query_start_loc_pcp_full_cpu[:num_reqs]
@@ -303,7 +303,7 @@ class MtpProposer(EagleProposer):
# update pcp related params
if self.pcp_size > 1:
token_indices_to_sample = \
query_start_loc_pcp_full_cpu[1:num_reqs + 1] - 1
query_start_loc_pcp_full[1:num_reqs + 1] - 1
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states
@@ -751,8 +751,8 @@ class MtpProposer(EagleProposer):
hidden_states = hidden_states[:num_tokens]
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
hidden_states = torch.index_select(
hidden_states, 0, self.runner.
pcp_allgather_restore_idx[:hidden_states.shape[0]])
hidden_states, 0, self.runner.pcp_manager.
pcp_allgather_restore_idx.gpu[:hidden_states.shape[0]])
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
@@ -797,13 +797,13 @@ class MtpProposer(EagleProposer):
# (_generate_pcp_mtp_input), and use updated slot_indices
# to get corresponding slot_mapping in each step.
num_reject_tokens = torch.tensor(
self.runner.cu_num_tokens_pcp_full,
self.runner.pcp_manager.cu_num_tokens_pcp_full,
dtype=torch.int32).to(
self.device) - ori_last_token_indices - 1
num_accept_tokens = \
query_lens_d.to(self.device) - num_reject_tokens
ori_seq_len = attn_metadata_i.seq_lens
mtp_slot_mapping = self.runner.mtp_slot_pad
mtp_slot_mapping = self.runner.pcp_manager.mtp_slot_pad
# slot_mapping index base offset:
# scheduled tokens + pre-allocated mtp tokens + accepted tokens
@@ -889,7 +889,7 @@ class MtpProposer(EagleProposer):
self.hidden_states[:hidden_states.shape[0]] = hidden_states
if self.pcp_size * self.dcp_size > 1:
# update local seq_len and batch_seq_mask
num_computed_tokens_of_pcp_dcp = self.runner._get_cp_local_seq_lens(
num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens(
ori_seq_len + step + 1,
self.pcp_size,
self.dcp_size,