[feature] support pcp + mtp (in pd co-locate scenario) (#4098)
1. support pcp + mtp in pd co-locate scenario
2. llmdatadist connector pcp related bugfix and cleancode
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
@@ -25,12 +25,15 @@ from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
AscendPrefillContextParallelMetadata)
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
||||
prefill_context_parallel_enable,
|
||||
vllm_version_is)
|
||||
|
||||
if prefill_context_parallel_enable():
|
||||
from vllm.distributed import get_pcp_group
|
||||
|
||||
if vllm_version_is("0.11.0"):
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.utils import is_pin_memory_available
|
||||
@@ -84,6 +87,7 @@ class MtpProposer(Proposer):
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
|
||||
self.decode_threshold = 1 + self.num_speculative_tokens
|
||||
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
self.token_arange_np = np.arange(self.max_num_tokens)
|
||||
# We need to get the hidden size from the draft model config because
|
||||
@@ -276,16 +280,28 @@ class MtpProposer(Proposer):
|
||||
self.runner.num_discarded_requests
|
||||
)
|
||||
|
||||
is_prefill = len(scheduler_output.scheduled_new_reqs) > 0
|
||||
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
long_seq_metadata: AscendPrefillContextParallelMetadata = \
|
||||
self.runner.long_seq_metadata if self.pcp_size > 1 else None
|
||||
if self.pcp_size > 1:
|
||||
long_seq_metadata = self.runner.long_seq_metadata
|
||||
input_ids_pcp_full = self.runner.input_ids_pcp_full
|
||||
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full
|
||||
query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full_cpu
|
||||
num_reqs = self.runner.input_batch.num_reqs
|
||||
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
|
||||
query_start_loc_pcp_full_cpu[:num_reqs]
|
||||
num_prefill_reqs = (ori_query_lens
|
||||
> self.decode_threshold).sum().item()
|
||||
num_decode_reqs = num_reqs - num_prefill_reqs
|
||||
else:
|
||||
long_seq_metadata = None
|
||||
num_prefill_reqs = 0
|
||||
num_decode_reqs = 0
|
||||
if spec_decode_metadata is None:
|
||||
# update pcp related params
|
||||
if self.pcp_size > 1 and is_prefill:
|
||||
token_indices_to_sample = None
|
||||
target_token_ids = self.runner.input_ids_pcp_full[:
|
||||
num_scheduled_tokens]
|
||||
if self.pcp_size > 1:
|
||||
token_indices_to_sample = \
|
||||
query_start_loc_pcp_full_cpu[1:num_reqs + 1] - 1
|
||||
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
target_hidden_states = hidden_states
|
||||
else:
|
||||
@@ -295,6 +311,11 @@ class MtpProposer(Proposer):
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
else:
|
||||
if self.pcp_size > 1:
|
||||
common_attn_metadata.query_start_loc_cpu = \
|
||||
query_start_loc_pcp_full_cpu[:num_reqs + 1]
|
||||
common_attn_metadata.query_start_loc = \
|
||||
query_start_loc_pcp_full[:num_reqs + 1]
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
token_indices_to_sample = None
|
||||
common_attn_metadata, token_indices =\
|
||||
@@ -309,9 +330,14 @@ class MtpProposer(Proposer):
|
||||
common_attn_metadata,
|
||||
spec_decode_metadata,
|
||||
valid_sampled_tokens_count)
|
||||
target_token_ids = self.runner.input_ids[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
if self.pcp_size > 1:
|
||||
target_token_ids = input_ids_pcp_full[token_indices]
|
||||
target_positions = positions
|
||||
target_hidden_states = hidden_states
|
||||
else:
|
||||
target_token_ids = self.runner.input_ids[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
|
||||
draft_token_ids = self._propose(
|
||||
target_token_ids=target_token_ids,
|
||||
@@ -321,9 +347,10 @@ class MtpProposer(Proposer):
|
||||
last_token_indices=token_indices_to_sample,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata,
|
||||
is_prefill=is_prefill,
|
||||
req_scheduled_tokens=req_scheduled_tokens,
|
||||
long_seq_metadata=long_seq_metadata,
|
||||
num_prefill_reqs=num_prefill_reqs,
|
||||
num_decode_reqs=num_decode_reqs,
|
||||
)
|
||||
|
||||
return draft_token_ids
|
||||
@@ -464,9 +491,10 @@ class MtpProposer(Proposer):
|
||||
sampling_metadata: SamplingMetadata,
|
||||
mm_embed_inputs: Optional[tuple[list[torch.Tensor],
|
||||
torch.Tensor]] = None,
|
||||
is_prefill=False,
|
||||
req_scheduled_tokens=None,
|
||||
long_seq_metadata=None,
|
||||
num_prefill_reqs=0,
|
||||
num_decode_reqs=0,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
@@ -488,20 +516,65 @@ class MtpProposer(Proposer):
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
|
||||
# update pcp related params
|
||||
if self.pcp_size > 1 and is_prefill:
|
||||
num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens = \
|
||||
self._split_pcp_input(req_scheduled_tokens, num_tokens, target_hidden_states)
|
||||
# graph mode padding not considered now
|
||||
num_input_tokens = num_tokens
|
||||
self.input_ids[:num_input_tokens].copy_(input_ids)
|
||||
if self.pcp_size > 1:
|
||||
assert long_seq_metadata is not None
|
||||
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
|
||||
# 1. preprocess decode/prefill input_ids & target_hidden_states
|
||||
# decode input_ids: keep unchanged
|
||||
# decode target_hidden_states: remove padding
|
||||
# prefill input_ids: add padding and pcp split
|
||||
# prefill target_hidden_states: pcp split
|
||||
num_tokens_d = num_decode_reqs * self.decode_threshold
|
||||
num_tokens_d_padded = num_tokens_d * self.pcp_size
|
||||
input_ids_d = self.input_ids[:num_tokens_d]
|
||||
input_ids_p = self.input_ids[num_tokens_d:num_tokens]
|
||||
target_hidden_states_d_padded = \
|
||||
target_hidden_states[:num_tokens_d_padded]
|
||||
if num_tokens_d:
|
||||
# remove padding (from pcp all-gather) in decode part
|
||||
target_hidden_states_d = target_hidden_states_d_padded.reshape(
|
||||
[
|
||||
num_decode_reqs, self.decode_threshold * self.pcp_size,
|
||||
-1
|
||||
])[:, :self.decode_threshold, :].reshape(
|
||||
[num_tokens_d, -1])
|
||||
else:
|
||||
target_hidden_states_d = target_hidden_states_d_padded
|
||||
target_hidden_states_p = target_hidden_states[num_tokens_d_padded:]
|
||||
req_scheduled_tokens_p = {}
|
||||
for i, req_id in enumerate(self.runner.input_batch.req_ids):
|
||||
if i >= num_decode_reqs:
|
||||
req_scheduled_tokens_p[req_id] = \
|
||||
req_scheduled_tokens[req_id]
|
||||
(num_tokens_p, input_ids_p, target_hidden_states_p,
|
||||
max_query_len_p, seq_lens_p, cu_num_tokens_p) = \
|
||||
self._split_pcp_input(
|
||||
req_scheduled_tokens_p, input_ids_p, target_hidden_states_p)
|
||||
num_tokens = num_tokens_d + num_tokens_p
|
||||
target_positions = target_positions[:num_tokens]
|
||||
self.input_ids[:num_tokens].copy_(
|
||||
torch.cat([input_ids_d, input_ids_p], dim=0))
|
||||
target_hidden_states = torch.cat(
|
||||
[target_hidden_states_d, target_hidden_states_p], dim=0)
|
||||
# 2. update attn_metadata params that may be influenced by pcp
|
||||
common_attn_metadata.num_actual_tokens = num_tokens
|
||||
common_attn_metadata.max_query_len = max_query_len
|
||||
common_attn_metadata.seq_lens_cpu = seq_lens.cpu()
|
||||
common_attn_metadata.query_start_loc = \
|
||||
cu_num_tokens[:batch_size + 1]
|
||||
common_attn_metadata.query_start_loc_cpu = \
|
||||
cu_num_tokens[:batch_size + 1].cpu()
|
||||
common_attn_metadata.max_query_len = max(self.decode_threshold,
|
||||
max_query_len_p)
|
||||
common_attn_metadata.seq_lens[num_decode_reqs:] = seq_lens_p
|
||||
common_attn_metadata.seq_lens_cpu[num_decode_reqs:] = seq_lens_p
|
||||
query_start_loc_p = cu_num_tokens_p[1:] + \
|
||||
common_attn_metadata.query_start_loc[num_decode_reqs].item()
|
||||
common_attn_metadata.query_start_loc[num_decode_reqs + 1:] = \
|
||||
query_start_loc_p
|
||||
common_attn_metadata.query_start_loc_cpu[num_decode_reqs + 1:] = \
|
||||
query_start_loc_p
|
||||
# 3. update sample_indices according to main model
|
||||
if num_decode_reqs:
|
||||
last_token_indices[:num_decode_reqs] = \
|
||||
self.runner.logits_indices[last_token_indices[:num_decode_reqs]]
|
||||
if num_prefill_reqs:
|
||||
last_token_indices[-num_prefill_reqs:] = \
|
||||
self.runner.logits_indices[-num_prefill_reqs:]
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
@@ -575,6 +648,12 @@ class MtpProposer(Proposer):
|
||||
last_token_indices,
|
||||
(0, max_num_reqs_across_dp - num_indices))
|
||||
|
||||
if self.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
|
||||
hidden_states = torch.index_select(
|
||||
hidden_states, 0, self.runner.
|
||||
pcp_allgather_restore_idx[:hidden_states.shape[0]])
|
||||
|
||||
sample_hidden_states = hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
if lmhead_tp_enable() and num_indices < logits.shape[0]:
|
||||
@@ -854,16 +933,26 @@ class MtpProposer(Proposer):
|
||||
|
||||
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
||||
|
||||
def _split_pcp_input(self, req_scheduled_tokens, num_tokens,
|
||||
def _split_pcp_input(self, req_scheduled_tokens, input_ids,
|
||||
target_hidden_states):
|
||||
"""
|
||||
Split input_ids and target_hidden_states in pcp group.
|
||||
Split prefill input_ids and target_hidden_states in pcp group.
|
||||
1. input_ids padding: [t0, t1, t2, t3, t4, t5] -> [t0, t1, t2, t3, t4, t5, pad, pad]
|
||||
2. split input_ids: pcp0 [t0, t1, pad, pad], pcp1 [t2, t3, t4, t5]
|
||||
3. split target_hidden_states (already include cp padding):
|
||||
3. split target_hidden_states (already include pcp padding):
|
||||
[h0, h1, h2, h3, h4, h5, pad, pad] -> pcp0 [h0, h1, pad, pad], pcp1 [h2, h3, h4, h5]
|
||||
4. also update max_query_len, seq_lens, cu_num_tokens according to pcp split.
|
||||
"""
|
||||
if len(req_scheduled_tokens) == 0:
|
||||
# no prefill inputs to split, return empty result
|
||||
return (
|
||||
0,
|
||||
torch.zeros([0], device='npu'),
|
||||
torch.zeros([0, target_hidden_states.size(1)], device='npu'),
|
||||
0,
|
||||
torch.zeros([0]),
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
)
|
||||
|
||||
def _pcp_pad_and_split(num_tokens):
|
||||
num_pcp_padded_scheduled_tokens = cdiv(
|
||||
@@ -885,7 +974,6 @@ class MtpProposer(Proposer):
|
||||
return req_position_cp, num_pcp_padded_scheduled_tokens, pcp_pad
|
||||
|
||||
num_pcp_scheduled_tokens = []
|
||||
input_ids_list = self.input_ids[:num_tokens]
|
||||
ori_start_index = 0
|
||||
pad_start_index = 0
|
||||
pcp_split_input_ids_list = []
|
||||
@@ -896,8 +984,8 @@ class MtpProposer(Proposer):
|
||||
actual_num_tokens = len(req_position_pcp)
|
||||
num_pcp_scheduled_tokens.append(actual_num_tokens)
|
||||
pad_input_ids = F.pad(
|
||||
input_ids_list[ori_start_index:ori_start_index +
|
||||
ori_num_tokens], (0, num_pcp_pad))
|
||||
input_ids[ori_start_index:ori_start_index + ori_num_tokens],
|
||||
(0, num_pcp_pad))
|
||||
ori_start_index += ori_num_tokens
|
||||
pcp_chunk_indices = [
|
||||
pad_start_index + pos for pos in req_position_pcp
|
||||
|
||||
Reference in New Issue
Block a user