[feature] support pcp + mtp (in pd co-locate scenario) (#4098)

1. support pcp + mtp in pd co-locate scenario
2. llmdatadist connector pcp related bugfix and cleancode

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
zhangsicheng5
2025-11-12 17:22:21 +08:00
committed by GitHub
parent 1b4ce63ec9
commit a123f355e9
6 changed files with 246 additions and 97 deletions

View File

@@ -487,19 +487,29 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.speculative_config and self.pcp_size > 1:
self.input_ids_pcp_full = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=True)
device=self.device)
self.input_ids_pcp_full_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=True)
self.query_start_loc_pcp_full = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=True)
self.query_start_loc_pcp_full_np = self.query_start_loc_pcp_full.numpy(
)
device=self.device)
self.query_start_loc_pcp_full_cpu = \
torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=True)
self.query_start_loc_pcp_full_np = \
self.query_start_loc_pcp_full_cpu.numpy()
self.positions_pcp_full = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=True)
self.positions_np_pcp_full = self.positions_pcp_full.numpy()
self.positions_pcp_full_np = self.positions_pcp_full.numpy()
self.decode_threshold = 1 + (
self.speculative_config.num_speculative_tokens
if self.speculative_config else 0)
self.use_aclgraph = self._use_aclgraph()
self.aclgraph_batch_sizes = list(
@@ -1854,8 +1864,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logits_indices = torch.from_numpy(cu_num_tokens - 1).to(
self.device, non_blocking=True)
else:
# pcp not supported now
assert self.pcp_size == 1
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
@@ -1866,11 +1874,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
num_draft_tokens, cu_num_tokens, self.num_pcp_pads[:num_reqs])
logits_indices = spec_decode_metadata.logits_indices
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
self.num_draft_tokens.np[num_reqs:].fill(0)
self.num_draft_tokens.copy_to_gpu()
# save logits_indices for pcp spec decode usage
self.logits_indices = logits_indices
# Used in the below loop.
# query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
@@ -1883,8 +1893,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
is_prefill = len(scheduler_output.scheduled_new_reqs) > 0
if self.speculative_config and self.pcp_size > 1 and is_prefill:
if self.speculative_config and self.pcp_size > 1:
self._generate_pcp_mtp_input(
num_reqs, scheduler_output.total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens)
@@ -2040,8 +2049,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_dcp_pcp_params(self.update_stream,
forward_context,
maybe_padded_num_tokens,
self.speculative_config)
maybe_padded_num_tokens)
else:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context,
@@ -2110,6 +2118,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self,
num_draft_tokens: np.ndarray,
cu_num_scheduled_tokens: np.ndarray,
num_pcp_pads: np.ndarray,
) -> SpecDecodeMetadata:
# Inputs:
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
@@ -2138,6 +2147,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
logits_indices += arange
# while pcp > 1, decode results may contain padding (from pcp all-gather),
# update logits_indices after getting draft_token_ids from ori logits_indices
if self.pcp_size > 1:
cu_num_scheduled_tokens = cu_num_scheduled_tokens * self.pcp_size - num_pcp_pads
logits_indices_pcp = np.repeat(
cu_num_scheduled_tokens - num_sampled_tokens,
num_sampled_tokens)
logits_indices_pcp += arange
logits_indices_pcp = torch.from_numpy(logits_indices_pcp).to(
self.device, non_blocking=True)
# Compute the bonus logits indices.
bonus_logits_indices = cu_num_sampled_tokens - 1
@@ -2173,6 +2193,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
draft_token_ids = self.input_ids[logits_indices]
draft_token_ids = draft_token_ids[target_logits_indices + 1]
if self.pcp_size > 1:
logits_indices = logits_indices_pcp
if vllm_version_is("0.11.0"):
metadata = SpecDecodeMetadata(
draft_token_ids=draft_token_ids,
@@ -2920,8 +2942,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_dcp_pcp_params(self.update_stream,
forward_context,
positions.shape[0],
self.speculative_config)
positions.shape[0])
else:
# FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context,
@@ -4328,18 +4349,25 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_decode_reqs = sum(
self.input_batch.num_computed_tokens_cpu[:num_reqs] >=
self.input_batch.num_prompt_tokens[:num_reqs])
num_decode_tokens = sum(tokens[:num_decode_reqs])
num_padded_scheduled_tokens = np.ceil(
tokens /
(2 * self.pcp_size)).astype(np.int32) * (2 * self.pcp_size)
num_padded_scheduled_tokens[:num_decode_reqs] = self.pcp_size
num_padded_scheduled_tokens[:num_decode_reqs] = (
tokens[:num_decode_reqs] * self.pcp_size)
self.num_pcp_pads = num_padded_scheduled_tokens - tokens
cu_padded_tokens, pcp_padded_arange = \
self._get_cumsum_and_arange(num_padded_scheduled_tokens)
unpad_mask = torch.from_numpy(
pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens))
unpad_mask_decode = unpad_mask[:num_decode_tokens * self.pcp_size]
unpad_mask_decode = unpad_mask_decode.reshape([-1, self.pcp_size])
unpad_mask_decode[:, 0] = True
unpad_mask_decode[:, 1:] = False
pcp_tokens = num_padded_scheduled_tokens // self.pcp_size
pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1)
pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs]
_, pcp_arange = self._get_cumsum_and_arange(pcp_tokens)
_, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes)
pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes,
@@ -4356,14 +4384,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
np.repeat(head_start_loc, pcp_chunk_sizes)
# Decode reqs do not have tail chunks.
positions[~pcp_head_chunk_mask] = \
pcp_chunk_arange[num_decode_reqs:] + \
np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_reqs:]
pcp_chunk_arange[num_decode_tokens:] + \
np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:]
return positions
positions = get_current_rank_positions(
np.zeros(num_reqs, dtype=np.int32), self.pcp_rank)
# Decode tokens are duplicate and their positions always be 0.
positions[:num_decode_reqs] = 0
if num_decode_reqs > 0:
positions[:num_decode_tokens] = self._get_cumsum_and_arange(
tokens[:num_decode_reqs])[1]
all_positions = [
get_current_rank_positions(cu_padded_tokens, rank_i)
@@ -4372,7 +4402,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
all_positions_tensor = torch.from_numpy(np.concatenate(all_positions))
self.pcp_allgather_restore_idx[:all_positions_tensor.shape[0]].copy_(
all_positions_tensor.float().argsort().long(), non_blocking=True)
pcp_tokens[:num_decode_reqs] = 1
return pcp_tokens, positions, unpad_mask
def _get_pcp_local_seq_lens(
@@ -4524,7 +4553,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens,
seq_lens_origin):
num_reqs = self.input_batch.num_reqs
# In dummy run num_reqs == 0, update it from seq_lens
num_reqs = self.input_batch.num_reqs or seq_lens.size(0)
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
>= self.input_batch.num_prompt_tokens[:num_reqs])
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
@@ -4535,14 +4565,28 @@ class NPUModelRunner(LoRAModelRunnerMixin):
local_chunked_kv_lens)
long_seq_metadata = None
if self.pcp_size * self.dcp_size > 1:
num_computed_tokens_of_pcp_dcp = torch.zeros(
[
num_reqs * self.decode_threshold, self.pcp_size,
self.dcp_size
],
dtype=torch.int32,
)
# For pcp + spec decode, we flatten seq_lens
# to avoid irregular spec_attn_mask shape
for decode_idx in range(self.decode_threshold):
num_computed_tokens_of_pcp_dcp[
self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \
self._get_pcp_local_seq_lens(
seq_lens_origin - decode_idx,
self.pcp_size,
self.dcp_size,
self.parallel_config.cp_kv_cache_interleave_size,
)
long_seq_metadata = AscendPrefillContextParallelMetadata(
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_computed_tokens_of_pcp_dcp=self._get_pcp_local_seq_lens(
seq_lens_origin,
self.pcp_size,
self.dcp_size,
self.parallel_config.cp_kv_cache_interleave_size,
).numpy(),
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.
numpy(),
local_chunked_kv_lens=local_chunked_kv_lens,
mask_for_non_zero_chunk=mask_for_non_zero_chunk,
max_chunk_num=max_chunk_num)
@@ -4706,16 +4750,25 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens_pcp_full)
arange_pcp_full = self.arange_np[:
total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full
positions_np_pcp_full = self.positions_np_pcp_full[:
positions_pcp_full_np = self.positions_pcp_full_np[:
total_num_scheduled_tokens_pcp_full]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full],
arange_pcp_full,
out=positions_np_pcp_full)
out=positions_pcp_full_np)
token_indices_pcp_full = (
positions_np_pcp_full +
positions_pcp_full_np +
req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1])
torch.index_select(
self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices_pcp_full),
out=self.input_ids_pcp_full[:total_num_scheduled_tokens_pcp_full])
out=self.
input_ids_pcp_full_cpu[:total_num_scheduled_tokens_pcp_full])
self.query_start_loc_pcp_full[:num_reqs + 1].copy_(
self.query_start_loc_pcp_full_cpu[:num_reqs + 1],
non_blocking=True,
)
self.input_ids_pcp_full[:total_num_scheduled_tokens_pcp_full].copy_(
self.input_ids_pcp_full_cpu[:total_num_scheduled_tokens_pcp_full],
non_blocking=True,
)