[feature] support pcp + mtp (with pd disaggregate) (#3822)

### What this PR does / why we need it?
support pcp + mtp (with pd disaggregate, only pcp in P nodes)

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
zhangsicheng5
2025-10-31 15:43:22 +08:00
committed by GitHub
parent f99762eb25
commit 0f70698d6d
2 changed files with 185 additions and 7 deletions

View File

@@ -479,6 +479,22 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
if self.speculative_config and self.pcp_size > 1:
self.input_ids_pcp_full = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=True)
self.query_start_loc_pcp_full = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=True)
self.query_start_loc_pcp_full_np = self.query_start_loc_pcp_full.numpy(
)
self.positions_pcp_full = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=True)
self.positions_np_pcp_full = self.positions_pcp_full.numpy()
self.use_aclgraph = self._use_aclgraph()
self.aclgraph_batch_sizes = list(
@@ -1598,7 +1614,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
]
num_tokens_np = np.array(num_tokens, dtype=np.int32)
num_reqs = self.input_batch.num_reqs
discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np
if self.pcp_size == 1:
discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np
else:
# while pcp > 1, we need the original num_scheduled_tokens before split
# to calculate discard_requests_mask
original_seq_lens_np = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
np.array(list(scheduler_output.num_scheduled_tokens.values())))
discard_requests_mask = original_seq_lens_np < num_tokens_np
discard_request_indices = np.nonzero(discard_requests_mask)[0]
self.num_discarded_requests = len(discard_request_indices)
self.discard_request_indices.np[:self.num_discarded_requests] = (
@@ -1730,6 +1754,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
is_prefill = len(scheduler_output.scheduled_new_reqs) > 0
if self.speculative_config and self.pcp_size > 1 and is_prefill:
self._generate_pcp_mtp_input(
num_reqs, scheduler_output.total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens)
# prepare pcp meta data
long_seq_metadata = self._generate_pcp_metadata(
total_num_scheduled_tokens, seq_lens_cpu)
@@ -4419,4 +4449,46 @@ class NPUModelRunner(LoRAModelRunnerMixin):
'tail_attn_nomask_seqlens']
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
'pcp_prefill_mask']
self.long_seq_metadata = long_seq_metadata
return long_seq_metadata
def _generate_pcp_mtp_input(
self,
num_reqs: int,
total_num_scheduled_tokens: int,
num_scheduled_tokens: dict[str, int],
):
"""
While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group,
but mtp need to shift original input_ids before pcp splitting,
so we record original input_ids here.
"""
total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens
num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32)
for i, req_id in enumerate(self.input_batch.req_ids):
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens_pcp_full)
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
self.query_start_loc_pcp_full_np[0] = 0
self.query_start_loc_pcp_full_np[1:num_reqs +
1] = cu_num_tokens_pcp_full
self.query_start_loc_pcp_full_np[num_reqs + 1:].fill(-1)
cumsums_offsets_pcp_full = np.repeat(
cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full,
num_scheduled_tokens_pcp_full)
arange_pcp_full = self.arange_np[:
total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full
positions_np_pcp_full = self.positions_np_pcp_full[:
total_num_scheduled_tokens_pcp_full]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full],
arange_pcp_full,
out=positions_np_pcp_full)
token_indices_pcp_full = (
positions_np_pcp_full +
req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1])
torch.index_select(
self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices_pcp_full),
out=self.input_ids_pcp_full[:total_num_scheduled_tokens_pcp_full])