From 0f70698d6d4ea09ce827a1f90f84f3ad4d63ec6e Mon Sep 17 00:00:00 2001 From: zhangsicheng5 Date: Fri, 31 Oct 2025 15:43:22 +0800 Subject: [PATCH] [feature] support pcp + mtp (with pd disaggregate) (#3822) ### What this PR does / why we need it? support pcp + mtp (with pd disaggregate, only pcp in P nodes) - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac Signed-off-by: zhangsicheng5 --- vllm_ascend/spec_decode/mtp_proposer.py | 118 ++++++++++++++++++++++-- vllm_ascend/worker/model_runner_v1.py | 74 ++++++++++++++- 2 files changed, 185 insertions(+), 7 deletions(-) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 17274a51..c2a2a845 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -3,6 +3,7 @@ from typing import Optional import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from vllm.config import (CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, set_current_vllm_config) from vllm.forward_context import BatchDescriptor @@ -13,6 +14,7 @@ from vllm.model_executor.model_loader.utils import \ process_weights_after_loading from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.utils import cdiv from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.core.sched.output import SchedulerOutput @@ -22,7 +24,8 @@ from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm_ascend.ascend_forward_context import set_ascend_forward_context -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + AscendPrefillContextParallelMetadata) from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, vllm_version_is) @@ -67,6 +70,10 @@ class MtpProposer(Proposer): # hidden size (e.g., Llama 3.3 70B). self.hidden_size = self.draft_model_config.get_hidden_size() + self.pcp_size = self.runner.pcp_size + self.dcp_size = self.runner.dcp_size + self.pcp_rank = self.runner.pcp_rank + self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None self.draft_indexer_metadata_builder: Optional[ AttentionMetadataBuilder] = None @@ -99,6 +106,9 @@ class MtpProposer(Proposer): (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device) + self.full_indices = range( + self.runner.max_num_tokens * self.pcp_size * self.dcp_size + + self.pcp_size * self.dcp_size * self.runner.max_num_reqs) # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. @@ -238,12 +248,24 @@ class MtpProposer(Proposer): self.runner.num_discarded_requests ) + is_prefill = len(scheduler_output.scheduled_new_reqs) > 0 + req_scheduled_tokens = scheduler_output.num_scheduled_tokens + long_seq_metadata: AscendPrefillContextParallelMetadata = \ + self.runner.long_seq_metadata if self.pcp_size > 1 else None if spec_decode_metadata is None: - token_indices_to_sample = None - # input_ids can be None for multimodal models. - target_token_ids = self.runner.input_ids[:num_scheduled_tokens] - target_positions = positions[:num_scheduled_tokens] - target_hidden_states = hidden_states[:num_scheduled_tokens] + # update pcp related params + if self.pcp_size > 1 and is_prefill: + token_indices_to_sample = None + target_token_ids = self.runner.input_ids_pcp_full[: + num_scheduled_tokens] + target_positions = positions[:num_scheduled_tokens] + target_hidden_states = hidden_states + else: + token_indices_to_sample = None + # input_ids can be None for multimodal models. + target_token_ids = self.runner.input_ids[:num_scheduled_tokens] + target_positions = positions[:num_scheduled_tokens] + target_hidden_states = hidden_states[:num_scheduled_tokens] else: if self.speculative_config.disable_padded_drafter_batch: token_indices_to_sample = None @@ -271,6 +293,9 @@ class MtpProposer(Proposer): last_token_indices=token_indices_to_sample, common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata, + is_prefill=is_prefill, + req_scheduled_tokens=req_scheduled_tokens, + long_seq_metadata=long_seq_metadata, ) return draft_token_ids @@ -397,6 +422,9 @@ class MtpProposer(Proposer): sampling_metadata: SamplingMetadata, mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]] = None, + is_prefill=False, + req_scheduled_tokens=None, + long_seq_metadata=None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] @@ -417,6 +445,22 @@ class MtpProposer(Proposer): # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] self.input_ids[last_token_indices] = next_token_ids + # update pcp related params + if self.pcp_size > 1 and is_prefill: + num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens = \ + self._split_pcp_input(req_scheduled_tokens, num_tokens, target_hidden_states) + # graph mode padding not considered now + num_input_tokens = num_tokens + self.input_ids[:num_input_tokens].copy_(input_ids) + common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata + common_attn_metadata.num_actual_tokens = num_tokens + common_attn_metadata.max_query_len = max_query_len + common_attn_metadata.seq_lens_cpu = seq_lens.cpu() + common_attn_metadata.query_start_loc = \ + cu_num_tokens[:batch_size + 1] + common_attn_metadata.query_start_loc_cpu = \ + cu_num_tokens[:batch_size + 1].cpu() + assert self.runner is not None builder = self.runner.attn_groups[0][0].get_metadata_builder() @@ -767,3 +811,65 @@ class MtpProposer(Proposer): 1 - num_rejected_tokens_gpu) return spec_common_attn_metadata, token_indices, token_indices_to_sample + + def _split_pcp_input(self, req_scheduled_tokens, num_tokens, + target_hidden_states): + """ + Split input_ids and target_hidden_states in pcp group. + 1. input_ids padding: [t0, t1, t2, t3, t4, t5] -> [t0, t1, t2, t3, t4, t5, pad, pad] + 2. split input_ids: pcp0 [t0, t1, pad, pad], pcp1 [t2, t3, t4, t5] + 3. split target_hidden_states (already include cp padding): + [h0, h1, h2, h3, h4, h5, pad, pad] -> pcp0 [h0, h1, pad, pad], pcp1 [h2, h3, h4, h5] + 4. also update max_query_len, seq_lens, cu_num_tokens according to pcp split. + """ + + def _pcp_pad_and_split(num_tokens): + num_pcp_padded_scheduled_tokens = cdiv( + num_tokens, 2 * self.pcp_size) * 2 * self.pcp_size + pcp_pad = num_pcp_padded_scheduled_tokens - num_tokens + chunk_size = num_pcp_padded_scheduled_tokens // (2 * self.pcp_size) + + # split position_ids (and use split position_ids to split input_ids afterwards) + req_position_cp: list[int] = [] + req_position_cp.extend( + self.full_indices[self.pcp_rank * + chunk_size:(self.pcp_rank + 1) * chunk_size]) + req_position_cp.extend( + self.full_indices[num_pcp_padded_scheduled_tokens - + (self.pcp_rank + 1) * + chunk_size:num_pcp_padded_scheduled_tokens - + self.pcp_rank * chunk_size]) + + return req_position_cp, num_pcp_padded_scheduled_tokens, pcp_pad + + num_pcp_scheduled_tokens = [] + input_ids_list = self.input_ids[:num_tokens] + ori_start_index = 0 + pad_start_index = 0 + pcp_split_input_ids_list = [] + pcp_split_hidden_states_list = [] + for ori_num_tokens in req_scheduled_tokens.values(): + req_position_pcp, num_pcp_padded_scheduled_tokens, num_pcp_pad = \ + _pcp_pad_and_split(ori_num_tokens) + actual_num_tokens = len(req_position_pcp) + num_pcp_scheduled_tokens.append(actual_num_tokens) + pad_input_ids = F.pad( + input_ids_list[ori_start_index:ori_start_index + + ori_num_tokens], (0, num_pcp_pad)) + ori_start_index += ori_num_tokens + pcp_chunk_indices = [ + pad_start_index + pos for pos in req_position_pcp + ] + pcp_split_input_ids = pad_input_ids[req_position_pcp] + pcp_split_hidden_states = target_hidden_states[pcp_chunk_indices] + pcp_split_input_ids_list.append(pcp_split_input_ids) + pcp_split_hidden_states_list.append(pcp_split_hidden_states) + pad_start_index += num_pcp_padded_scheduled_tokens + num_tokens = sum(num_pcp_scheduled_tokens) + input_ids = torch.cat(pcp_split_input_ids_list) + target_hidden_states = torch.cat(pcp_split_hidden_states_list, dim=0) + max_query_len = max(num_pcp_scheduled_tokens) + seq_lens = torch.tensor(num_pcp_scheduled_tokens, dtype=torch.int32) + cu_num_tokens = torch.tensor( + np.insert(np.cumsum(np.array(num_pcp_scheduled_tokens)), 0, 0)) + return num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 85069454..a85d4cda 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -479,6 +479,22 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=self.device) + if self.speculative_config and self.pcp_size > 1: + self.input_ids_pcp_full = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=True) + self.query_start_loc_pcp_full = torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device="cpu", + pin_memory=True) + self.query_start_loc_pcp_full_np = self.query_start_loc_pcp_full.numpy( + ) + self.positions_pcp_full = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=True) + self.positions_np_pcp_full = self.positions_pcp_full.numpy() self.use_aclgraph = self._use_aclgraph() self.aclgraph_batch_sizes = list( @@ -1598,7 +1614,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): ] num_tokens_np = np.array(num_tokens, dtype=np.int32) num_reqs = self.input_batch.num_reqs - discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np + if self.pcp_size == 1: + discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np + else: + # while pcp > 1, we need the original num_scheduled_tokens before split + # to calculate discard_requests_mask + original_seq_lens_np = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + np.array(list(scheduler_output.num_scheduled_tokens.values()))) + discard_requests_mask = original_seq_lens_np < num_tokens_np discard_request_indices = np.nonzero(discard_requests_mask)[0] self.num_discarded_requests = len(discard_request_indices) self.discard_request_indices.np[:self.num_discarded_requests] = ( @@ -1730,6 +1754,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() + is_prefill = len(scheduler_output.scheduled_new_reqs) > 0 + if self.speculative_config and self.pcp_size > 1 and is_prefill: + self._generate_pcp_mtp_input( + num_reqs, scheduler_output.total_num_scheduled_tokens, + scheduler_output.num_scheduled_tokens) + # prepare pcp meta data long_seq_metadata = self._generate_pcp_metadata( total_num_scheduled_tokens, seq_lens_cpu) @@ -4419,4 +4449,46 @@ class NPUModelRunner(LoRAModelRunnerMixin): 'tail_attn_nomask_seqlens'] long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[ 'pcp_prefill_mask'] + self.long_seq_metadata = long_seq_metadata return long_seq_metadata + + def _generate_pcp_mtp_input( + self, + num_reqs: int, + total_num_scheduled_tokens: int, + num_scheduled_tokens: dict[str, int], + ): + """ + While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group, + but mtp need to shift original input_ids before pcp splitting, + so we record original input_ids here. + """ + total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens + num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32) + for i, req_id in enumerate(self.input_batch.req_ids): + num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id] + req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs], + num_scheduled_tokens_pcp_full) + cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full) + self.query_start_loc_pcp_full_np[0] = 0 + self.query_start_loc_pcp_full_np[1:num_reqs + + 1] = cu_num_tokens_pcp_full + self.query_start_loc_pcp_full_np[num_reqs + 1:].fill(-1) + cumsums_offsets_pcp_full = np.repeat( + cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full, + num_scheduled_tokens_pcp_full) + arange_pcp_full = self.arange_np[: + total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full + positions_np_pcp_full = self.positions_np_pcp_full[: + total_num_scheduled_tokens_pcp_full] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full], + arange_pcp_full, + out=positions_np_pcp_full) + token_indices_pcp_full = ( + positions_np_pcp_full + + req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1]) + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices_pcp_full), + out=self.input_ids_pcp_full[:total_num_scheduled_tokens_pcp_full])