[feature] support pcp + mtp (with pd disaggregate) (#3822)
### What this PR does / why we need it?
support pcp + mtp (with pd disaggregate, only pcp in P nodes)
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from vllm.config import (CUDAGraphMode, VllmConfig,
|
||||
get_layers_from_vllm_config, set_current_vllm_config)
|
||||
from vllm.forward_context import BatchDescriptor
|
||||
@@ -13,6 +14,7 @@ from vllm.model_executor.model_loader.utils import \
|
||||
process_weights_after_loading
|
||||
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@@ -22,7 +24,8 @@ from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
AscendPrefillContextParallelMetadata)
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
||||
vllm_version_is)
|
||||
@@ -67,6 +70,10 @@ class MtpProposer(Proposer):
|
||||
# hidden size (e.g., Llama 3.3 70B).
|
||||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||||
|
||||
self.pcp_size = self.runner.pcp_size
|
||||
self.dcp_size = self.runner.dcp_size
|
||||
self.pcp_rank = self.runner.pcp_rank
|
||||
|
||||
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
|
||||
self.draft_indexer_metadata_builder: Optional[
|
||||
AttentionMetadataBuilder] = None
|
||||
@@ -99,6 +106,9 @@ class MtpProposer(Proposer):
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=device)
|
||||
self.full_indices = range(
|
||||
self.runner.max_num_tokens * self.pcp_size * self.dcp_size +
|
||||
self.pcp_size * self.dcp_size * self.runner.max_num_reqs)
|
||||
|
||||
# We need +1 here because the arange is used to set query_start_loc,
|
||||
# which has one more element than batch_size.
|
||||
@@ -238,12 +248,24 @@ class MtpProposer(Proposer):
|
||||
self.runner.num_discarded_requests
|
||||
)
|
||||
|
||||
is_prefill = len(scheduler_output.scheduled_new_reqs) > 0
|
||||
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
long_seq_metadata: AscendPrefillContextParallelMetadata = \
|
||||
self.runner.long_seq_metadata if self.pcp_size > 1 else None
|
||||
if spec_decode_metadata is None:
|
||||
token_indices_to_sample = None
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
# update pcp related params
|
||||
if self.pcp_size > 1 and is_prefill:
|
||||
token_indices_to_sample = None
|
||||
target_token_ids = self.runner.input_ids_pcp_full[:
|
||||
num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
target_hidden_states = hidden_states
|
||||
else:
|
||||
token_indices_to_sample = None
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
else:
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
token_indices_to_sample = None
|
||||
@@ -271,6 +293,9 @@ class MtpProposer(Proposer):
|
||||
last_token_indices=token_indices_to_sample,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata,
|
||||
is_prefill=is_prefill,
|
||||
req_scheduled_tokens=req_scheduled_tokens,
|
||||
long_seq_metadata=long_seq_metadata,
|
||||
)
|
||||
|
||||
return draft_token_ids
|
||||
@@ -397,6 +422,9 @@ class MtpProposer(Proposer):
|
||||
sampling_metadata: SamplingMetadata,
|
||||
mm_embed_inputs: Optional[tuple[list[torch.Tensor],
|
||||
torch.Tensor]] = None,
|
||||
is_prefill=False,
|
||||
req_scheduled_tokens=None,
|
||||
long_seq_metadata=None,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
@@ -417,6 +445,22 @@ class MtpProposer(Proposer):
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
|
||||
# update pcp related params
|
||||
if self.pcp_size > 1 and is_prefill:
|
||||
num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens = \
|
||||
self._split_pcp_input(req_scheduled_tokens, num_tokens, target_hidden_states)
|
||||
# graph mode padding not considered now
|
||||
num_input_tokens = num_tokens
|
||||
self.input_ids[:num_input_tokens].copy_(input_ids)
|
||||
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
|
||||
common_attn_metadata.num_actual_tokens = num_tokens
|
||||
common_attn_metadata.max_query_len = max_query_len
|
||||
common_attn_metadata.seq_lens_cpu = seq_lens.cpu()
|
||||
common_attn_metadata.query_start_loc = \
|
||||
cu_num_tokens[:batch_size + 1]
|
||||
common_attn_metadata.query_start_loc_cpu = \
|
||||
cu_num_tokens[:batch_size + 1].cpu()
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
@@ -767,3 +811,65 @@ class MtpProposer(Proposer):
|
||||
1 - num_rejected_tokens_gpu)
|
||||
|
||||
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
||||
|
||||
def _split_pcp_input(self, req_scheduled_tokens, num_tokens,
|
||||
target_hidden_states):
|
||||
"""
|
||||
Split input_ids and target_hidden_states in pcp group.
|
||||
1. input_ids padding: [t0, t1, t2, t3, t4, t5] -> [t0, t1, t2, t3, t4, t5, pad, pad]
|
||||
2. split input_ids: pcp0 [t0, t1, pad, pad], pcp1 [t2, t3, t4, t5]
|
||||
3. split target_hidden_states (already include cp padding):
|
||||
[h0, h1, h2, h3, h4, h5, pad, pad] -> pcp0 [h0, h1, pad, pad], pcp1 [h2, h3, h4, h5]
|
||||
4. also update max_query_len, seq_lens, cu_num_tokens according to pcp split.
|
||||
"""
|
||||
|
||||
def _pcp_pad_and_split(num_tokens):
|
||||
num_pcp_padded_scheduled_tokens = cdiv(
|
||||
num_tokens, 2 * self.pcp_size) * 2 * self.pcp_size
|
||||
pcp_pad = num_pcp_padded_scheduled_tokens - num_tokens
|
||||
chunk_size = num_pcp_padded_scheduled_tokens // (2 * self.pcp_size)
|
||||
|
||||
# split position_ids (and use split position_ids to split input_ids afterwards)
|
||||
req_position_cp: list[int] = []
|
||||
req_position_cp.extend(
|
||||
self.full_indices[self.pcp_rank *
|
||||
chunk_size:(self.pcp_rank + 1) * chunk_size])
|
||||
req_position_cp.extend(
|
||||
self.full_indices[num_pcp_padded_scheduled_tokens -
|
||||
(self.pcp_rank + 1) *
|
||||
chunk_size:num_pcp_padded_scheduled_tokens -
|
||||
self.pcp_rank * chunk_size])
|
||||
|
||||
return req_position_cp, num_pcp_padded_scheduled_tokens, pcp_pad
|
||||
|
||||
num_pcp_scheduled_tokens = []
|
||||
input_ids_list = self.input_ids[:num_tokens]
|
||||
ori_start_index = 0
|
||||
pad_start_index = 0
|
||||
pcp_split_input_ids_list = []
|
||||
pcp_split_hidden_states_list = []
|
||||
for ori_num_tokens in req_scheduled_tokens.values():
|
||||
req_position_pcp, num_pcp_padded_scheduled_tokens, num_pcp_pad = \
|
||||
_pcp_pad_and_split(ori_num_tokens)
|
||||
actual_num_tokens = len(req_position_pcp)
|
||||
num_pcp_scheduled_tokens.append(actual_num_tokens)
|
||||
pad_input_ids = F.pad(
|
||||
input_ids_list[ori_start_index:ori_start_index +
|
||||
ori_num_tokens], (0, num_pcp_pad))
|
||||
ori_start_index += ori_num_tokens
|
||||
pcp_chunk_indices = [
|
||||
pad_start_index + pos for pos in req_position_pcp
|
||||
]
|
||||
pcp_split_input_ids = pad_input_ids[req_position_pcp]
|
||||
pcp_split_hidden_states = target_hidden_states[pcp_chunk_indices]
|
||||
pcp_split_input_ids_list.append(pcp_split_input_ids)
|
||||
pcp_split_hidden_states_list.append(pcp_split_hidden_states)
|
||||
pad_start_index += num_pcp_padded_scheduled_tokens
|
||||
num_tokens = sum(num_pcp_scheduled_tokens)
|
||||
input_ids = torch.cat(pcp_split_input_ids_list)
|
||||
target_hidden_states = torch.cat(pcp_split_hidden_states_list, dim=0)
|
||||
max_query_len = max(num_pcp_scheduled_tokens)
|
||||
seq_lens = torch.tensor(num_pcp_scheduled_tokens, dtype=torch.int32)
|
||||
cu_num_tokens = torch.tensor(
|
||||
np.insert(np.cumsum(np.array(num_pcp_scheduled_tokens)), 0, 0))
|
||||
return num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens
|
||||
|
||||
Reference in New Issue
Block a user