[feature] support pcp + mtp (in pd co-locate scenario) (#4098)

1. support pcp + mtp in pd co-locate scenario
2. llmdatadist connector pcp related bugfix and cleancode

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
This commit is contained in:
zhangsicheng5
2025-11-12 17:22:21 +08:00
committed by GitHub
parent 1b4ce63ec9
commit a123f355e9
6 changed files with 246 additions and 97 deletions

View File

@@ -164,7 +164,7 @@ def create_request(
remote_host="my-host", remote_host="my-host",
remote_port=1234, remote_port=1234,
remote_tp_size=1, remote_tp_size=1,
remote_cp_size=1, remote_pcp_size=1,
remote_dcp_size=1) remote_dcp_size=1)
max_tokens = 1 if do_remote_decode else max_tokens max_tokens = 1 if do_remote_decode else max_tokens

View File

@@ -281,14 +281,17 @@ class AscendMLAMetadataBuilder:
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs', decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
0) 0)
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs) max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
self.batch_seq_mask_buf = torch.empty(max_num_seqs, self.batch_seq_mask_buf = torch.empty(max_num_seqs *
self.decode_threshold,
dtype=torch.uint8, dtype=torch.uint8,
device=device) device=device)
self.seq_mask_pcp_buf = torch.empty(max_num_seqs, self.seq_mask_pcp_buf = torch.empty(max_num_seqs *
self.decode_threshold,
self.pcp_size, self.pcp_size,
dtype=torch.uint8, dtype=torch.uint8,
device=device) device=device)
self.seq_mask_dcp_buf = torch.empty(max_num_seqs, self.seq_mask_dcp_buf = torch.empty(max_num_seqs *
self.decode_threshold,
self.dcp_size, self.dcp_size,
dtype=torch.uint8, dtype=torch.uint8,
device=device) device=device)
@@ -504,12 +507,18 @@ class AscendMLAMetadataBuilder:
seq_lens = seq_lens[:num_decodes] seq_lens = seq_lens[:num_decodes]
input_positions = input_positions[:num_decode_tokens] input_positions = input_positions[:num_decode_tokens]
block_table = block_table[:num_decodes, ...] block_table = block_table[:num_decodes, ...]
# For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular spec_attn_mask shape
if self.pcp_size > 1:
block_table = block_table.repeat_interleave(
self.decode_threshold, dim=0)
seq_lens_list = seq_lens.tolist() seq_lens_list = seq_lens.tolist()
if num_computed_tokens_of_pcp_dcp is not None: if num_computed_tokens_of_pcp_dcp is not None:
# [bs, pcp_size, dcp_size]
num_computed_tokens_of_cp_dcp_array = np.array( num_computed_tokens_of_cp_dcp_array = np.array(
num_computed_tokens_of_pcp_dcp num_computed_tokens_of_pcp_dcp)[:num_decodes *
)[:num_decodes] # [bs, pcp_size, dcp_size] self.decode_threshold]
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
self.pcp_rank, self.pcp_rank,

View File

@@ -371,7 +371,7 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
def update_mla_attn_dcp_pcp_params(update_stream, forward_context, def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
runtime_shape, speculative_config): runtime_shape):
graph_params = get_graph_params() graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args # FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph. # for each layer's attention op in the graph.
@@ -388,16 +388,14 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
decode_meta = forward_context.attn_metadata[key].decode decode_meta = forward_context.attn_metadata[key].decode
seq_len = decode_meta.cp_seq_len seq_len = decode_meta.cp_seq_len
if speculative_config and speculative_config.method == "deepseek_mtp": # For pcp + spec decode, we flatten seq_lens
spec_multiple = speculative_config.num_speculative_tokens + 1 # to avoid irregular spec_attn_mask shape,
seq_len = seq_len + [0] * (runtime_shape // spec_multiple - # so there's no need to divide runtime_shape by spec_multiple
len(seq_len)) pad_length = runtime_shape - len(seq_len)
else: pad_tensor = torch.zeros(pad_length,
pad_length = runtime_shape - len(seq_len) dtype=seq_len.dtype,
pad_tensor = torch.zeros(pad_length, device=seq_len.device)
dtype=seq_len.dtype, seq_len = torch.cat([seq_len, pad_tensor], dim=0)
device=seq_len.device)
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
torch.npu.graph_task_update_begin(update_stream, handle) torch.npu.graph_task_update_begin(update_stream, handle)

View File

@@ -79,7 +79,7 @@ class ReqMeta:
remote_port: str remote_port: str
engine_id: str engine_id: str
remote_tp_size: str remote_tp_size: str
remote_cp_size: str remote_pcp_size: str
remote_dcp_size: str remote_dcp_size: str
@@ -97,7 +97,7 @@ class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata):
remote_host=kv_transfer_params["remote_host"], remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"], remote_port=kv_transfer_params["remote_port"],
remote_tp_size=kv_transfer_params["remote_tp_size"], remote_tp_size=kv_transfer_params["remote_tp_size"],
remote_cp_size=kv_transfer_params["remote_cp_size"], remote_pcp_size=kv_transfer_params["remote_pcp_size"],
remote_dcp_size=kv_transfer_params["remote_dcp_size"], remote_dcp_size=kv_transfer_params["remote_dcp_size"],
) )
@@ -318,7 +318,7 @@ class LLMDataDistCMgrConnectorScheduler():
remote_port=self.port, remote_port=self.port,
remote_tp_size=str( remote_tp_size=str(
self.vllm_config.parallel_config.tensor_parallel_size), self.vllm_config.parallel_config.tensor_parallel_size),
remote_cp_size=str(self.pcp_size), remote_pcp_size=str(self.pcp_size),
remote_dcp_size=str(self.dcp_size), remote_dcp_size=str(self.dcp_size),
) )
@@ -677,7 +677,7 @@ class LLMDataDistCMgrConnectorWorker():
remote_engine_id=meta.engine_id, remote_engine_id=meta.engine_id,
request_id=req_id, request_id=req_id,
remote_tp_size=meta.remote_tp_size, remote_tp_size=meta.remote_tp_size,
remote_cp_size=meta.remote_cp_size, remote_pcp_size=meta.remote_pcp_size,
remote_dcp_size=meta.remote_dcp_size, remote_dcp_size=meta.remote_dcp_size,
) )
futures.append(future) futures.append(future)
@@ -876,39 +876,40 @@ class LLMDataDistCMgrConnectorWorker():
remote_block_ids: list[int], remote_block_ids: list[int],
remote_port: int, remote_port: int,
remote_tp_size: int, remote_tp_size: int,
remote_cp_size: int, remote_pcp_size: int,
remote_dcp_size: int, remote_dcp_size: int,
) -> tuple[int, list[int], list[int]]: ) -> tuple[int, list[int], list[int]]:
""" """
In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node. In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
Use this function to calculate remote port and remote block number of each remote P node that we need to pull. Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
""" """
if self.pcp_size == remote_cp_size and self.dcp_size == remote_dcp_size: if self.pcp_size == remote_pcp_size and self.dcp_size == remote_dcp_size:
# remote & local cp/dcp are equal, do kv transfer point-to-point # remote & local cp/dcp are equal, do kv transfer point-to-point
remote_kv_num = 1 remote_kv_num = 1
remote_ports = [remote_port + self.pcp_rank * self.tp_size + tp_offset \ remote_ports = [remote_port + self.pcp_rank * self.tp_size + tp_offset \
for tp_offset in range(self.tp_rank, int(remote_tp_size), self.tp_size)] for tp_offset in range(self.tp_rank, int(remote_tp_size), self.tp_size)]
remote_block_nums = [len(remote_block_ids)] remote_block_nums = [len(remote_block_ids)]
elif (self.use_mla and self.pcp_size == 1 and self.dcp_size == 1) \ elif (self.use_mla and self.pcp_size == 1 and self.dcp_size == 1) \
or (not self.use_mla and self.pcp_size == 1 and remote_tp_size == self.tp_size and remote_dcp_size == self.dcp_size): or (not self.use_mla and self.pcp_size == 1 and self.dcp_size == 1 and remote_tp_size == self.tp_size):
# remote & local cp/dcp are not equal, each D node needs to pull from cp(*dcp) P nodes # remote & local cp/dcp are not equal, each D node needs to pull from cp(*dcp) P nodes
# 1. for mla, support D cp_size = dcp_size = 1 # 1. for mla, support D cp_size = dcp_size = 1
# 2. for gqa, support D tp_size = P tp_size, D dcp_size = P dcp_size # 2. for gqa, support D tp_size = P tp_size, D dcp_size = P dcp_size
remote_dcp_size = remote_dcp_size // self.dcp_size remote_dcp_size = remote_dcp_size // self.dcp_size
remote_kv_num = remote_cp_size * remote_dcp_size remote_kv_num = remote_pcp_size * remote_dcp_size
cp_dcp_offsets = [] cp_dcp_offsets = []
for cp_idx in range(remote_cp_size): for cp_idx in range(remote_pcp_size):
cp_offset = cp_idx * remote_tp_size cp_offset = cp_idx * remote_tp_size
cp_dcp_offsets += list( cp_dcp_offsets += list(
range(cp_offset, cp_offset + remote_dcp_size)) range(cp_offset, cp_offset + remote_dcp_size))
remote_ports = [remote_port + cp_dcp_offset + (self.tp_rank if not self.use_mla else 0) \ tp_offset = 0 if self.use_mla else self.tp_rank // remote_dcp_size * remote_dcp_size
remote_ports = [remote_port + cp_dcp_offset + tp_offset \
for cp_dcp_offset in cp_dcp_offsets] for cp_dcp_offset in cp_dcp_offsets]
# recompute cp/dcp block assign here, maybe we can also pass it from P node meta # recompute cp/dcp block assign here, maybe we can also pass it from P node meta
local_block_num = len(local_block_ids) local_block_num = len(local_block_ids)
remote_block_nums = [ remote_block_nums = [
local_block_num // (remote_cp_size * remote_dcp_size) local_block_num // (remote_pcp_size * remote_dcp_size)
] * remote_cp_size * remote_dcp_size ] * remote_pcp_size * remote_dcp_size
num_remain_blocks = local_block_num % (remote_cp_size * num_remain_blocks = local_block_num % (remote_pcp_size *
remote_dcp_size) remote_dcp_size)
for i in range(num_remain_blocks): for i in range(num_remain_blocks):
remote_block_nums[i] += 1 remote_block_nums[i] += 1
@@ -921,7 +922,7 @@ class LLMDataDistCMgrConnectorWorker():
# Other cases are not supported now, maybe need to reshard kv_cache. # Other cases are not supported now, maybe need to reshard kv_cache.
raise NotImplementedError( raise NotImplementedError(
f'Current case is not supported now: use_mla={self.use_mla}, ' f'Current case is not supported now: use_mla={self.use_mla}, '
f'P tp={remote_tp_size}, pcp={remote_cp_size}, dcp={remote_dcp_size}, ' f'P tp={remote_tp_size}, pcp={remote_pcp_size}, dcp={remote_dcp_size}, '
f'D tp={self.tp_size}, pcp={self.pcp_size}, dcp={self.dcp_size}' f'D tp={self.tp_size}, pcp={self.pcp_size}, dcp={self.dcp_size}'
) )
return remote_kv_num, remote_ports, remote_block_nums return remote_kv_num, remote_ports, remote_block_nums
@@ -935,7 +936,7 @@ class LLMDataDistCMgrConnectorWorker():
remote_engine_id: str, remote_engine_id: str,
request_id: str, request_id: str,
remote_tp_size: str, remote_tp_size: str,
remote_cp_size: str, remote_pcp_size: str,
remote_dcp_size: str, remote_dcp_size: str,
): ):
remote_kv_num, remote_ports, remote_block_nums = self._get_kv_split_metadata( remote_kv_num, remote_ports, remote_block_nums = self._get_kv_split_metadata(
@@ -943,7 +944,7 @@ class LLMDataDistCMgrConnectorWorker():
remote_block_ids=remote_block_ids, remote_block_ids=remote_block_ids,
remote_port=remote_port, remote_port=remote_port,
remote_tp_size=int(remote_tp_size), remote_tp_size=int(remote_tp_size),
remote_cp_size=int(remote_cp_size), remote_pcp_size=int(remote_pcp_size),
remote_dcp_size=int(remote_dcp_size), remote_dcp_size=int(remote_dcp_size),
) )
logger.debug( logger.debug(

View File

@@ -25,12 +25,15 @@ from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
AscendPrefillContextParallelMetadata)
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
prefill_context_parallel_enable,
vllm_version_is) vllm_version_is)
if prefill_context_parallel_enable():
from vllm.distributed import get_pcp_group
if vllm_version_is("0.11.0"): if vllm_version_is("0.11.0"):
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
@@ -84,6 +87,7 @@ class MtpProposer(Proposer):
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
self.decode_threshold = 1 + self.num_speculative_tokens
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.token_arange_np = np.arange(self.max_num_tokens) self.token_arange_np = np.arange(self.max_num_tokens)
# We need to get the hidden size from the draft model config because # We need to get the hidden size from the draft model config because
@@ -276,16 +280,28 @@ class MtpProposer(Proposer):
self.runner.num_discarded_requests self.runner.num_discarded_requests
) )
is_prefill = len(scheduler_output.scheduled_new_reqs) > 0
req_scheduled_tokens = scheduler_output.num_scheduled_tokens req_scheduled_tokens = scheduler_output.num_scheduled_tokens
long_seq_metadata: AscendPrefillContextParallelMetadata = \ if self.pcp_size > 1:
self.runner.long_seq_metadata if self.pcp_size > 1 else None long_seq_metadata = self.runner.long_seq_metadata
input_ids_pcp_full = self.runner.input_ids_pcp_full
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full
query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full_cpu
num_reqs = self.runner.input_batch.num_reqs
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
query_start_loc_pcp_full_cpu[:num_reqs]
num_prefill_reqs = (ori_query_lens
> self.decode_threshold).sum().item()
num_decode_reqs = num_reqs - num_prefill_reqs
else:
long_seq_metadata = None
num_prefill_reqs = 0
num_decode_reqs = 0
if spec_decode_metadata is None: if spec_decode_metadata is None:
# update pcp related params # update pcp related params
if self.pcp_size > 1 and is_prefill: if self.pcp_size > 1:
token_indices_to_sample = None token_indices_to_sample = \
target_token_ids = self.runner.input_ids_pcp_full[: query_start_loc_pcp_full_cpu[1:num_reqs + 1] - 1
num_scheduled_tokens] target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states target_hidden_states = hidden_states
else: else:
@@ -295,6 +311,11 @@ class MtpProposer(Proposer):
target_positions = positions[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states[:num_scheduled_tokens] target_hidden_states = hidden_states[:num_scheduled_tokens]
else: else:
if self.pcp_size > 1:
common_attn_metadata.query_start_loc_cpu = \
query_start_loc_pcp_full_cpu[:num_reqs + 1]
common_attn_metadata.query_start_loc = \
query_start_loc_pcp_full[:num_reqs + 1]
if self.speculative_config.disable_padded_drafter_batch: if self.speculative_config.disable_padded_drafter_batch:
token_indices_to_sample = None token_indices_to_sample = None
common_attn_metadata, token_indices =\ common_attn_metadata, token_indices =\
@@ -309,9 +330,14 @@ class MtpProposer(Proposer):
common_attn_metadata, common_attn_metadata,
spec_decode_metadata, spec_decode_metadata,
valid_sampled_tokens_count) valid_sampled_tokens_count)
target_token_ids = self.runner.input_ids[token_indices] if self.pcp_size > 1:
target_positions = positions[token_indices] target_token_ids = input_ids_pcp_full[token_indices]
target_hidden_states = hidden_states[token_indices] target_positions = positions
target_hidden_states = hidden_states
else:
target_token_ids = self.runner.input_ids[token_indices]
target_positions = positions[token_indices]
target_hidden_states = hidden_states[token_indices]
draft_token_ids = self._propose( draft_token_ids = self._propose(
target_token_ids=target_token_ids, target_token_ids=target_token_ids,
@@ -321,9 +347,10 @@ class MtpProposer(Proposer):
last_token_indices=token_indices_to_sample, last_token_indices=token_indices_to_sample,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
is_prefill=is_prefill,
req_scheduled_tokens=req_scheduled_tokens, req_scheduled_tokens=req_scheduled_tokens,
long_seq_metadata=long_seq_metadata, long_seq_metadata=long_seq_metadata,
num_prefill_reqs=num_prefill_reqs,
num_decode_reqs=num_decode_reqs,
) )
return draft_token_ids return draft_token_ids
@@ -464,9 +491,10 @@ class MtpProposer(Proposer):
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
mm_embed_inputs: Optional[tuple[list[torch.Tensor], mm_embed_inputs: Optional[tuple[list[torch.Tensor],
torch.Tensor]] = None, torch.Tensor]] = None,
is_prefill=False,
req_scheduled_tokens=None, req_scheduled_tokens=None,
long_seq_metadata=None, long_seq_metadata=None,
num_prefill_reqs=0,
num_decode_reqs=0,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = target_token_ids.shape[0] num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0] batch_size = next_token_ids.shape[0]
@@ -488,20 +516,65 @@ class MtpProposer(Proposer):
self.input_ids[last_token_indices] = next_token_ids self.input_ids[last_token_indices] = next_token_ids
# update pcp related params # update pcp related params
if self.pcp_size > 1 and is_prefill: if self.pcp_size > 1:
num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens = \ assert long_seq_metadata is not None
self._split_pcp_input(req_scheduled_tokens, num_tokens, target_hidden_states)
# graph mode padding not considered now
num_input_tokens = num_tokens
self.input_ids[:num_input_tokens].copy_(input_ids)
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
# 1. preprocess decode/prefill input_ids & target_hidden_states
# decode input_ids: keep unchanged
# decode target_hidden_states: remove padding
# prefill input_ids: add padding and pcp split
# prefill target_hidden_states: pcp split
num_tokens_d = num_decode_reqs * self.decode_threshold
num_tokens_d_padded = num_tokens_d * self.pcp_size
input_ids_d = self.input_ids[:num_tokens_d]
input_ids_p = self.input_ids[num_tokens_d:num_tokens]
target_hidden_states_d_padded = \
target_hidden_states[:num_tokens_d_padded]
if num_tokens_d:
# remove padding (from pcp all-gather) in decode part
target_hidden_states_d = target_hidden_states_d_padded.reshape(
[
num_decode_reqs, self.decode_threshold * self.pcp_size,
-1
])[:, :self.decode_threshold, :].reshape(
[num_tokens_d, -1])
else:
target_hidden_states_d = target_hidden_states_d_padded
target_hidden_states_p = target_hidden_states[num_tokens_d_padded:]
req_scheduled_tokens_p = {}
for i, req_id in enumerate(self.runner.input_batch.req_ids):
if i >= num_decode_reqs:
req_scheduled_tokens_p[req_id] = \
req_scheduled_tokens[req_id]
(num_tokens_p, input_ids_p, target_hidden_states_p,
max_query_len_p, seq_lens_p, cu_num_tokens_p) = \
self._split_pcp_input(
req_scheduled_tokens_p, input_ids_p, target_hidden_states_p)
num_tokens = num_tokens_d + num_tokens_p
target_positions = target_positions[:num_tokens]
self.input_ids[:num_tokens].copy_(
torch.cat([input_ids_d, input_ids_p], dim=0))
target_hidden_states = torch.cat(
[target_hidden_states_d, target_hidden_states_p], dim=0)
# 2. update attn_metadata params that may be influenced by pcp
common_attn_metadata.num_actual_tokens = num_tokens common_attn_metadata.num_actual_tokens = num_tokens
common_attn_metadata.max_query_len = max_query_len common_attn_metadata.max_query_len = max(self.decode_threshold,
common_attn_metadata.seq_lens_cpu = seq_lens.cpu() max_query_len_p)
common_attn_metadata.query_start_loc = \ common_attn_metadata.seq_lens[num_decode_reqs:] = seq_lens_p
cu_num_tokens[:batch_size + 1] common_attn_metadata.seq_lens_cpu[num_decode_reqs:] = seq_lens_p
common_attn_metadata.query_start_loc_cpu = \ query_start_loc_p = cu_num_tokens_p[1:] + \
cu_num_tokens[:batch_size + 1].cpu() common_attn_metadata.query_start_loc[num_decode_reqs].item()
common_attn_metadata.query_start_loc[num_decode_reqs + 1:] = \
query_start_loc_p
common_attn_metadata.query_start_loc_cpu[num_decode_reqs + 1:] = \
query_start_loc_p
# 3. update sample_indices according to main model
if num_decode_reqs:
last_token_indices[:num_decode_reqs] = \
self.runner.logits_indices[last_token_indices[:num_decode_reqs]]
if num_prefill_reqs:
last_token_indices[-num_prefill_reqs:] = \
self.runner.logits_indices[-num_prefill_reqs:]
assert self.runner is not None assert self.runner is not None
@@ -575,6 +648,12 @@ class MtpProposer(Proposer):
last_token_indices, last_token_indices,
(0, max_num_reqs_across_dp - num_indices)) (0, max_num_reqs_across_dp - num_indices))
if self.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
hidden_states = torch.index_select(
hidden_states, 0, self.runner.
pcp_allgather_restore_idx[:hidden_states.shape[0]])
sample_hidden_states = hidden_states[last_token_indices] sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
if lmhead_tp_enable() and num_indices < logits.shape[0]: if lmhead_tp_enable() and num_indices < logits.shape[0]:
@@ -854,16 +933,26 @@ class MtpProposer(Proposer):
return spec_common_attn_metadata, token_indices, token_indices_to_sample return spec_common_attn_metadata, token_indices, token_indices_to_sample
def _split_pcp_input(self, req_scheduled_tokens, num_tokens, def _split_pcp_input(self, req_scheduled_tokens, input_ids,
target_hidden_states): target_hidden_states):
""" """
Split input_ids and target_hidden_states in pcp group. Split prefill input_ids and target_hidden_states in pcp group.
1. input_ids padding: [t0, t1, t2, t3, t4, t5] -> [t0, t1, t2, t3, t4, t5, pad, pad] 1. input_ids padding: [t0, t1, t2, t3, t4, t5] -> [t0, t1, t2, t3, t4, t5, pad, pad]
2. split input_ids: pcp0 [t0, t1, pad, pad], pcp1 [t2, t3, t4, t5] 2. split input_ids: pcp0 [t0, t1, pad, pad], pcp1 [t2, t3, t4, t5]
3. split target_hidden_states (already include cp padding): 3. split target_hidden_states (already include pcp padding):
[h0, h1, h2, h3, h4, h5, pad, pad] -> pcp0 [h0, h1, pad, pad], pcp1 [h2, h3, h4, h5] [h0, h1, h2, h3, h4, h5, pad, pad] -> pcp0 [h0, h1, pad, pad], pcp1 [h2, h3, h4, h5]
4. also update max_query_len, seq_lens, cu_num_tokens according to pcp split. 4. also update max_query_len, seq_lens, cu_num_tokens according to pcp split.
""" """
if len(req_scheduled_tokens) == 0:
# no prefill inputs to split, return empty result
return (
0,
torch.zeros([0], device='npu'),
torch.zeros([0, target_hidden_states.size(1)], device='npu'),
0,
torch.zeros([0]),
torch.tensor([0], dtype=torch.int32),
)
def _pcp_pad_and_split(num_tokens): def _pcp_pad_and_split(num_tokens):
num_pcp_padded_scheduled_tokens = cdiv( num_pcp_padded_scheduled_tokens = cdiv(
@@ -885,7 +974,6 @@ class MtpProposer(Proposer):
return req_position_cp, num_pcp_padded_scheduled_tokens, pcp_pad return req_position_cp, num_pcp_padded_scheduled_tokens, pcp_pad
num_pcp_scheduled_tokens = [] num_pcp_scheduled_tokens = []
input_ids_list = self.input_ids[:num_tokens]
ori_start_index = 0 ori_start_index = 0
pad_start_index = 0 pad_start_index = 0
pcp_split_input_ids_list = [] pcp_split_input_ids_list = []
@@ -896,8 +984,8 @@ class MtpProposer(Proposer):
actual_num_tokens = len(req_position_pcp) actual_num_tokens = len(req_position_pcp)
num_pcp_scheduled_tokens.append(actual_num_tokens) num_pcp_scheduled_tokens.append(actual_num_tokens)
pad_input_ids = F.pad( pad_input_ids = F.pad(
input_ids_list[ori_start_index:ori_start_index + input_ids[ori_start_index:ori_start_index + ori_num_tokens],
ori_num_tokens], (0, num_pcp_pad)) (0, num_pcp_pad))
ori_start_index += ori_num_tokens ori_start_index += ori_num_tokens
pcp_chunk_indices = [ pcp_chunk_indices = [
pad_start_index + pos for pos in req_position_pcp pad_start_index + pos for pos in req_position_pcp

View File

@@ -487,19 +487,29 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.speculative_config and self.pcp_size > 1: if self.speculative_config and self.pcp_size > 1:
self.input_ids_pcp_full = torch.zeros(self.max_num_tokens, self.input_ids_pcp_full = torch.zeros(self.max_num_tokens,
dtype=torch.int32, dtype=torch.int32,
device="cpu", device=self.device)
pin_memory=True) self.input_ids_pcp_full_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=True)
self.query_start_loc_pcp_full = torch.zeros(self.max_num_reqs + 1, self.query_start_loc_pcp_full = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32, dtype=torch.int32,
device="cpu", device=self.device)
pin_memory=True) self.query_start_loc_pcp_full_cpu = \
self.query_start_loc_pcp_full_np = self.query_start_loc_pcp_full.numpy( torch.zeros(self.max_num_reqs + 1,
) dtype=torch.int32,
device="cpu",
pin_memory=True)
self.query_start_loc_pcp_full_np = \
self.query_start_loc_pcp_full_cpu.numpy()
self.positions_pcp_full = torch.zeros(self.max_num_tokens, self.positions_pcp_full = torch.zeros(self.max_num_tokens,
dtype=torch.int64, dtype=torch.int64,
device="cpu", device="cpu",
pin_memory=True) pin_memory=True)
self.positions_np_pcp_full = self.positions_pcp_full.numpy() self.positions_pcp_full_np = self.positions_pcp_full.numpy()
self.decode_threshold = 1 + (
self.speculative_config.num_speculative_tokens
if self.speculative_config else 0)
self.use_aclgraph = self._use_aclgraph() self.use_aclgraph = self._use_aclgraph()
self.aclgraph_batch_sizes = list( self.aclgraph_batch_sizes = list(
@@ -1854,8 +1864,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logits_indices = torch.from_numpy(cu_num_tokens - 1).to( logits_indices = torch.from_numpy(cu_num_tokens - 1).to(
self.device, non_blocking=True) self.device, non_blocking=True)
else: else:
# pcp not supported now
assert self.pcp_size == 1
# Get the number of draft tokens for each request. # Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all # Iterate over the dictionary rather than all requests since not all
# requests have draft tokens. # requests have draft tokens.
@@ -1866,11 +1874,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_draft_tokens[req_idx] = len(draft_token_ids) num_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_metadata = self._calc_spec_decode_metadata( spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens) num_draft_tokens, cu_num_tokens, self.num_pcp_pads[:num_reqs])
logits_indices = spec_decode_metadata.logits_indices logits_indices = spec_decode_metadata.logits_indices
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
self.num_draft_tokens.np[num_reqs:].fill(0) self.num_draft_tokens.np[num_reqs:].fill(0)
self.num_draft_tokens.copy_to_gpu() self.num_draft_tokens.copy_to_gpu()
# save logits_indices for pcp spec decode usage
self.logits_indices = logits_indices
# Used in the below loop. # Used in the below loop.
# query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] # query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
@@ -1883,8 +1893,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu() self.num_accepted_tokens.copy_to_gpu()
is_prefill = len(scheduler_output.scheduled_new_reqs) > 0 if self.speculative_config and self.pcp_size > 1:
if self.speculative_config and self.pcp_size > 1 and is_prefill:
self._generate_pcp_mtp_input( self._generate_pcp_mtp_input(
num_reqs, scheduler_output.total_num_scheduled_tokens, num_reqs, scheduler_output.total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens) scheduler_output.num_scheduled_tokens)
@@ -2040,8 +2049,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# FIXME: Try using `auto_dispatch_capture=True` # FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_dcp_pcp_params(self.update_stream, update_mla_attn_dcp_pcp_params(self.update_stream,
forward_context, forward_context,
maybe_padded_num_tokens, maybe_padded_num_tokens)
self.speculative_config)
else: else:
# FIXME: Try using `auto_dispatch_capture=True` # FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context, update_mla_attn_params(self.update_stream, forward_context,
@@ -2110,6 +2118,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self, self,
num_draft_tokens: np.ndarray, num_draft_tokens: np.ndarray,
cu_num_scheduled_tokens: np.ndarray, cu_num_scheduled_tokens: np.ndarray,
num_pcp_pads: np.ndarray,
) -> SpecDecodeMetadata: ) -> SpecDecodeMetadata:
# Inputs: # Inputs:
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209] # cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
@@ -2138,6 +2147,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] # Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
logits_indices += arange logits_indices += arange
# while pcp > 1, decode results may contain padding (from pcp all-gather),
# update logits_indices after getting draft_token_ids from ori logits_indices
if self.pcp_size > 1:
cu_num_scheduled_tokens = cu_num_scheduled_tokens * self.pcp_size - num_pcp_pads
logits_indices_pcp = np.repeat(
cu_num_scheduled_tokens - num_sampled_tokens,
num_sampled_tokens)
logits_indices_pcp += arange
logits_indices_pcp = torch.from_numpy(logits_indices_pcp).to(
self.device, non_blocking=True)
# Compute the bonus logits indices. # Compute the bonus logits indices.
bonus_logits_indices = cu_num_sampled_tokens - 1 bonus_logits_indices = cu_num_sampled_tokens - 1
@@ -2173,6 +2193,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# draft_token_indices: [ 1, 2, 3, 105, 106, 208] # draft_token_indices: [ 1, 2, 3, 105, 106, 208]
draft_token_ids = self.input_ids[logits_indices] draft_token_ids = self.input_ids[logits_indices]
draft_token_ids = draft_token_ids[target_logits_indices + 1] draft_token_ids = draft_token_ids[target_logits_indices + 1]
if self.pcp_size > 1:
logits_indices = logits_indices_pcp
if vllm_version_is("0.11.0"): if vllm_version_is("0.11.0"):
metadata = SpecDecodeMetadata( metadata = SpecDecodeMetadata(
draft_token_ids=draft_token_ids, draft_token_ids=draft_token_ids,
@@ -2920,8 +2942,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# FIXME: Try using `auto_dispatch_capture=True` # FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_dcp_pcp_params(self.update_stream, update_mla_attn_dcp_pcp_params(self.update_stream,
forward_context, forward_context,
positions.shape[0], positions.shape[0])
self.speculative_config)
else: else:
# FIXME: Try using `auto_dispatch_capture=True` # FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context, update_mla_attn_params(self.update_stream, forward_context,
@@ -4328,18 +4349,25 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_decode_reqs = sum( num_decode_reqs = sum(
self.input_batch.num_computed_tokens_cpu[:num_reqs] >= self.input_batch.num_computed_tokens_cpu[:num_reqs] >=
self.input_batch.num_prompt_tokens[:num_reqs]) self.input_batch.num_prompt_tokens[:num_reqs])
num_decode_tokens = sum(tokens[:num_decode_reqs])
num_padded_scheduled_tokens = np.ceil( num_padded_scheduled_tokens = np.ceil(
tokens / tokens /
(2 * self.pcp_size)).astype(np.int32) * (2 * self.pcp_size) (2 * self.pcp_size)).astype(np.int32) * (2 * self.pcp_size)
num_padded_scheduled_tokens[:num_decode_reqs] = self.pcp_size num_padded_scheduled_tokens[:num_decode_reqs] = (
tokens[:num_decode_reqs] * self.pcp_size)
self.num_pcp_pads = num_padded_scheduled_tokens - tokens self.num_pcp_pads = num_padded_scheduled_tokens - tokens
cu_padded_tokens, pcp_padded_arange = \ cu_padded_tokens, pcp_padded_arange = \
self._get_cumsum_and_arange(num_padded_scheduled_tokens) self._get_cumsum_and_arange(num_padded_scheduled_tokens)
unpad_mask = torch.from_numpy( unpad_mask = torch.from_numpy(
pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens)) pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens))
unpad_mask_decode = unpad_mask[:num_decode_tokens * self.pcp_size]
unpad_mask_decode = unpad_mask_decode.reshape([-1, self.pcp_size])
unpad_mask_decode[:, 0] = True
unpad_mask_decode[:, 1:] = False
pcp_tokens = num_padded_scheduled_tokens // self.pcp_size pcp_tokens = num_padded_scheduled_tokens // self.pcp_size
pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1) pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1)
pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs]
_, pcp_arange = self._get_cumsum_and_arange(pcp_tokens) _, pcp_arange = self._get_cumsum_and_arange(pcp_tokens)
_, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes) _, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes)
pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes,
@@ -4356,14 +4384,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
np.repeat(head_start_loc, pcp_chunk_sizes) np.repeat(head_start_loc, pcp_chunk_sizes)
# Decode reqs do not have tail chunks. # Decode reqs do not have tail chunks.
positions[~pcp_head_chunk_mask] = \ positions[~pcp_head_chunk_mask] = \
pcp_chunk_arange[num_decode_reqs:] + \ pcp_chunk_arange[num_decode_tokens:] + \
np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_reqs:] np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:]
return positions return positions
positions = get_current_rank_positions( positions = get_current_rank_positions(
np.zeros(num_reqs, dtype=np.int32), self.pcp_rank) np.zeros(num_reqs, dtype=np.int32), self.pcp_rank)
# Decode tokens are duplicate and their positions always be 0. # Decode tokens are duplicate and their positions always be 0.
positions[:num_decode_reqs] = 0 if num_decode_reqs > 0:
positions[:num_decode_tokens] = self._get_cumsum_and_arange(
tokens[:num_decode_reqs])[1]
all_positions = [ all_positions = [
get_current_rank_positions(cu_padded_tokens, rank_i) get_current_rank_positions(cu_padded_tokens, rank_i)
@@ -4372,7 +4402,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
all_positions_tensor = torch.from_numpy(np.concatenate(all_positions)) all_positions_tensor = torch.from_numpy(np.concatenate(all_positions))
self.pcp_allgather_restore_idx[:all_positions_tensor.shape[0]].copy_( self.pcp_allgather_restore_idx[:all_positions_tensor.shape[0]].copy_(
all_positions_tensor.float().argsort().long(), non_blocking=True) all_positions_tensor.float().argsort().long(), non_blocking=True)
pcp_tokens[:num_decode_reqs] = 1
return pcp_tokens, positions, unpad_mask return pcp_tokens, positions, unpad_mask
def _get_pcp_local_seq_lens( def _get_pcp_local_seq_lens(
@@ -4524,7 +4553,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens, def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens,
seq_lens_origin): seq_lens_origin):
num_reqs = self.input_batch.num_reqs # In dummy run num_reqs == 0, update it from seq_lens
num_reqs = self.input_batch.num_reqs or seq_lens.size(0)
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs] num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
>= self.input_batch.num_prompt_tokens[:num_reqs]) >= self.input_batch.num_prompt_tokens[:num_reqs])
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
@@ -4535,14 +4565,28 @@ class NPUModelRunner(LoRAModelRunnerMixin):
local_chunked_kv_lens) local_chunked_kv_lens)
long_seq_metadata = None long_seq_metadata = None
if self.pcp_size * self.dcp_size > 1: if self.pcp_size * self.dcp_size > 1:
num_computed_tokens_of_pcp_dcp = torch.zeros(
[
num_reqs * self.decode_threshold, self.pcp_size,
self.dcp_size
],
dtype=torch.int32,
)
# For pcp + spec decode, we flatten seq_lens
# to avoid irregular spec_attn_mask shape
for decode_idx in range(self.decode_threshold):
num_computed_tokens_of_pcp_dcp[
self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \
self._get_pcp_local_seq_lens(
seq_lens_origin - decode_idx,
self.pcp_size,
self.dcp_size,
self.parallel_config.cp_kv_cache_interleave_size,
)
long_seq_metadata = AscendPrefillContextParallelMetadata( long_seq_metadata = AscendPrefillContextParallelMetadata(
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_computed_tokens_of_pcp_dcp=self._get_pcp_local_seq_lens( num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.
seq_lens_origin, numpy(),
self.pcp_size,
self.dcp_size,
self.parallel_config.cp_kv_cache_interleave_size,
).numpy(),
local_chunked_kv_lens=local_chunked_kv_lens, local_chunked_kv_lens=local_chunked_kv_lens,
mask_for_non_zero_chunk=mask_for_non_zero_chunk, mask_for_non_zero_chunk=mask_for_non_zero_chunk,
max_chunk_num=max_chunk_num) max_chunk_num=max_chunk_num)
@@ -4706,16 +4750,25 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens_pcp_full) num_scheduled_tokens_pcp_full)
arange_pcp_full = self.arange_np[: arange_pcp_full = self.arange_np[:
total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full
positions_np_pcp_full = self.positions_np_pcp_full[: positions_pcp_full_np = self.positions_pcp_full_np[:
total_num_scheduled_tokens_pcp_full] total_num_scheduled_tokens_pcp_full]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full], np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full],
arange_pcp_full, arange_pcp_full,
out=positions_np_pcp_full) out=positions_pcp_full_np)
token_indices_pcp_full = ( token_indices_pcp_full = (
positions_np_pcp_full + positions_pcp_full_np +
req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1]) req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1])
torch.index_select( torch.index_select(
self.input_batch.token_ids_cpu_tensor.flatten(), self.input_batch.token_ids_cpu_tensor.flatten(),
0, 0,
torch.from_numpy(token_indices_pcp_full), torch.from_numpy(token_indices_pcp_full),
out=self.input_ids_pcp_full[:total_num_scheduled_tokens_pcp_full]) out=self.
input_ids_pcp_full_cpu[:total_num_scheduled_tokens_pcp_full])
self.query_start_loc_pcp_full[:num_reqs + 1].copy_(
self.query_start_loc_pcp_full_cpu[:num_reqs + 1],
non_blocking=True,
)
self.input_ids_pcp_full[:total_num_scheduled_tokens_pcp_full].copy_(
self.input_ids_pcp_full_cpu[:total_num_scheduled_tokens_pcp_full],
non_blocking=True,
)