[feature] support pcp + mtp (with pd disaggregate) (#3822)

### What this PR does / why we need it?
support pcp + mtp (with pd disaggregate, only pcp in P nodes)

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
zhangsicheng5
2025-10-31 15:43:22 +08:00
committed by GitHub
parent f99762eb25
commit 0f70698d6d
2 changed files with 185 additions and 7 deletions

View File

@@ -3,6 +3,7 @@ from typing import Optional
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from vllm.config import (CUDAGraphMode, VllmConfig, from vllm.config import (CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config, set_current_vllm_config) get_layers_from_vllm_config, set_current_vllm_config)
from vllm.forward_context import BatchDescriptor from vllm.forward_context import BatchDescriptor
@@ -13,6 +14,7 @@ from vllm.model_executor.model_loader.utils import \
process_weights_after_loading process_weights_after_loading
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
@@ -22,7 +24,8 @@ from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
AscendPrefillContextParallelMetadata)
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
vllm_version_is) vllm_version_is)
@@ -67,6 +70,10 @@ class MtpProposer(Proposer):
# hidden size (e.g., Llama 3.3 70B). # hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size() self.hidden_size = self.draft_model_config.get_hidden_size()
self.pcp_size = self.runner.pcp_size
self.dcp_size = self.runner.dcp_size
self.pcp_rank = self.runner.pcp_rank
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
self.draft_indexer_metadata_builder: Optional[ self.draft_indexer_metadata_builder: Optional[
AttentionMetadataBuilder] = None AttentionMetadataBuilder] = None
@@ -99,6 +106,9 @@ class MtpProposer(Proposer):
(self.max_num_tokens, self.hidden_size), (self.max_num_tokens, self.hidden_size),
dtype=self.dtype, dtype=self.dtype,
device=device) device=device)
self.full_indices = range(
self.runner.max_num_tokens * self.pcp_size * self.dcp_size +
self.pcp_size * self.dcp_size * self.runner.max_num_reqs)
# We need +1 here because the arange is used to set query_start_loc, # We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size. # which has one more element than batch_size.
@@ -238,12 +248,24 @@ class MtpProposer(Proposer):
self.runner.num_discarded_requests self.runner.num_discarded_requests
) )
is_prefill = len(scheduler_output.scheduled_new_reqs) > 0
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
long_seq_metadata: AscendPrefillContextParallelMetadata = \
self.runner.long_seq_metadata if self.pcp_size > 1 else None
if spec_decode_metadata is None: if spec_decode_metadata is None:
token_indices_to_sample = None # update pcp related params
# input_ids can be None for multimodal models. if self.pcp_size > 1 and is_prefill:
target_token_ids = self.runner.input_ids[:num_scheduled_tokens] token_indices_to_sample = None
target_positions = positions[:num_scheduled_tokens] target_token_ids = self.runner.input_ids_pcp_full[:
target_hidden_states = hidden_states[:num_scheduled_tokens] num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states
else:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states[:num_scheduled_tokens]
else: else:
if self.speculative_config.disable_padded_drafter_batch: if self.speculative_config.disable_padded_drafter_batch:
token_indices_to_sample = None token_indices_to_sample = None
@@ -271,6 +293,9 @@ class MtpProposer(Proposer):
last_token_indices=token_indices_to_sample, last_token_indices=token_indices_to_sample,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
is_prefill=is_prefill,
req_scheduled_tokens=req_scheduled_tokens,
long_seq_metadata=long_seq_metadata,
) )
return draft_token_ids return draft_token_ids
@@ -397,6 +422,9 @@ class MtpProposer(Proposer):
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
mm_embed_inputs: Optional[tuple[list[torch.Tensor], mm_embed_inputs: Optional[tuple[list[torch.Tensor],
torch.Tensor]] = None, torch.Tensor]] = None,
is_prefill=False,
req_scheduled_tokens=None,
long_seq_metadata=None,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = target_token_ids.shape[0] num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0] batch_size = next_token_ids.shape[0]
@@ -417,6 +445,22 @@ class MtpProposer(Proposer):
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids self.input_ids[last_token_indices] = next_token_ids
# update pcp related params
if self.pcp_size > 1 and is_prefill:
num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens = \
self._split_pcp_input(req_scheduled_tokens, num_tokens, target_hidden_states)
# graph mode padding not considered now
num_input_tokens = num_tokens
self.input_ids[:num_input_tokens].copy_(input_ids)
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
common_attn_metadata.num_actual_tokens = num_tokens
common_attn_metadata.max_query_len = max_query_len
common_attn_metadata.seq_lens_cpu = seq_lens.cpu()
common_attn_metadata.query_start_loc = \
cu_num_tokens[:batch_size + 1]
common_attn_metadata.query_start_loc_cpu = \
cu_num_tokens[:batch_size + 1].cpu()
assert self.runner is not None assert self.runner is not None
builder = self.runner.attn_groups[0][0].get_metadata_builder() builder = self.runner.attn_groups[0][0].get_metadata_builder()
@@ -767,3 +811,65 @@ class MtpProposer(Proposer):
1 - num_rejected_tokens_gpu) 1 - num_rejected_tokens_gpu)
return spec_common_attn_metadata, token_indices, token_indices_to_sample return spec_common_attn_metadata, token_indices, token_indices_to_sample
def _split_pcp_input(self, req_scheduled_tokens, num_tokens,
target_hidden_states):
"""
Split input_ids and target_hidden_states in pcp group.
1. input_ids padding: [t0, t1, t2, t3, t4, t5] -> [t0, t1, t2, t3, t4, t5, pad, pad]
2. split input_ids: pcp0 [t0, t1, pad, pad], pcp1 [t2, t3, t4, t5]
3. split target_hidden_states (already include cp padding):
[h0, h1, h2, h3, h4, h5, pad, pad] -> pcp0 [h0, h1, pad, pad], pcp1 [h2, h3, h4, h5]
4. also update max_query_len, seq_lens, cu_num_tokens according to pcp split.
"""
def _pcp_pad_and_split(num_tokens):
num_pcp_padded_scheduled_tokens = cdiv(
num_tokens, 2 * self.pcp_size) * 2 * self.pcp_size
pcp_pad = num_pcp_padded_scheduled_tokens - num_tokens
chunk_size = num_pcp_padded_scheduled_tokens // (2 * self.pcp_size)
# split position_ids (and use split position_ids to split input_ids afterwards)
req_position_cp: list[int] = []
req_position_cp.extend(
self.full_indices[self.pcp_rank *
chunk_size:(self.pcp_rank + 1) * chunk_size])
req_position_cp.extend(
self.full_indices[num_pcp_padded_scheduled_tokens -
(self.pcp_rank + 1) *
chunk_size:num_pcp_padded_scheduled_tokens -
self.pcp_rank * chunk_size])
return req_position_cp, num_pcp_padded_scheduled_tokens, pcp_pad
num_pcp_scheduled_tokens = []
input_ids_list = self.input_ids[:num_tokens]
ori_start_index = 0
pad_start_index = 0
pcp_split_input_ids_list = []
pcp_split_hidden_states_list = []
for ori_num_tokens in req_scheduled_tokens.values():
req_position_pcp, num_pcp_padded_scheduled_tokens, num_pcp_pad = \
_pcp_pad_and_split(ori_num_tokens)
actual_num_tokens = len(req_position_pcp)
num_pcp_scheduled_tokens.append(actual_num_tokens)
pad_input_ids = F.pad(
input_ids_list[ori_start_index:ori_start_index +
ori_num_tokens], (0, num_pcp_pad))
ori_start_index += ori_num_tokens
pcp_chunk_indices = [
pad_start_index + pos for pos in req_position_pcp
]
pcp_split_input_ids = pad_input_ids[req_position_pcp]
pcp_split_hidden_states = target_hidden_states[pcp_chunk_indices]
pcp_split_input_ids_list.append(pcp_split_input_ids)
pcp_split_hidden_states_list.append(pcp_split_hidden_states)
pad_start_index += num_pcp_padded_scheduled_tokens
num_tokens = sum(num_pcp_scheduled_tokens)
input_ids = torch.cat(pcp_split_input_ids_list)
target_hidden_states = torch.cat(pcp_split_hidden_states_list, dim=0)
max_query_len = max(num_pcp_scheduled_tokens)
seq_lens = torch.tensor(num_pcp_scheduled_tokens, dtype=torch.int32)
cu_num_tokens = torch.tensor(
np.insert(np.cumsum(np.array(num_pcp_scheduled_tokens)), 0, 0))
return num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens

View File

@@ -479,6 +479,22 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens, self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens,
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
if self.speculative_config and self.pcp_size > 1:
self.input_ids_pcp_full = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=True)
self.query_start_loc_pcp_full = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=True)
self.query_start_loc_pcp_full_np = self.query_start_loc_pcp_full.numpy(
)
self.positions_pcp_full = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=True)
self.positions_np_pcp_full = self.positions_pcp_full.numpy()
self.use_aclgraph = self._use_aclgraph() self.use_aclgraph = self._use_aclgraph()
self.aclgraph_batch_sizes = list( self.aclgraph_batch_sizes = list(
@@ -1598,7 +1614,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
] ]
num_tokens_np = np.array(num_tokens, dtype=np.int32) num_tokens_np = np.array(num_tokens, dtype=np.int32)
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np if self.pcp_size == 1:
discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np
else:
# while pcp > 1, we need the original num_scheduled_tokens before split
# to calculate discard_requests_mask
original_seq_lens_np = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
np.array(list(scheduler_output.num_scheduled_tokens.values())))
discard_requests_mask = original_seq_lens_np < num_tokens_np
discard_request_indices = np.nonzero(discard_requests_mask)[0] discard_request_indices = np.nonzero(discard_requests_mask)[0]
self.num_discarded_requests = len(discard_request_indices) self.num_discarded_requests = len(discard_request_indices)
self.discard_request_indices.np[:self.num_discarded_requests] = ( self.discard_request_indices.np[:self.num_discarded_requests] = (
@@ -1730,6 +1754,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu() self.num_accepted_tokens.copy_to_gpu()
is_prefill = len(scheduler_output.scheduled_new_reqs) > 0
if self.speculative_config and self.pcp_size > 1 and is_prefill:
self._generate_pcp_mtp_input(
num_reqs, scheduler_output.total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens)
# prepare pcp meta data # prepare pcp meta data
long_seq_metadata = self._generate_pcp_metadata( long_seq_metadata = self._generate_pcp_metadata(
total_num_scheduled_tokens, seq_lens_cpu) total_num_scheduled_tokens, seq_lens_cpu)
@@ -4419,4 +4449,46 @@ class NPUModelRunner(LoRAModelRunnerMixin):
'tail_attn_nomask_seqlens'] 'tail_attn_nomask_seqlens']
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[ long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
'pcp_prefill_mask'] 'pcp_prefill_mask']
self.long_seq_metadata = long_seq_metadata
return long_seq_metadata return long_seq_metadata
def _generate_pcp_mtp_input(
self,
num_reqs: int,
total_num_scheduled_tokens: int,
num_scheduled_tokens: dict[str, int],
):
"""
While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group,
but mtp need to shift original input_ids before pcp splitting,
so we record original input_ids here.
"""
total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens
num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32)
for i, req_id in enumerate(self.input_batch.req_ids):
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens_pcp_full)
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
self.query_start_loc_pcp_full_np[0] = 0
self.query_start_loc_pcp_full_np[1:num_reqs +
1] = cu_num_tokens_pcp_full
self.query_start_loc_pcp_full_np[num_reqs + 1:].fill(-1)
cumsums_offsets_pcp_full = np.repeat(
cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full,
num_scheduled_tokens_pcp_full)
arange_pcp_full = self.arange_np[:
total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full
positions_np_pcp_full = self.positions_np_pcp_full[:
total_num_scheduled_tokens_pcp_full]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full],
arange_pcp_full,
out=positions_np_pcp_full)
token_indices_pcp_full = (
positions_np_pcp_full +
req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1])
torch.index_select(
self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices_pcp_full),
out=self.input_ids_pcp_full[:total_num_scheduled_tokens_pcp_full])