[Feature] Refactor PCP &DCP related code (#5214)
### What this PR does / why we need it?
Refactor pcp& dcp related code. we use pcp_manager class to Unifiy
Manage pcp & dcp . as we do this , many code can be deleted from
model_runner, and can avoid break pcp & dcp by other developments.
RFC:https://github.com/vllm-project/vllm-ascend/issues/5449
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
This commit is contained in:
@@ -179,7 +179,7 @@ class MtpProposer(EagleProposer):
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
# update long_seq related params and flatten block_table
|
||||
common_attn_metadata.prefill_context_parallel_metadata = \
|
||||
self.runner.long_seq_metadata
|
||||
self.runner.pcp_manager.long_seq_metadata
|
||||
common_attn_metadata.block_table_tensor = \
|
||||
self.runner.input_batch.block_table[0].get_device_tensor()[
|
||||
:num_reqs * self.decode_threshold]
|
||||
@@ -286,9 +286,9 @@ class MtpProposer(EagleProposer):
|
||||
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
long_seq_metadata = self.runner.long_seq_metadata
|
||||
input_ids_pcp_full = self.runner.input_ids_pcp_full.gpu
|
||||
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full.gpu
|
||||
query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full.cpu
|
||||
input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu
|
||||
query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu
|
||||
query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu
|
||||
num_reqs = self.runner.input_batch.num_reqs
|
||||
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
|
||||
query_start_loc_pcp_full_cpu[:num_reqs]
|
||||
@@ -303,7 +303,7 @@ class MtpProposer(EagleProposer):
|
||||
# update pcp related params
|
||||
if self.pcp_size > 1:
|
||||
token_indices_to_sample = \
|
||||
query_start_loc_pcp_full_cpu[1:num_reqs + 1] - 1
|
||||
query_start_loc_pcp_full[1:num_reqs + 1] - 1
|
||||
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
target_hidden_states = hidden_states
|
||||
@@ -751,8 +751,8 @@ class MtpProposer(EagleProposer):
|
||||
hidden_states = hidden_states[:num_tokens]
|
||||
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
|
||||
hidden_states = torch.index_select(
|
||||
hidden_states, 0, self.runner.
|
||||
pcp_allgather_restore_idx[:hidden_states.shape[0]])
|
||||
hidden_states, 0, self.runner.pcp_manager.
|
||||
pcp_allgather_restore_idx.gpu[:hidden_states.shape[0]])
|
||||
|
||||
sample_hidden_states = hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
@@ -797,13 +797,13 @@ class MtpProposer(EagleProposer):
|
||||
# (_generate_pcp_mtp_input), and use updated slot_indices
|
||||
# to get corresponding slot_mapping in each step.
|
||||
num_reject_tokens = torch.tensor(
|
||||
self.runner.cu_num_tokens_pcp_full,
|
||||
self.runner.pcp_manager.cu_num_tokens_pcp_full,
|
||||
dtype=torch.int32).to(
|
||||
self.device) - ori_last_token_indices - 1
|
||||
num_accept_tokens = \
|
||||
query_lens_d.to(self.device) - num_reject_tokens
|
||||
ori_seq_len = attn_metadata_i.seq_lens
|
||||
mtp_slot_mapping = self.runner.mtp_slot_pad
|
||||
mtp_slot_mapping = self.runner.pcp_manager.mtp_slot_pad
|
||||
|
||||
# slot_mapping index base offset:
|
||||
# scheduled tokens + pre-allocated mtp tokens + accepted tokens
|
||||
@@ -889,7 +889,7 @@ class MtpProposer(EagleProposer):
|
||||
self.hidden_states[:hidden_states.shape[0]] = hidden_states
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
# update local seq_len and batch_seq_mask
|
||||
num_computed_tokens_of_pcp_dcp = self.runner._get_cp_local_seq_lens(
|
||||
num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens(
|
||||
ori_seq_len + step + 1,
|
||||
self.pcp_size,
|
||||
self.dcp_size,
|
||||
|
||||
Reference in New Issue
Block a user