[feature] support pcp + mtp (with pd disaggregate) (#3822)
### What this PR does / why we need it?
support pcp + mtp (with pd disaggregate, only pcp in P nodes)
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from vllm.config import (CUDAGraphMode, VllmConfig,
|
from vllm.config import (CUDAGraphMode, VllmConfig,
|
||||||
get_layers_from_vllm_config, set_current_vllm_config)
|
get_layers_from_vllm_config, set_current_vllm_config)
|
||||||
from vllm.forward_context import BatchDescriptor
|
from vllm.forward_context import BatchDescriptor
|
||||||
@@ -13,6 +14,7 @@ from vllm.model_executor.model_loader.utils import \
|
|||||||
process_weights_after_loading
|
process_weights_after_loading
|
||||||
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
|
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
|
||||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
|
from vllm.utils import cdiv
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata)
|
CommonAttentionMetadata)
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@@ -22,7 +24,8 @@ from vllm.v1.utils import CpuGpuBuffer
|
|||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
|
AscendPrefillContextParallelMetadata)
|
||||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||||
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
||||||
vllm_version_is)
|
vllm_version_is)
|
||||||
@@ -67,6 +70,10 @@ class MtpProposer(Proposer):
|
|||||||
# hidden size (e.g., Llama 3.3 70B).
|
# hidden size (e.g., Llama 3.3 70B).
|
||||||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||||||
|
|
||||||
|
self.pcp_size = self.runner.pcp_size
|
||||||
|
self.dcp_size = self.runner.dcp_size
|
||||||
|
self.pcp_rank = self.runner.pcp_rank
|
||||||
|
|
||||||
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
|
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
|
||||||
self.draft_indexer_metadata_builder: Optional[
|
self.draft_indexer_metadata_builder: Optional[
|
||||||
AttentionMetadataBuilder] = None
|
AttentionMetadataBuilder] = None
|
||||||
@@ -99,6 +106,9 @@ class MtpProposer(Proposer):
|
|||||||
(self.max_num_tokens, self.hidden_size),
|
(self.max_num_tokens, self.hidden_size),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=device)
|
device=device)
|
||||||
|
self.full_indices = range(
|
||||||
|
self.runner.max_num_tokens * self.pcp_size * self.dcp_size +
|
||||||
|
self.pcp_size * self.dcp_size * self.runner.max_num_reqs)
|
||||||
|
|
||||||
# We need +1 here because the arange is used to set query_start_loc,
|
# We need +1 here because the arange is used to set query_start_loc,
|
||||||
# which has one more element than batch_size.
|
# which has one more element than batch_size.
|
||||||
@@ -238,12 +248,24 @@ class MtpProposer(Proposer):
|
|||||||
self.runner.num_discarded_requests
|
self.runner.num_discarded_requests
|
||||||
)
|
)
|
||||||
|
|
||||||
|
is_prefill = len(scheduler_output.scheduled_new_reqs) > 0
|
||||||
|
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||||
|
long_seq_metadata: AscendPrefillContextParallelMetadata = \
|
||||||
|
self.runner.long_seq_metadata if self.pcp_size > 1 else None
|
||||||
if spec_decode_metadata is None:
|
if spec_decode_metadata is None:
|
||||||
token_indices_to_sample = None
|
# update pcp related params
|
||||||
# input_ids can be None for multimodal models.
|
if self.pcp_size > 1 and is_prefill:
|
||||||
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
|
token_indices_to_sample = None
|
||||||
target_positions = positions[:num_scheduled_tokens]
|
target_token_ids = self.runner.input_ids_pcp_full[:
|
||||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
num_scheduled_tokens]
|
||||||
|
target_positions = positions[:num_scheduled_tokens]
|
||||||
|
target_hidden_states = hidden_states
|
||||||
|
else:
|
||||||
|
token_indices_to_sample = None
|
||||||
|
# input_ids can be None for multimodal models.
|
||||||
|
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
|
||||||
|
target_positions = positions[:num_scheduled_tokens]
|
||||||
|
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||||
else:
|
else:
|
||||||
if self.speculative_config.disable_padded_drafter_batch:
|
if self.speculative_config.disable_padded_drafter_batch:
|
||||||
token_indices_to_sample = None
|
token_indices_to_sample = None
|
||||||
@@ -271,6 +293,9 @@ class MtpProposer(Proposer):
|
|||||||
last_token_indices=token_indices_to_sample,
|
last_token_indices=token_indices_to_sample,
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
|
is_prefill=is_prefill,
|
||||||
|
req_scheduled_tokens=req_scheduled_tokens,
|
||||||
|
long_seq_metadata=long_seq_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
return draft_token_ids
|
return draft_token_ids
|
||||||
@@ -397,6 +422,9 @@ class MtpProposer(Proposer):
|
|||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
mm_embed_inputs: Optional[tuple[list[torch.Tensor],
|
mm_embed_inputs: Optional[tuple[list[torch.Tensor],
|
||||||
torch.Tensor]] = None,
|
torch.Tensor]] = None,
|
||||||
|
is_prefill=False,
|
||||||
|
req_scheduled_tokens=None,
|
||||||
|
long_seq_metadata=None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_tokens = target_token_ids.shape[0]
|
num_tokens = target_token_ids.shape[0]
|
||||||
batch_size = next_token_ids.shape[0]
|
batch_size = next_token_ids.shape[0]
|
||||||
@@ -417,6 +445,22 @@ class MtpProposer(Proposer):
|
|||||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||||
self.input_ids[last_token_indices] = next_token_ids
|
self.input_ids[last_token_indices] = next_token_ids
|
||||||
|
|
||||||
|
# update pcp related params
|
||||||
|
if self.pcp_size > 1 and is_prefill:
|
||||||
|
num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens = \
|
||||||
|
self._split_pcp_input(req_scheduled_tokens, num_tokens, target_hidden_states)
|
||||||
|
# graph mode padding not considered now
|
||||||
|
num_input_tokens = num_tokens
|
||||||
|
self.input_ids[:num_input_tokens].copy_(input_ids)
|
||||||
|
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
|
||||||
|
common_attn_metadata.num_actual_tokens = num_tokens
|
||||||
|
common_attn_metadata.max_query_len = max_query_len
|
||||||
|
common_attn_metadata.seq_lens_cpu = seq_lens.cpu()
|
||||||
|
common_attn_metadata.query_start_loc = \
|
||||||
|
cu_num_tokens[:batch_size + 1]
|
||||||
|
common_attn_metadata.query_start_loc_cpu = \
|
||||||
|
cu_num_tokens[:batch_size + 1].cpu()
|
||||||
|
|
||||||
assert self.runner is not None
|
assert self.runner is not None
|
||||||
|
|
||||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||||
@@ -767,3 +811,65 @@ class MtpProposer(Proposer):
|
|||||||
1 - num_rejected_tokens_gpu)
|
1 - num_rejected_tokens_gpu)
|
||||||
|
|
||||||
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
||||||
|
|
||||||
|
def _split_pcp_input(self, req_scheduled_tokens, num_tokens,
|
||||||
|
target_hidden_states):
|
||||||
|
"""
|
||||||
|
Split input_ids and target_hidden_states in pcp group.
|
||||||
|
1. input_ids padding: [t0, t1, t2, t3, t4, t5] -> [t0, t1, t2, t3, t4, t5, pad, pad]
|
||||||
|
2. split input_ids: pcp0 [t0, t1, pad, pad], pcp1 [t2, t3, t4, t5]
|
||||||
|
3. split target_hidden_states (already include cp padding):
|
||||||
|
[h0, h1, h2, h3, h4, h5, pad, pad] -> pcp0 [h0, h1, pad, pad], pcp1 [h2, h3, h4, h5]
|
||||||
|
4. also update max_query_len, seq_lens, cu_num_tokens according to pcp split.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _pcp_pad_and_split(num_tokens):
|
||||||
|
num_pcp_padded_scheduled_tokens = cdiv(
|
||||||
|
num_tokens, 2 * self.pcp_size) * 2 * self.pcp_size
|
||||||
|
pcp_pad = num_pcp_padded_scheduled_tokens - num_tokens
|
||||||
|
chunk_size = num_pcp_padded_scheduled_tokens // (2 * self.pcp_size)
|
||||||
|
|
||||||
|
# split position_ids (and use split position_ids to split input_ids afterwards)
|
||||||
|
req_position_cp: list[int] = []
|
||||||
|
req_position_cp.extend(
|
||||||
|
self.full_indices[self.pcp_rank *
|
||||||
|
chunk_size:(self.pcp_rank + 1) * chunk_size])
|
||||||
|
req_position_cp.extend(
|
||||||
|
self.full_indices[num_pcp_padded_scheduled_tokens -
|
||||||
|
(self.pcp_rank + 1) *
|
||||||
|
chunk_size:num_pcp_padded_scheduled_tokens -
|
||||||
|
self.pcp_rank * chunk_size])
|
||||||
|
|
||||||
|
return req_position_cp, num_pcp_padded_scheduled_tokens, pcp_pad
|
||||||
|
|
||||||
|
num_pcp_scheduled_tokens = []
|
||||||
|
input_ids_list = self.input_ids[:num_tokens]
|
||||||
|
ori_start_index = 0
|
||||||
|
pad_start_index = 0
|
||||||
|
pcp_split_input_ids_list = []
|
||||||
|
pcp_split_hidden_states_list = []
|
||||||
|
for ori_num_tokens in req_scheduled_tokens.values():
|
||||||
|
req_position_pcp, num_pcp_padded_scheduled_tokens, num_pcp_pad = \
|
||||||
|
_pcp_pad_and_split(ori_num_tokens)
|
||||||
|
actual_num_tokens = len(req_position_pcp)
|
||||||
|
num_pcp_scheduled_tokens.append(actual_num_tokens)
|
||||||
|
pad_input_ids = F.pad(
|
||||||
|
input_ids_list[ori_start_index:ori_start_index +
|
||||||
|
ori_num_tokens], (0, num_pcp_pad))
|
||||||
|
ori_start_index += ori_num_tokens
|
||||||
|
pcp_chunk_indices = [
|
||||||
|
pad_start_index + pos for pos in req_position_pcp
|
||||||
|
]
|
||||||
|
pcp_split_input_ids = pad_input_ids[req_position_pcp]
|
||||||
|
pcp_split_hidden_states = target_hidden_states[pcp_chunk_indices]
|
||||||
|
pcp_split_input_ids_list.append(pcp_split_input_ids)
|
||||||
|
pcp_split_hidden_states_list.append(pcp_split_hidden_states)
|
||||||
|
pad_start_index += num_pcp_padded_scheduled_tokens
|
||||||
|
num_tokens = sum(num_pcp_scheduled_tokens)
|
||||||
|
input_ids = torch.cat(pcp_split_input_ids_list)
|
||||||
|
target_hidden_states = torch.cat(pcp_split_hidden_states_list, dim=0)
|
||||||
|
max_query_len = max(num_pcp_scheduled_tokens)
|
||||||
|
seq_lens = torch.tensor(num_pcp_scheduled_tokens, dtype=torch.int32)
|
||||||
|
cu_num_tokens = torch.tensor(
|
||||||
|
np.insert(np.cumsum(np.array(num_pcp_scheduled_tokens)), 0, 0))
|
||||||
|
return num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens
|
||||||
|
|||||||
@@ -479,6 +479,22 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens,
|
self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
if self.speculative_config and self.pcp_size > 1:
|
||||||
|
self.input_ids_pcp_full = torch.zeros(self.max_num_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu",
|
||||||
|
pin_memory=True)
|
||||||
|
self.query_start_loc_pcp_full = torch.zeros(self.max_num_reqs + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu",
|
||||||
|
pin_memory=True)
|
||||||
|
self.query_start_loc_pcp_full_np = self.query_start_loc_pcp_full.numpy(
|
||||||
|
)
|
||||||
|
self.positions_pcp_full = torch.zeros(self.max_num_tokens,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device="cpu",
|
||||||
|
pin_memory=True)
|
||||||
|
self.positions_np_pcp_full = self.positions_pcp_full.numpy()
|
||||||
|
|
||||||
self.use_aclgraph = self._use_aclgraph()
|
self.use_aclgraph = self._use_aclgraph()
|
||||||
self.aclgraph_batch_sizes = list(
|
self.aclgraph_batch_sizes = list(
|
||||||
@@ -1598,7 +1614,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
]
|
]
|
||||||
num_tokens_np = np.array(num_tokens, dtype=np.int32)
|
num_tokens_np = np.array(num_tokens, dtype=np.int32)
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np
|
if self.pcp_size == 1:
|
||||||
|
discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np
|
||||||
|
else:
|
||||||
|
# while pcp > 1, we need the original num_scheduled_tokens before split
|
||||||
|
# to calculate discard_requests_mask
|
||||||
|
original_seq_lens_np = (
|
||||||
|
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||||
|
np.array(list(scheduler_output.num_scheduled_tokens.values())))
|
||||||
|
discard_requests_mask = original_seq_lens_np < num_tokens_np
|
||||||
discard_request_indices = np.nonzero(discard_requests_mask)[0]
|
discard_request_indices = np.nonzero(discard_requests_mask)[0]
|
||||||
self.num_discarded_requests = len(discard_request_indices)
|
self.num_discarded_requests = len(discard_request_indices)
|
||||||
self.discard_request_indices.np[:self.num_discarded_requests] = (
|
self.discard_request_indices.np[:self.num_discarded_requests] = (
|
||||||
@@ -1730,6 +1754,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||||
self.num_accepted_tokens.copy_to_gpu()
|
self.num_accepted_tokens.copy_to_gpu()
|
||||||
|
|
||||||
|
is_prefill = len(scheduler_output.scheduled_new_reqs) > 0
|
||||||
|
if self.speculative_config and self.pcp_size > 1 and is_prefill:
|
||||||
|
self._generate_pcp_mtp_input(
|
||||||
|
num_reqs, scheduler_output.total_num_scheduled_tokens,
|
||||||
|
scheduler_output.num_scheduled_tokens)
|
||||||
|
|
||||||
# prepare pcp meta data
|
# prepare pcp meta data
|
||||||
long_seq_metadata = self._generate_pcp_metadata(
|
long_seq_metadata = self._generate_pcp_metadata(
|
||||||
total_num_scheduled_tokens, seq_lens_cpu)
|
total_num_scheduled_tokens, seq_lens_cpu)
|
||||||
@@ -4419,4 +4449,46 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
'tail_attn_nomask_seqlens']
|
'tail_attn_nomask_seqlens']
|
||||||
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
|
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
|
||||||
'pcp_prefill_mask']
|
'pcp_prefill_mask']
|
||||||
|
self.long_seq_metadata = long_seq_metadata
|
||||||
return long_seq_metadata
|
return long_seq_metadata
|
||||||
|
|
||||||
|
def _generate_pcp_mtp_input(
|
||||||
|
self,
|
||||||
|
num_reqs: int,
|
||||||
|
total_num_scheduled_tokens: int,
|
||||||
|
num_scheduled_tokens: dict[str, int],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group,
|
||||||
|
but mtp need to shift original input_ids before pcp splitting,
|
||||||
|
so we record original input_ids here.
|
||||||
|
"""
|
||||||
|
total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens
|
||||||
|
num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32)
|
||||||
|
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||||
|
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
|
||||||
|
req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs],
|
||||||
|
num_scheduled_tokens_pcp_full)
|
||||||
|
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
|
||||||
|
self.query_start_loc_pcp_full_np[0] = 0
|
||||||
|
self.query_start_loc_pcp_full_np[1:num_reqs +
|
||||||
|
1] = cu_num_tokens_pcp_full
|
||||||
|
self.query_start_loc_pcp_full_np[num_reqs + 1:].fill(-1)
|
||||||
|
cumsums_offsets_pcp_full = np.repeat(
|
||||||
|
cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full,
|
||||||
|
num_scheduled_tokens_pcp_full)
|
||||||
|
arange_pcp_full = self.arange_np[:
|
||||||
|
total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full
|
||||||
|
positions_np_pcp_full = self.positions_np_pcp_full[:
|
||||||
|
total_num_scheduled_tokens_pcp_full]
|
||||||
|
np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full],
|
||||||
|
arange_pcp_full,
|
||||||
|
out=positions_np_pcp_full)
|
||||||
|
token_indices_pcp_full = (
|
||||||
|
positions_np_pcp_full +
|
||||||
|
req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1])
|
||||||
|
torch.index_select(
|
||||||
|
self.input_batch.token_ids_cpu_tensor.flatten(),
|
||||||
|
0,
|
||||||
|
torch.from_numpy(token_indices_pcp_full),
|
||||||
|
out=self.input_ids_pcp_full[:total_num_scheduled_tokens_pcp_full])
|
||||||
|
|||||||
Reference in New Issue
Block a user