[Feature] Refactor PCP &DCP related code (#5214)

### What this PR does / why we need it?
Refactor pcp& dcp related code. we use pcp_manager class to Unifiy
Manage pcp & dcp . as we do this , many code can be deleted from
model_runner, and can avoid break pcp & dcp by other developments.
RFC:https://github.com/vllm-project/vllm-ascend/issues/5449
### Does this PR introduce _any_ user-facing change?
NO

### How was this patch tested?

- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Co-authored-by: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com>
This commit is contained in:
zhenwenqi2024
2025-12-31 09:29:57 +08:00
committed by GitHub
parent 46862ce1af
commit 5d9fde9819
7 changed files with 1156 additions and 1047 deletions

View File

@@ -24,7 +24,7 @@ from contextlib import contextmanager, nullcontext
from copy import copy, deepcopy
from dataclasses import dataclass
from multiprocessing import Manager
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, NamedTuple, Optional, Union
import numpy as np
import torch
@@ -78,8 +78,7 @@ from vllm.v1.worker.utils import AttentionGroup
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
AscendPrefillContextParallelMetadata)
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
# yapf conflicts with isort for this block
# yapf: disable
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
@@ -109,6 +108,7 @@ from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration,
lmhead_tp_enable, maybe_trans_nz,
set_weight_prefetch_method)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
from vllm_ascend.worker.pcp_utils import PCPManager
from vllm_ascend.ascend_forward_context import ( # isort: skip
MoECommType, get_mc2_tokens_capacity, select_moe_comm_method,
@@ -202,6 +202,26 @@ class NPUModelRunner(GPUModelRunner):
self.pcp_rank = 0
if self.pcp_size > 1:
self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs
max_buffer_num_tokens = self.max_num_tokens
if self.pcp_size * self.dcp_size > 1:
max_buffer_num_tokens = (self.max_num_tokens +
self.max_num_reqs * 2 * self.pcp_size)
self.pcp_manager = PCPManager(
self.pcp_size,
self.pcp_rank,
self.dcp_size,
self.dcp_rank,
max_buffer_num_tokens,
self.max_num_reqs,
self.device,
self.vllm_config,
self.pin_memory,
)
# TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this
self.input_ids = self._make_buffer(max_buffer_num_tokens,
dtype=torch.int32)
self.positions = self._make_buffer(max_buffer_num_tokens,
dtype=torch.int64)
self.sampler = AscendSampler()
self.attn_mask = None
self.attn_state = None
@@ -262,32 +282,6 @@ class NPUModelRunner(GPUModelRunner):
set_mc2_tokens_capacity(vllm_config, self.max_num_reqs,
self.uniform_decode_query_len)
set_mc2_mask(vllm_config, self.device)
self.pcp_allgather_restore_idx = torch.zeros(
self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.cp_kv_recover_idx_for_chunk: List[List[int]] = [
[] for _ in range(self.pcp_size)
]
self.num_pcp_pads = torch.zeros(self.max_num_reqs, dtype=torch.int32)
self.pcp_padded_slot_mapping = torch.zeros(
self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.num_actual_tokens_pcp_padded = 0
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
self.input_ids_pcp_full = self._make_buffer(self.max_num_tokens,
dtype=torch.int32)
self.query_start_loc_pcp_full = self._make_buffer(
self.max_num_reqs + 1, dtype=torch.int32)
self.positions_pcp_full = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=True)
self.positions_pcp_full_np = self.positions_pcp_full.numpy()
self.query_lens_pcp_full = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
self.decode_threshold = 1 + (
self.speculative_config.num_speculative_tokens
if self.speculative_config else 0)
@@ -359,6 +353,7 @@ class NPUModelRunner(GPUModelRunner):
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: IntermediateTensors | None = None
self.reorder_batch_threshold: int | None = None
self.long_seq_metadata = None
def _init_device_properties(self) -> None:
self.num_sms = None
@@ -508,49 +503,6 @@ class NPUModelRunner(GPUModelRunner):
return self.attn_mask_builder.get_mla_mask(self.dtype)
return self.attn_mask_builder.get_splitfuse_attn_mask()
def generate_kv_idx(self, scheduler_output):
if not self.pcp_size > 1:
return
self.cp_kv_recover_idx_for_chunk = [[] for _ in range(self.pcp_size)]
for i, req_id in enumerate(self.input_batch.req_ids):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
is_prefill = self.input_batch.num_computed_tokens_cpu[
i] < self.input_batch.num_prompt_tokens[i]
if is_prefill:
num_cp_padded_scheduled_tokens = cdiv(
num_scheduled_tokens,
2 * self.pcp_size) * (2 * self.pcp_size)
full_indices = list(
range(self.max_num_tokens * self.pcp_size * self.dcp_size +
self.pcp_size * self.dcp_size * self.max_num_reqs))
chunk_size = num_cp_padded_scheduled_tokens // (2 *
self.pcp_size)
num_added_recover_tokens = len(
self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_size
for rank in range(self.pcp_size):
self.cp_kv_recover_idx_for_chunk[rank].extend(
full_indices[rank * chunk_size +
num_added_recover_tokens:(rank + 1) *
chunk_size + num_added_recover_tokens])
self.cp_kv_recover_idx_for_chunk[rank].extend(
full_indices[num_cp_padded_scheduled_tokens -
(rank + 1) * chunk_size +
num_added_recover_tokens:
num_cp_padded_scheduled_tokens -
rank * chunk_size +
num_added_recover_tokens])
cp_kv_recover_idx_for_chunk = torch.from_numpy(
np.concatenate(
self.cp_kv_recover_idx_for_chunk)).to(device=self.device)
cp_kv_recover_idx_for_chunk.copy_(torch.tensor(
np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()),
non_blocking=True)
self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to(
torch.float32).argsort().to(torch.int32)
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
@@ -574,43 +526,70 @@ class NPUModelRunner(GPUModelRunner):
req_indices = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens)
_, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
positions_np = np.add(
self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
)
self.input_batch.block_table.compute_slot_mapping(
req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens)
total_num_pcp_pads = 0
if self.pcp_size > 1:
if not self.vllm_config.model_config.use_mla:
self.generate_kv_idx(scheduler_output)
tokens_before_update = tokens.copy()
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
tokens)
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs])
total_num_pcp_pads = torch.sum(self.num_pcp_pads[:num_reqs]).item()
else:
position_pcp, pcp_unpad_mask = None, None
self.num_pcp_pads[:num_reqs] = 0
max_num_scheduled_tokens = max(tokens)
if not scheduler_output.scheduled_spec_decode_tokens:
num_valid_tokens = np.array(tokens, dtype=np.int32)
else:
num_valid_tokens = np.array([
num_tokens -
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
for num_tokens, i in zip((tokens_before_update if self.
pcp_size > 1 else tokens), req_ids)
for num_tokens, i in zip(tokens, req_ids)
],
dtype=np.int32)
# Get the attention state.
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
num_valid_tokens)
self.attn_state = attn_state # type: ignore
# Determine if it's a splitfuse batch
with_prefill = attn_state not in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
self.attn_mask = self._make_attention_mask(attn_state)
# Get positions.
positions_np = self.positions.np[:total_num_scheduled_tokens]
cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens)
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np)
self.input_batch.block_table.compute_slot_mapping(
req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens)
# for pcp, prefill mtp should use origin scheduleroutput ,
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
self.pcp_manager.generate_pcp_mtp_input(
num_reqs, total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens, with_prefill,
self.input_batch, self.arange_np, req_indices, positions_np,
cu_num_tokens)
if self.pcp_size > 1:
if not self.vllm_config.model_config.use_mla:
self.pcp_manager.generate_kv_idx(scheduler_output,
self.input_batch)
num_scheduled_tokens[:
num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp(
num_scheduled_tokens[:num_reqs],
self.arange_np,
self.input_batch.num_reqs,
self.reorder_batch_threshold,
)
# Re-update after PCP split sequences.
total_num_scheduled_tokens = sum(num_scheduled_tokens)
req_indices = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens)
cu_num_tokens, _ = self._get_cumsum_and_arange(
num_scheduled_tokens)
positions_np = self.positions.np[:total_num_scheduled_tokens]
np.add(
self.input_batch.num_computed_tokens_cpu[req_indices],
position_pcp[:total_num_scheduled_tokens],
out=positions_np,
)
max_num_scheduled_tokens = max(tokens)
if (self.use_aclgraph and total_num_scheduled_tokens
<= self.cudagraph_batch_sizes[-1]):
# Add padding to the batch size.
@@ -627,17 +606,6 @@ class NPUModelRunner(GPUModelRunner):
else:
# Eager mode.
num_input_tokens = total_num_scheduled_tokens
# Get the attention state.
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
num_valid_tokens)
self.attn_state = attn_state # type: ignore
# Determine if it's a splitfuse batch
with_prefill = attn_state not in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
self.query_lens = torch.from_numpy(num_scheduled_tokens)
# Get info across DP ranks.
@@ -646,7 +614,7 @@ class NPUModelRunner(GPUModelRunner):
(maybe_padded_num_tokens, num_tokens_across_dp,
with_prefill) = self._sync_metadata_across_dp(num_input_tokens,
with_prefill)
self.with_prefill = with_prefill
# TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens
# We should consider removing maybe_padded_num_tokens later
num_input_tokens = maybe_padded_num_tokens
@@ -655,24 +623,8 @@ class NPUModelRunner(GPUModelRunner):
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
req_indices = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens)
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens)
if self.pcp_size > 1:
positions_np = self.positions.np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
position_pcp[:total_num_scheduled_tokens],
out=positions_np)
else:
self.positions.np[:total_num_scheduled_tokens] = positions_np
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self._calc_mrope_positions(scheduler_output)
@@ -766,21 +718,11 @@ class NPUModelRunner(GPUModelRunner):
self.seq_lens.gpu[num_reqs:].fill_(0)
self.query_lens = torch.from_numpy(num_scheduled_tokens)
# Copy the tensors to the NPU.
self._prepare_input_ids(scheduler_output, total_num_scheduled_tokens,
cu_num_tokens)
self.positions.cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions.copy_to_gpu()
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
num_valid_tokens)
self.attn_mask = self._make_attention_mask(attn_state)
self.attn_state = attn_state # type: ignore
self.with_prefill = with_prefill
self.num_tokens_across_dp = num_tokens_across_dp
attn_metadata: dict[str, Any] = {}
# Record the index of requests that should not be sampled,
@@ -914,9 +856,8 @@ class NPUModelRunner(GPUModelRunner):
# TODO: Support prompt logprobs.
spec_decode_metadata = None
if self.pcp_size * self.dcp_size > 1:
logits_indices = torch.from_numpy(
cu_num_tokens
) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1
logits_indices = self.pcp_manager.get_logits_indices(
cu_num_tokens, num_reqs)
logits_indices = logits_indices.pin_memory().to(
self.device, non_blocking=True)
else:
@@ -938,8 +879,10 @@ class NPUModelRunner(GPUModelRunner):
>= self.input_batch.num_prompt_tokens[req_idx]) else -1)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens,
self.num_pcp_pads[:num_reqs].numpy())
num_draft_tokens,
cu_num_tokens,
num_pcp_pads=self.pcp_manager.num_pcp_pads_cpu[:num_reqs]
if self.pcp_size > 1 else None)
logits_indices = spec_decode_metadata.logits_indices
# For DECODE only cuda graph of some attention backends (e.g., GDN).
@@ -961,23 +904,10 @@ class NPUModelRunner(GPUModelRunner):
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
self._generate_pcp_mtp_input(
num_reqs, scheduler_output.total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens, with_prefill,
req_indices, positions_np, cu_num_tokens)
long_seq_metadata = self._generate_pcp_metadata(
total_num_scheduled_tokens)
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
# NOTE: This is strange, why did we use total_num_scheduled_tokens before?
slot_mapping_size = (total_num_scheduled_tokens
if self.pcp_size == 1 else
total_num_scheduled_tokens * self.pcp_size -
total_num_pcp_pads)
if isinstance(kv_cache_group_spec.kv_cache_spec,
EncoderOnlyAttentionSpec):
# Encoder-only layers do not have KV cache, so we need to
@@ -993,30 +923,30 @@ class NPUModelRunner(GPUModelRunner):
device=self.device,
)
else:
maybe_pcp_full_tokens = (
num_input_tokens if self.pcp_size == 1 else
total_num_scheduled_tokens * self.pcp_size -
sum(self.pcp_manager.num_pcp_pads_cpu[:num_reqs]))
blk_table = self.input_batch.block_table[kv_cache_group_id]
blk_table_tensor = blk_table.get_device_tensor()
blk_table.slot_mapping.gpu[slot_mapping_size:].fill_(0)
if self.pcp_size > 1:
slot_mapping_for_pcp = blk_table.slot_mapping.gpu[:
long_seq_metadata
.
num_actual_tokens_pcp_padded]
slot_mapping_for_pcp[slot_mapping_size:].fill_(-1)
assert pcp_unpad_mask is not None
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:
pcp_unpad_mask
.
shape[
0]]
pcp_padded_slot_mapping.fill_(-1)
pcp_padded_slot_mapping[
pcp_unpad_mask] = slot_mapping_for_pcp[:
slot_mapping_size]
slot_mapping_for_pcp[:long_seq_metadata.
num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
blk_table.slot_mapping.gpu[:long_seq_metadata.num_actual_tokens_pcp_padded] = \
slot_mapping_for_pcp
slot_mapping = blk_table.slot_mapping.gpu[:
maybe_pcp_full_tokens]
if self.pcp_size * self.dcp_size == 1:
slot_mapping[
total_num_scheduled_tokens:num_input_tokens].fill_(-1)
slot_mapping = blk_table.slot_mapping.gpu
if self.pcp_size * self.dcp_size > 1:
self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata(
total_num_scheduled_tokens, self.query_lens,
self.attn_mask, self.input_batch)
blk_table.slot_mapping.gpu[maybe_pcp_full_tokens:].fill_(-1)
slot_mapping = slot_mapping[:maybe_pcp_full_tokens]
slot_mapping = self.pcp_manager.get_padded_slot_mapping(
total_num_scheduled_tokens,
slot_mapping,
)
blk_table.slot_mapping.gpu[:self.pcp_manager.
num_actual_tokens_pcp_padded] = slot_mapping
# NOTE: This is a temporary hack, now in GPUModelRunner, this prepare_inputs
# has been split to multiple parts, and there are 3 parts that is related to this
@@ -1055,7 +985,7 @@ class NPUModelRunner(GPUModelRunner):
seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
seq_lens=self.seq_lens.gpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=slot_mapping_size,
num_actual_tokens=total_num_scheduled_tokens,
num_input_tokens=num_input_tokens,
actual_seq_lengths_q=self.actual_seq_lengths_q,
# TODO: change this to the right block table for linear attn
@@ -1069,8 +999,9 @@ class NPUModelRunner(GPUModelRunner):
attn_state=self.attn_state,
max_query_len=max_num_scheduled_tokens,
decode_token_per_req=self.decode_token_per_req,
prefill_context_parallel_metadata=long_seq_metadata,
max_seq_len=0)
prefill_context_parallel_metadata=self.long_seq_metadata,
max_seq_len=0,
)
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
# For pcp + spec decode, we flatten block_table
@@ -1080,8 +1011,10 @@ class NPUModelRunner(GPUModelRunner):
# (num_reqs_d + num_reqs_p, max_num_blocks),
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs]
ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs]
ori_query_lens_cpu = self.pcp_manager.query_lens_pcp_full.cpu[:
num_reqs]
ori_query_lens = self.pcp_manager.query_lens_pcp_full.gpu[:
num_reqs]
num_prefill_reqs = (ori_query_lens
> self.decode_threshold).sum().item()
num_decode_reqs = num_reqs - num_prefill_reqs
@@ -1097,13 +1030,17 @@ class NPUModelRunner(GPUModelRunner):
ori_query_lens[:num_decode_reqs], dim=0))
common_attn_metadata.block_table_tensor = \
blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs]
long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu
assert self.long_seq_metadata is not None
self.long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu
if 'pad_size' in locals() and pad_size > 0:
ori_query_lens_cpu[-pad_size:] = \
torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())
long_seq_metadata.max_query_len_pcp_full = \
self.long_seq_metadata.max_query_len_pcp_full = \
ori_query_lens_cpu.max().item()
if self.speculative_config and \
self.spec_decode_common_attn_metadata is None:
self.spec_decode_common_attn_metadata = common_attn_metadata
@@ -1193,19 +1130,12 @@ class NPUModelRunner(GPUModelRunner):
pad_size = get_forward_context().pad_size
if pad_size > 0:
hidden_states = hidden_states[:-pad_size, :]
if self.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(
hidden_states[:self.num_actual_tokens_pcp_padded //
self.pcp_size], 0)
hidden_states = torch.index_select(
hidden_states, 0,
self.pcp_allgather_restore_idx[:hidden_states.shape[0]])
return hidden_states
return hidden_states if self.pcp_size == 1 else self.pcp_manager.get_restore_hidden_states(
hidden_states)
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
num_valid_tokens):
if np.array_equal(self.seq_lens.np[:num_reqs], num_scheduled_tokens):
if np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] == 0):
attn_state = AscendAttentionState.PrefillNoCache
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
@@ -1231,7 +1161,7 @@ class NPUModelRunner(GPUModelRunner):
self,
num_draft_tokens: np.ndarray,
cu_num_scheduled_tokens: np.ndarray,
num_pcp_pads: np.ndarray,
num_pcp_pads: np.ndarray | None,
) -> SpecDecodeMetadata:
# Inputs:
# cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
@@ -1846,7 +1776,9 @@ class NPUModelRunner(GPUModelRunner):
self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
long_seq_metadata = self._generate_pcp_metadata(num_tokens)
long_seq_metadata = None if self.pcp_size * self.dcp_size == 1 else self.pcp_manager.generate_pcp_metadata(
num_tokens, self.query_lens, self.attn_mask,
self.input_batch)
if long_seq_metadata is not None:
pcp_world_size = get_pcp_group().world_size
dcp_world_size = get_dcp_group().world_size
@@ -2890,365 +2822,6 @@ class NPUModelRunner(GPUModelRunner):
parent_module_name):
super().capture_model()
def _update_tokens_for_pcp(self, tokens):
num_reqs = self.input_batch.num_reqs
tokens = np.array(tokens, dtype=np.int32)
num_decode_reqs = (np.array(tokens) <= self.decode_threshold).sum()
num_decode_tokens = sum(tokens[:num_decode_reqs])
num_padded_scheduled_tokens = np.ceil(
tokens /
(2 * self.pcp_size)).astype(np.int32) * (2 * self.pcp_size)
num_padded_scheduled_tokens[:num_decode_reqs] = (
tokens[:num_decode_reqs] * self.pcp_size)
self.num_pcp_pads[:num_reqs] = torch.tensor(
num_padded_scheduled_tokens - tokens)
cu_padded_tokens, pcp_padded_arange = \
self._get_cumsum_and_arange(num_padded_scheduled_tokens)
unpad_mask = torch.from_numpy(
pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens))
unpad_mask_decode = unpad_mask[:num_decode_tokens * self.pcp_size]
unpad_mask_decode = unpad_mask_decode.reshape([-1, self.pcp_size])
unpad_mask_decode[:, 0] = True
unpad_mask_decode[:, 1:] = False
pcp_tokens = num_padded_scheduled_tokens // self.pcp_size
pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1)
pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs]
_, pcp_arange = self._get_cumsum_and_arange(pcp_tokens)
_, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes)
pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes,
pcp_tokens)
def get_current_rank_positions(cu_tokens, rank):
positions_start_loc = np.zeros_like(cu_tokens)
positions_start_loc[1:] = cu_tokens[:-1]
positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32)
head_start_loc = positions_start_loc + rank * pcp_chunk_sizes
tail_start_loc = positions_start_loc + \
(2 * self.pcp_size - rank - 1) * pcp_chunk_sizes
positions[pcp_head_chunk_mask] = pcp_chunk_arange + \
np.repeat(head_start_loc, pcp_chunk_sizes)
# Decode reqs do not have tail chunks.
positions[~pcp_head_chunk_mask] = \
pcp_chunk_arange[num_decode_tokens:] + \
np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:]
return positions
positions = get_current_rank_positions(
np.zeros(num_reqs, dtype=np.int32), self.pcp_rank)
# Decode tokens are duplicate and their positions always be 0.
if num_decode_reqs > 0:
positions[:num_decode_tokens] = self._get_cumsum_and_arange(
tokens[:num_decode_reqs])[1]
all_positions = [
get_current_rank_positions(cu_padded_tokens, rank_i)
for rank_i in range(self.pcp_size)
]
all_positions_tensor = torch.from_numpy(np.concatenate(all_positions))
self.pcp_allgather_restore_idx[:all_positions_tensor.shape[0]].copy_(
all_positions_tensor.float().argsort().long(), non_blocking=True)
return pcp_tokens, positions, unpad_mask
def _get_cp_local_seq_lens(
self,
seq_lens: torch.Tensor,
pcp_world_size: int = 1,
dcp_world_size: int = 1,
cp_kv_cache_interleave_size: int = 1,
) -> torch.Tensor:
"""While using pcp or dcp, kv_cache size stored on each rank may be different,
use this function to calculate split decode seq_lens of each (p/d)cp rank.
"""
num_requests = seq_lens.size(0)
total_world_size = pcp_world_size * dcp_world_size
seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size)
rank_offsets = (torch.arange(total_world_size,
dtype=torch.int32).unsqueeze(0).repeat(
num_requests, 1))
base = (seq_lens_tiled // cp_kv_cache_interleave_size //
total_world_size * cp_kv_cache_interleave_size)
remainder = seq_lens_tiled - base * total_world_size
remainder = torch.clip(
remainder - rank_offsets * cp_kv_cache_interleave_size,
0,
cp_kv_cache_interleave_size,
)
dcp_local_seq_lens = (base + remainder).reshape(
[-1, pcp_world_size, dcp_world_size])
return dcp_local_seq_lens
def _generate_pcp_metadata(self, total_num_scheduled_tokens):
# In dummy run num_reqs == 0, update it from seq_lens
num_reqs = self.input_batch.num_reqs or self.query_lens.size(0)
query_lens = self.query_lens_pcp_full.cpu[:num_reqs] \
if self.pcp_size > 1 and self.speculative_config else self.query_lens
num_decodes = (query_lens <= self.decode_threshold).sum().item()
num_prefills = num_reqs - num_decodes
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
long_seq_metadata = None
if self.pcp_size * self.dcp_size > 1:
decode_context_lens = self.input_batch.num_tokens[:num_decodes]
prefill_context_lens = self.input_batch.num_computed_tokens_cpu[
num_decodes:num_reqs]
context_lens = np.concatenate(
[decode_context_lens, prefill_context_lens])
num_computed_tokens_of_pcp_dcp = torch.zeros(
[
num_reqs * self.decode_threshold, self.pcp_size,
self.dcp_size
],
dtype=torch.int32,
)
# For pcp + spec decode, we flatten seq_lens
# to avoid irregular spec_attn_mask shape.
# Same as block_table, we flatten decode seq_lens to query_lens,
# and keep prefill seq_lens unchanged.
for decode_idx in range(self.decode_threshold):
num_computed_tokens_of_pcp_dcp[
self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \
self._get_cp_local_seq_lens(
torch.tensor(context_lens) - decode_idx,
self.pcp_size,
self.dcp_size,
self.parallel_config.cp_kv_cache_interleave_size,
)
if self.decode_threshold > 1:
num_computed_tokens_of_pcp_dcp_list = []
if num_decodes:
num_decodes_flatten = \
self.query_lens[:num_decodes].sum().item()
if self.query_lens[:num_decodes].min().item(
) == self.decode_threshold:
decode_flatten_idx = list(range(num_decodes_flatten))
else:
decode_flatten_idx = []
for req_id in range(num_decodes):
offset = (req_id + 1) * self.decode_threshold
decode_flatten_idx += \
list(range(offset - self.query_lens[req_id], offset))
num_computed_tokens_of_pcp_dcp_list.append(
num_computed_tokens_of_pcp_dcp[decode_flatten_idx])
if num_prefills:
num_computed_tokens_of_pcp_dcp_list.append(
num_computed_tokens_of_pcp_dcp[
(num_decodes + 1) * self.decode_threshold -
1::self.decode_threshold])
num_computed_tokens_of_pcp_dcp = torch.cat(
num_computed_tokens_of_pcp_dcp_list, dim=0)
long_seq_metadata = AscendPrefillContextParallelMetadata(
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.
numpy())
if self.pcp_size > 1:
q_head_idx, q_tail_idx = [], []
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], []
chunk_seqlens = []
kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], []
q_req_offset = 0
kv_req_offset = 0
q_head_chunk_id = self.pcp_rank
q_tail_chunk_id = self.pcp_size * 2 - 1 - self.pcp_rank
for i, seq_len in enumerate(self.query_lens):
if i < num_decodes:
continue
chunk_len = seq_len // 2
chunk_seqlens.append(chunk_len)
q_head_idx.extend(
list(range(q_req_offset, q_req_offset + chunk_len)))
kv_with_q_head_nomask_idx.extend(
list(
range(kv_req_offset, kv_req_offset +
chunk_len * q_head_chunk_id)))
kv_with_q_head_mask_idx.extend(
list(
range(
kv_req_offset + chunk_len * q_head_chunk_id,
kv_req_offset + chunk_len *
(q_head_chunk_id + 1))))
kv_with_q_head_nomask_seqlens.append(chunk_len *
q_head_chunk_id)
q_tail_idx.extend(
list(
range(q_req_offset + chunk_len,
q_req_offset + chunk_len * 2)))
kv_with_q_tail_nomask_idx.extend(
list(
range(kv_req_offset, kv_req_offset +
chunk_len * q_tail_chunk_id)))
kv_with_q_tail_mask_idx.extend(
list(
range(
kv_req_offset + chunk_len * q_tail_chunk_id,
kv_req_offset + chunk_len *
(q_tail_chunk_id + 1))))
kv_with_q_tail_nomask_seqlens.append(chunk_len *
q_tail_chunk_id)
q_req_offset += seq_len
kv_req_offset += seq_len * self.pcp_size
# Convert lists to tensors and move to device
def _list_to_tensor(lst, device, dtype=torch.int32):
tensor_npu = torch.zeros(len(lst),
dtype=dtype,
device=device)
tensor_npu.copy_(torch.tensor(lst, dtype=dtype),
non_blocking=True)
return tensor_npu
q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device)
q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device)
self.q_head_idx_tensor = q_head_idx_tensor
self.q_tail_idx_tensor = q_tail_idx_tensor
q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor])
q_full_idx = q_full_idx.to(torch.float32).argsort().to(
torch.int32)
self.q_full_idx = q_full_idx
self.kv_idx_names = {
'kv_with_q_head_nomask_idx_tensor':
kv_with_q_head_nomask_idx,
'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx,
'kv_with_q_tail_nomask_idx_tensor':
kv_with_q_tail_nomask_idx,
'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx
}
for key, value in self.kv_idx_names.items():
tensor_npu = _list_to_tensor(value, self.device)
self.kv_idx_names[key] = tensor_npu
attn_mask_seqlens = torch.tensor(
[chunk_seqlens, chunk_seqlens], dtype=torch.int32)
head_attn_nomask_seqlens = torch.tensor(
[chunk_seqlens, kv_with_q_head_nomask_seqlens],
dtype=torch.int32)
tail_attn_nomask_seqlens = torch.tensor(
[chunk_seqlens, kv_with_q_tail_nomask_seqlens],
dtype=torch.int32)
pcp_prefill_mask = self.attn_mask
self.extra_long_seq_kwargs = {
'attn_mask_seqlens': attn_mask_seqlens,
'head_attn_nomask_seqlens': head_attn_nomask_seqlens,
'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens,
'pcp_prefill_mask': pcp_prefill_mask
}
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx[:
num_actual_tokens_pcp_padded]
long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
long_seq_metadata.q_full_idx = self.q_full_idx
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[
'kv_with_q_head_nomask_idx_tensor']
long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[
'kv_with_q_head_mask_idx_tensor']
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[
'kv_with_q_tail_nomask_idx_tensor']
long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[
'kv_with_q_tail_mask_idx_tensor']
long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[
'attn_mask_seqlens']
long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[
'head_attn_nomask_seqlens']
long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[
'tail_attn_nomask_seqlens']
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
'pcp_prefill_mask']
self.long_seq_metadata = long_seq_metadata
return long_seq_metadata
def _generate_pcp_mtp_input(
self,
num_reqs: int,
total_num_scheduled_tokens: int,
num_scheduled_tokens: dict[str, int],
with_prefill: bool = True,
req_indices=None,
positions_np=None,
cu_num_tokens=None,
):
"""
While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group,
but mtp need to shift original input_ids before pcp splitting,
so we record original input_ids here.
"""
total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens
num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32)
for i, req_id in enumerate(self.input_batch.req_ids):
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
self.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy(
num_scheduled_tokens_pcp_full)
req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens_pcp_full)
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
self.query_start_loc_pcp_full.np[0] = 0
self.query_start_loc_pcp_full.np[1:num_reqs +
1] = cu_num_tokens_pcp_full
self.query_start_loc_pcp_full.np[num_reqs + 1:].fill(-1)
cumsums_offsets_pcp_full = np.repeat(
cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full,
num_scheduled_tokens_pcp_full)
arange_pcp_full = self.arange_np[:
total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full
positions_pcp_full_np = self.positions_pcp_full_np[:
total_num_scheduled_tokens_pcp_full]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full],
arange_pcp_full,
out=positions_pcp_full_np)
token_indices_pcp_full = (
positions_pcp_full_np +
req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1])
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices_pcp_full),
out=self.input_ids_pcp_full.
cpu[:total_num_scheduled_tokens_pcp_full])
self.query_lens_pcp_full.copy_to_gpu()
self.query_start_loc_pcp_full.copy_to_gpu()
self.input_ids_pcp_full.gpu[:total_num_scheduled_tokens_pcp_full].copy_(
self.input_ids_pcp_full.cpu[:total_num_scheduled_tokens_pcp_full],
non_blocking=True,
)
self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full
# For mtpx, pre-allocate mtp slot_mapping here
if self.decode_threshold > 2 and not with_prefill:
num_tokens_ori = sum(list(num_scheduled_tokens.values()))
num_tokens_mtp = \
num_tokens_ori + num_reqs * (self.decode_threshold - 2)
num_tokens_mtp_pad = num_tokens_mtp * self.pcp_size
req_indices_split = np.array_split(req_indices,
cu_num_tokens)[:num_reqs]
positions_split = np.array_split(positions_np,
cu_num_tokens)[:num_reqs]
for req_idx in range(num_reqs):
ori_req_indice = req_indices_split[req_idx]
ori_position = positions_split[req_idx]
req_indices_split[req_idx] = np.append(
ori_req_indice,
np.repeat(ori_req_indice[-1], self.decode_threshold - 2))
positions_split[req_idx] = np.append(
ori_position,
np.arange(ori_position[-1] + 1,
ori_position[-1] + self.decode_threshold - 1))
req_indices_mtp = np.concatenate(req_indices_split)
positions_mtp = np.concatenate(positions_split)
self.input_batch.block_table.compute_slot_mapping(
req_indices_mtp, positions_mtp)
mtp_slot_ori = self.input_batch.block_table.block_tables[
0].slot_mapping.cpu[:num_tokens_mtp]
unpad_mask = np.repeat(False, num_tokens_mtp_pad)
unpad_mask[::self.pcp_size] = True
mtp_slot_pad = \
torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32)
mtp_slot_pad[unpad_mask] = mtp_slot_ori
self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True)
def _prepare_multimodal_fields(self):
"""
Ensures specific multimodal tensors are on CPU.