[feature] support pcp + mtp (with pd disaggregate) (#3822)
### What this PR does / why we need it?
support pcp + mtp (with pd disaggregate, only pcp in P nodes)
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
@@ -479,6 +479,22 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
if self.speculative_config and self.pcp_size > 1:
|
||||
self.input_ids_pcp_full = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.query_start_loc_pcp_full = torch.zeros(self.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.query_start_loc_pcp_full_np = self.query_start_loc_pcp_full.numpy(
|
||||
)
|
||||
self.positions_pcp_full = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.positions_np_pcp_full = self.positions_pcp_full.numpy()
|
||||
|
||||
self.use_aclgraph = self._use_aclgraph()
|
||||
self.aclgraph_batch_sizes = list(
|
||||
@@ -1598,7 +1614,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
]
|
||||
num_tokens_np = np.array(num_tokens, dtype=np.int32)
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np
|
||||
if self.pcp_size == 1:
|
||||
discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np
|
||||
else:
|
||||
# while pcp > 1, we need the original num_scheduled_tokens before split
|
||||
# to calculate discard_requests_mask
|
||||
original_seq_lens_np = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||
np.array(list(scheduler_output.num_scheduled_tokens.values())))
|
||||
discard_requests_mask = original_seq_lens_np < num_tokens_np
|
||||
discard_request_indices = np.nonzero(discard_requests_mask)[0]
|
||||
self.num_discarded_requests = len(discard_request_indices)
|
||||
self.discard_request_indices.np[:self.num_discarded_requests] = (
|
||||
@@ -1730,6 +1754,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||
self.num_accepted_tokens.copy_to_gpu()
|
||||
|
||||
is_prefill = len(scheduler_output.scheduled_new_reqs) > 0
|
||||
if self.speculative_config and self.pcp_size > 1 and is_prefill:
|
||||
self._generate_pcp_mtp_input(
|
||||
num_reqs, scheduler_output.total_num_scheduled_tokens,
|
||||
scheduler_output.num_scheduled_tokens)
|
||||
|
||||
# prepare pcp meta data
|
||||
long_seq_metadata = self._generate_pcp_metadata(
|
||||
total_num_scheduled_tokens, seq_lens_cpu)
|
||||
@@ -4419,4 +4449,46 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
'tail_attn_nomask_seqlens']
|
||||
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
|
||||
'pcp_prefill_mask']
|
||||
self.long_seq_metadata = long_seq_metadata
|
||||
return long_seq_metadata
|
||||
|
||||
def _generate_pcp_mtp_input(
|
||||
self,
|
||||
num_reqs: int,
|
||||
total_num_scheduled_tokens: int,
|
||||
num_scheduled_tokens: dict[str, int],
|
||||
):
|
||||
"""
|
||||
While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group,
|
||||
but mtp need to shift original input_ids before pcp splitting,
|
||||
so we record original input_ids here.
|
||||
"""
|
||||
total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens
|
||||
num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32)
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
|
||||
req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs],
|
||||
num_scheduled_tokens_pcp_full)
|
||||
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
|
||||
self.query_start_loc_pcp_full_np[0] = 0
|
||||
self.query_start_loc_pcp_full_np[1:num_reqs +
|
||||
1] = cu_num_tokens_pcp_full
|
||||
self.query_start_loc_pcp_full_np[num_reqs + 1:].fill(-1)
|
||||
cumsums_offsets_pcp_full = np.repeat(
|
||||
cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full,
|
||||
num_scheduled_tokens_pcp_full)
|
||||
arange_pcp_full = self.arange_np[:
|
||||
total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full
|
||||
positions_np_pcp_full = self.positions_np_pcp_full[:
|
||||
total_num_scheduled_tokens_pcp_full]
|
||||
np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full],
|
||||
arange_pcp_full,
|
||||
out=positions_np_pcp_full)
|
||||
token_indices_pcp_full = (
|
||||
positions_np_pcp_full +
|
||||
req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1])
|
||||
torch.index_select(
|
||||
self.input_batch.token_ids_cpu_tensor.flatten(),
|
||||
0,
|
||||
torch.from_numpy(token_indices_pcp_full),
|
||||
out=self.input_ids_pcp_full[:total_num_scheduled_tokens_pcp_full])
|
||||
|
||||
Reference in New Issue
Block a user