[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #9) (#6135)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|`vllm_ascend/worker/model_runner_v1.py`|
|`vllm_ascend/worker/pcp_utils.py`|

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
d68209402d

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-02-01 23:20:20 +08:00
committed by GitHub
parent f7dc7d9b86
commit 347eb36a59
3 changed files with 802 additions and 1041 deletions

View File

@@ -54,9 +54,7 @@ exclude = [
# (7)
"vllm_ascend/quantization/**",
"vllm_ascend/sample/*.py",
"vllm_ascend/worker/v2/**",
"vllm_ascend/worker/block_table.py",
"vllm_ascend/worker/npu_input_batch.py",
# (8)
"vllm_ascend/ops/__init__.py",
"vllm_ascend/ops/activation.py",
@@ -65,13 +63,9 @@ exclude = [
"vllm_ascend/ops/mla.py",
"vllm_ascend/ops/mm_encoder_attention.py",
"vllm_ascend/ops/register_custom_ops.py",
"vllm_ascend/ops/rotary_embedding.py",
"vllm_ascend/ops/vocab_parallel_embedding.py",
"vllm_ascend/ops/weight_prefetch.py",
"vllm_ascend/spec_decode/**",
# (9)
"vllm_ascend/worker/model_runner_v1.py",
"vllm_ascend/worker/pcp_utils.py",
# (10)
"vllm_ascend/ops/*linear*.py",
"vllm_ascend/worker/worker.py",
@@ -79,6 +73,9 @@ exclude = [
"vllm_ascend/distributed/utils.py",
"vllm_ascend/xlite/*.py",
"vllm_ascend/patch/worker/patch_*.py",
"vllm_ascend/worker/v2/**",
"vllm_ascend/worker/npu_input_batch.py",
"vllm_ascend/ops/rotary_embedding.py",
# (11)
"vllm_ascend/ops/fused_moe/**",
]

File diff suppressed because it is too large Load Diff

View File

@@ -17,12 +17,11 @@
# Adapted from vllm-project/vllm/vllm/worker/worker.py
#
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING
import numpy as np
import torch
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
if TYPE_CHECKING:
@@ -36,6 +35,7 @@ class PCPManager:
This manager encapsulates all PCP-related buffers and logic so that the
ModelRunner can access them via `self.pcp_manager`.
"""
num_reqs: int = 0
num_decode_reqs: int = 0
num_prefill_reqs: int = 0
@@ -59,9 +59,7 @@ class PCPManager:
self.dcp_world_size = dcp_world_size
self.dcp_world_rank = dcp_rank
self.speculative_config = vllm_config.speculative_config
self.decode_threshold = 1 + (
self.speculative_config.num_speculative_tokens
if self.speculative_config else 0)
self.decode_threshold = 1 + (self.speculative_config.num_speculative_tokens if self.speculative_config else 0)
self.vllm_config = vllm_config
self.max_num_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs
@@ -74,46 +72,42 @@ class PCPManager:
pin_memory=pin_memory,
)
self.pcp_padded_slot_mapping = torch.full(
(max_buffer_num_tokens, ),
(max_buffer_num_tokens,),
fill_value=-1,
dtype=torch.int32,
device=device,
)
self.pcp_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.total_num_sampled_tokens_pcp = 0
self.num_pcp_pads_cpu_tensor = torch.zeros((max_num_reqs, ),
device="cpu",
dtype=torch.int64)
self.num_pcp_pads_cpu_tensor = torch.zeros((max_num_reqs,), device="cpu", dtype=torch.int64)
self.num_pcp_pads_cpu = self.num_pcp_pads_cpu_tensor.numpy()
self.pcp_unpad_mask_cpu_tensor = torch.ones(
(max_buffer_num_tokens, ),
(max_buffer_num_tokens,),
device="cpu",
dtype=torch.bool,
)
self.num_actual_tokens_pcp_padded = 0
self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy()
self.full_indices = list(
range(self.max_num_tokens * self.pcp_world_size *
self.dcp_world_size + self.pcp_world_size *
self.dcp_world_size * self.max_num_reqs))
range(
self.max_num_tokens * self.pcp_world_size * self.dcp_world_size
+ self.pcp_world_size * self.dcp_world_size * self.max_num_reqs
)
)
if self.speculative_config and self.pcp_world_size * self.dcp_world_size > 1:
self.input_ids_pcp_full = CpuGpuBuffer(self.max_num_tokens,
dtype=torch.int32,
device=device,
pin_memory=pin_memory)
self.query_start_loc_pcp_full = CpuGpuBuffer(self.max_num_reqs + 1,
dtype=torch.int32,
device=device,
pin_memory=pin_memory)
self.positions_pcp_full = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=pin_memory)
self.input_ids_pcp_full = CpuGpuBuffer(
self.max_num_tokens, dtype=torch.int32, device=device, pin_memory=pin_memory
)
self.query_start_loc_pcp_full = CpuGpuBuffer(
self.max_num_reqs + 1, dtype=torch.int32, device=device, pin_memory=pin_memory
)
self.positions_pcp_full = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device="cpu", pin_memory=pin_memory
)
self.positions_pcp_full_np = self.positions_pcp_full.numpy()
self.query_lens_pcp_full = CpuGpuBuffer(self.max_num_reqs,
dtype=torch.int32,
device=device,
pin_memory=pin_memory)
self.query_lens_pcp_full = CpuGpuBuffer(
self.max_num_reqs, dtype=torch.int32, device=device, pin_memory=pin_memory
)
def _get_cumsum_and_arange(
self,
@@ -130,8 +124,7 @@ class PCPManager:
cu_num_tokens = np.cumsum(num_scheduled_tokens, dtype=cumsum_dtype)
total_num_tokens = cu_num_tokens[-1]
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
num_scheduled_tokens)
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, num_scheduled_tokens)
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
arange = arange_np[:total_num_tokens] - cumsums_offsets
@@ -143,15 +136,15 @@ class PCPManager:
num_reqs: int,
) -> None:
self.num_reqs = num_reqs
is_prefill = (num_scheduled_tokens[:num_reqs] > self.decode_threshold)
is_prefill = num_scheduled_tokens[:num_reqs] > self.decode_threshold
if not any(is_prefill):
first_prefill = num_reqs
else:
first_prefill = is_prefill.argmax()
self.num_decode_reqs = first_prefill
self.num_prefill_reqs = num_reqs - self.num_decode_reqs
self.num_decode_tokens = num_scheduled_tokens[:self.num_decode_reqs].sum()
self.num_decode_tokens = num_scheduled_tokens[: self.num_decode_reqs].sum()
def update_tokens_for_pcp(
self,
num_scheduled_tokens: np.ndarray,
@@ -208,32 +201,29 @@ class PCPManager:
# DualChunkSwap requires alignment to a multiple of (2 * pcp_world_size).
# We first pad each request's token count up to that multiple.
num_padded_scheduled_tokens = np.ceil(
num_scheduled_tokens / (2 * self.pcp_world_size)).astype(
np.int32) * (2 * self.pcp_world_size)
num_padded_scheduled_tokens = np.ceil(num_scheduled_tokens / (2 * self.pcp_world_size)).astype(np.int32) * (
2 * self.pcp_world_size
)
# PCP does not split decode requests. For decode requests, we instead
# duplicate the scheduled tokens across the pcp_world_size ranks.
num_padded_scheduled_tokens[:self.num_decode_reqs] = (
num_scheduled_tokens[:self.num_decode_reqs] * self.pcp_world_size)
num_padded_scheduled_tokens[: self.num_decode_reqs] = (
num_scheduled_tokens[: self.num_decode_reqs] * self.pcp_world_size
)
# Record how many pads were added per request (padded - original).
self.num_pcp_pads_cpu[:self.num_reqs] = (num_padded_scheduled_tokens -
num_scheduled_tokens)
self.num_pcp_pads_cpu[: self.num_reqs] = num_padded_scheduled_tokens - num_scheduled_tokens
# cu_padded_tokens: cumulative sum of padded token counts,
# pcp_padded_arange: per-request arange flattened for padded tokens.
cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange(
num_padded_scheduled_tokens, arange_np)
cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange(num_padded_scheduled_tokens, arange_np)
# Build the mask that marks which positions in the padded allgather buffer
# correspond to real (unpadded) tokens.
self.pcp_unpad_mask_cpu[:pcp_padded_arange.shape[0]] = (
pcp_padded_arange < np.repeat(num_scheduled_tokens,
num_padded_scheduled_tokens))
unpad_mask_decode = self.pcp_unpad_mask_cpu[:self.num_decode_tokens *
self.pcp_world_size]
unpad_mask_decode = unpad_mask_decode.reshape(
[-1, self.pcp_world_size])
self.pcp_unpad_mask_cpu[: pcp_padded_arange.shape[0]] = pcp_padded_arange < np.repeat(
num_scheduled_tokens, num_padded_scheduled_tokens
)
unpad_mask_decode = self.pcp_unpad_mask_cpu[: self.num_decode_tokens * self.pcp_world_size]
unpad_mask_decode = unpad_mask_decode.reshape([-1, self.pcp_world_size])
unpad_mask_decode[:, 0] = True
unpad_mask_decode[:, 1:] = False
pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size
@@ -242,23 +232,20 @@ class PCPManager:
# For prefill requests, we further split the pcp_tokens into two chunks
# (head and tail). For decode requests, the chunk equals pcp_tokens.
pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1)
pcp_chunk_sizes[:self.num_decode_reqs] = pcp_tokens[:self.num_decode_reqs]
pcp_chunk_sizes[: self.num_decode_reqs] = pcp_tokens[: self.num_decode_reqs]
# Build arange-style helpers for pcp tokens and chunk sizes:
# - pcp_arange gives indices repeated for each token in pcp_tokens
# - pcp_chunk_arange gives indices repeated for each position inside chunks
_, pcp_arange = self._get_cumsum_and_arange(pcp_tokens, arange_np)
_, pcp_chunk_arange = self._get_cumsum_and_arange(
pcp_chunk_sizes, arange_np)
_, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes, arange_np)
# Mask that marks whether a position belongs to the head chunk (True)
# or the tail chunk (False). For decode requests, tail chunk won't exist
# and is handled specially below.
pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes,
pcp_tokens)
pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, pcp_tokens)
def get_current_rank_positions(positions_start_loc: int | np.ndarray,
rank: int):
def get_current_rank_positions(positions_start_loc: int | np.ndarray, rank: int):
"""
Compute flattened positions for the given rank with a given start
offset for each request (positions_start_loc).
@@ -271,59 +258,53 @@ class PCPManager:
"""
positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32)
head_start_loc = positions_start_loc + rank * pcp_chunk_sizes
tail_start_loc = (
positions_start_loc +
(2 * self.pcp_world_size - rank - 1) * pcp_chunk_sizes)
tail_start_loc = positions_start_loc + (2 * self.pcp_world_size - rank - 1) * pcp_chunk_sizes
# Fill head positions using chunk arange offset by head_start_loc.
positions[pcp_head_chunk_mask] = pcp_chunk_arange + np.repeat(
head_start_loc, pcp_chunk_sizes)
positions[pcp_head_chunk_mask] = pcp_chunk_arange + np.repeat(head_start_loc, pcp_chunk_sizes)
# Fill tail positions. Note decode requests do not have tail chunks,
# so the tail filling is only for prefill positions.
positions[~pcp_head_chunk_mask] = (
pcp_chunk_arange[self.num_decode_tokens:] +
np.repeat(tail_start_loc, pcp_chunk_sizes)[self.num_decode_tokens:])
pcp_chunk_arange[self.num_decode_tokens :]
+ np.repeat(tail_start_loc, pcp_chunk_sizes)[self.num_decode_tokens :]
)
return positions
positions = get_current_rank_positions(0, self.pcp_world_rank)
# Decode tokens are duplicated only after AG. But their positions are
# same without prefill context parallel.
if self.num_decode_reqs > 0:
positions[:self.num_decode_tokens] = self._get_cumsum_and_arange(
num_scheduled_tokens[:self.num_decode_reqs], arange_np)[1]
positions[: self.num_decode_tokens] = self._get_cumsum_and_arange(
num_scheduled_tokens[: self.num_decode_reqs], arange_np
)[1]
# Build the restore index used after allgather.
padded_pos_start_loc = np.roll(cu_padded_tokens, 1)
padded_pos_start_loc[0] = 0
all_positions_lst = [
get_current_rank_positions(padded_pos_start_loc, rank_i)
for rank_i in range(self.pcp_world_size)
get_current_rank_positions(padded_pos_start_loc, rank_i) for rank_i in range(self.pcp_world_size)
]
all_positions = np.concatenate(all_positions_lst)
self.pcp_allgather_restore_idx.np[:all_positions.shape[0]] = (
all_positions.argsort())
self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = all_positions.argsort()
self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0])
self.pcp_tokens[:self.num_reqs] = pcp_tokens[:self.num_reqs]
self.total_num_sampled_tokens_pcp = pcp_tokens[:self.num_reqs].sum()
self.pcp_tokens[: self.num_reqs] = pcp_tokens[: self.num_reqs]
self.total_num_sampled_tokens_pcp = pcp_tokens[: self.num_reqs].sum()
return (
pcp_tokens[:self.num_reqs],
pcp_tokens[: self.num_reqs],
positions,
)
def get_logits_indices(self, cu_num_tokens: np.ndarray):
return (torch.from_numpy(cu_num_tokens) * self.pcp_world_size -
self.num_pcp_pads_cpu_tensor[:self.num_reqs] - 1)
return torch.from_numpy(cu_num_tokens) * self.pcp_world_size - self.num_pcp_pads_cpu_tensor[: self.num_reqs] - 1
def get_padded_slot_mapping(self, num_tokens: int, num_tokens_padded: int,
slot_mapping: torch.Tensor):
def get_padded_slot_mapping(self, num_tokens: int, num_tokens_padded: int, slot_mapping: torch.Tensor):
# After pcp allgather and restore, there are padded tokens in kv,
# so we need pad slotmapping for alignment.
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:num_tokens_padded * self.pcp_world_size]
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[: num_tokens_padded * self.pcp_world_size]
cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[:num_tokens *
self.pcp_world_size]
cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[: num_tokens * self.pcp_world_size]
pcp_padded_slot_mapping.fill_(-1)
pcp_padded_slot_mapping[:num_tokens * self.pcp_world_size][cp_unpad_mask] = slot_mapping
pcp_padded_slot_mapping[: num_tokens * self.pcp_world_size][cp_unpad_mask] = slot_mapping
return pcp_padded_slot_mapping
def get_restore_hidden_states(
@@ -333,13 +314,12 @@ class PCPManager:
# NOTE we must `slice` hidden_states because pcp_allgather_restore_idx
# ignores the padding from CUDA Graph.
from vllm.distributed.parallel_state import get_pcp_group
hidden_states = get_pcp_group().all_gather(
hidden_states[:self.num_actual_tokens_pcp_padded //
self.pcp_world_size],
hidden_states[: self.num_actual_tokens_pcp_padded // self.pcp_world_size],
0,
)
restore_idx = self.pcp_allgather_restore_idx.gpu[:hidden_states.
shape[0]]
restore_idx = self.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]]
return torch.index_select(
hidden_states,
0,
@@ -369,73 +349,61 @@ class PCPManager:
num_scheduled_tokens_pcp_full = np.empty(self.num_reqs, dtype=np.int32)
for i, req_id in enumerate(input_batch.req_ids):
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
self.query_lens_pcp_full.cpu[:self.num_reqs] = torch.from_numpy(
num_scheduled_tokens_pcp_full)
req_indices_pcp_full = np.repeat(arange_np[:self.num_reqs],
num_scheduled_tokens_pcp_full)
self.query_lens_pcp_full.cpu[: self.num_reqs] = torch.from_numpy(num_scheduled_tokens_pcp_full)
req_indices_pcp_full = np.repeat(arange_np[: self.num_reqs], num_scheduled_tokens_pcp_full)
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
self.query_start_loc_pcp_full.np[0] = 0
self.query_start_loc_pcp_full.np[1:self.num_reqs +
1] = cu_num_tokens_pcp_full
self.query_start_loc_pcp_full.np[self.num_reqs + 1:].fill(-1)
self.query_start_loc_pcp_full.np[1 : self.num_reqs + 1] = cu_num_tokens_pcp_full
self.query_start_loc_pcp_full.np[self.num_reqs + 1 :].fill(-1)
cumsums_offsets_pcp_full = np.repeat(
cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full,
num_scheduled_tokens_pcp_full)
cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full, num_scheduled_tokens_pcp_full
)
arange_pcp_full = arange_np[:total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full
positions_pcp_full_np = self.positions_pcp_full_np[:
total_num_scheduled_tokens_pcp_full]
np.add(input_batch.num_computed_tokens_cpu[req_indices_pcp_full],
arange_pcp_full,
out=positions_pcp_full_np)
token_indices_pcp_full = (
positions_pcp_full_np +
req_indices_pcp_full * input_batch.token_ids_cpu.shape[1])
torch.index_select(input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices_pcp_full),
out=self.input_ids_pcp_full.
cpu[:total_num_scheduled_tokens_pcp_full])
positions_pcp_full_np = self.positions_pcp_full_np[:total_num_scheduled_tokens_pcp_full]
np.add(input_batch.num_computed_tokens_cpu[req_indices_pcp_full], arange_pcp_full, out=positions_pcp_full_np)
token_indices_pcp_full = positions_pcp_full_np + req_indices_pcp_full * input_batch.token_ids_cpu.shape[1]
torch.index_select(
input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices_pcp_full),
out=self.input_ids_pcp_full.cpu[:total_num_scheduled_tokens_pcp_full],
)
if self.use_async_scheduling:
self._update_input_ids_pcp_full_ids(input_batch, draft_token_ids,
scheduler_output,
total_num_scheduled_tokens,
cu_num_tokens_pcp_full,
num_spec_tokens)
self._update_input_ids_pcp_full_ids(
input_batch,
draft_token_ids,
scheduler_output,
total_num_scheduled_tokens,
cu_num_tokens_pcp_full,
num_spec_tokens,
)
self.query_lens_pcp_full.copy_to_gpu()
self.query_start_loc_pcp_full.copy_to_gpu()
self.input_ids_pcp_full.copy_to_gpu(
total_num_scheduled_tokens_pcp_full)
self.input_ids_pcp_full.copy_to_gpu(total_num_scheduled_tokens_pcp_full)
self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full
# For mtpx, pre-allocate mtp slot_mapping here
if self.decode_threshold > 2 and not with_prefill:
num_tokens_ori = sum(list(num_scheduled_tokens.values()))
num_tokens_mtp = \
num_tokens_ori + self.num_reqs * (self.decode_threshold - 2)
num_tokens_mtp = num_tokens_ori + self.num_reqs * (self.decode_threshold - 2)
num_tokens_mtp_pad = num_tokens_mtp * self.pcp_world_size
req_indices_split = np.array_split(req_indices,
cu_num_tokens)[:self.num_reqs]
positions_split = np.array_split(positions_np,
cu_num_tokens)[:self.num_reqs]
req_indices_split = np.array_split(req_indices, cu_num_tokens)[: self.num_reqs]
positions_split = np.array_split(positions_np, cu_num_tokens)[: self.num_reqs]
for req_idx in range(self.num_reqs):
ori_req_indice = req_indices_split[req_idx]
ori_position = positions_split[req_idx]
req_indices_split[req_idx] = np.append(
ori_req_indice,
np.repeat(ori_req_indice[-1], self.decode_threshold - 2))
ori_req_indice, np.repeat(ori_req_indice[-1], self.decode_threshold - 2)
)
positions_split[req_idx] = np.append(
ori_position,
np.arange(ori_position[-1] + 1,
ori_position[-1] + self.decode_threshold - 1))
ori_position, np.arange(ori_position[-1] + 1, ori_position[-1] + self.decode_threshold - 1)
)
req_indices_mtp = np.concatenate(req_indices_split)
positions_mtp = np.concatenate(positions_split)
input_batch.block_table.compute_slot_mapping(
req_indices_mtp, positions_mtp)
mtp_slot_ori = input_batch.block_table.block_tables[
0].slot_mapping.cpu[:num_tokens_mtp]
input_batch.block_table.compute_slot_mapping(req_indices_mtp, positions_mtp)
mtp_slot_ori = input_batch.block_table.block_tables[0].slot_mapping.cpu[:num_tokens_mtp]
unpad_mask = np.repeat(False, num_tokens_mtp_pad)
unpad_mask[::self.pcp_world_size] = True
mtp_slot_pad = \
torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32)
unpad_mask[:: self.pcp_world_size] = True
mtp_slot_pad = torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32)
mtp_slot_pad[unpad_mask] = mtp_slot_ori
self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True)
@@ -454,8 +422,7 @@ class PCPManager:
from the previous engine iteration, in which case those tokens on the
GPU need to be copied into the corresponding slots into input_ids."""
if (input_batch.prev_sampled_token_ids is None
or input_batch.prev_req_id_to_index is None):
if input_batch.prev_sampled_token_ids is None or input_batch.prev_req_id_to_index is None:
return
# Async scheduling case, where some decode requests from the previous
@@ -481,9 +448,7 @@ class PCPManager:
# sample_flattened_indices = [0, 2, 5]
# spec_flattened_indices = [1, 3, 4, 6, 7]
sample_flattened_indices.append(flattened_index - draft_len)
spec_flattened_indices.extend(
range(flattened_index - draft_len + 1,
flattened_index + 1))
spec_flattened_indices.extend(range(flattened_index - draft_len + 1, flattened_index + 1))
start = prev_index * num_spec_tokens
# prev_draft_token_indices is used to find which draft_tokens_id
# should be copied to input_ids
@@ -491,8 +456,7 @@ class PCPManager:
# flatten draft_tokens_id [1,2,3,4,5,6]
# draft_len of each request [1, 2, 1]
# then prev_draft_token_indices is [0, 2, 3, 4]
prev_draft_token_indices.extend(range(start,
start + draft_len))
prev_draft_token_indices.extend(range(start, start + draft_len))
num_commmon_tokens = len(sample_flattened_indices)
if num_commmon_tokens == 0:
@@ -500,15 +464,12 @@ class PCPManager:
# So input_ids.cpu will have all the input ids.
return
# Upload the index tensors asynchronously so the scatter can be non-blocking.
sampled_tokens_index_tensor = torch.tensor(sample_flattened_indices,
dtype=torch.int64)
prev_common_req_indices_tensor = torch.tensor(prev_common_req_indices,
dtype=torch.int64)
sampled_tokens_index_tensor = torch.tensor(sample_flattened_indices, dtype=torch.int64)
prev_common_req_indices_tensor = torch.tensor(prev_common_req_indices, dtype=torch.int64)
self.input_ids_pcp_full.cpu.scatter_(
dim=0,
index=sampled_tokens_index_tensor,
src=input_batch.prev_sampled_token_ids[
prev_common_req_indices_tensor, 0].cpu(),
src=input_batch.prev_sampled_token_ids[prev_common_req_indices_tensor, 0].cpu(),
)
# Scatter the draft tokens after the sampled tokens are scattered.
@@ -516,10 +477,8 @@ class PCPManager:
return
assert isinstance(draft_token_ids, torch.Tensor)
draft_tokens_index_tensor = torch.tensor(spec_flattened_indices,
dtype=torch.int64)
prev_draft_token_indices_tensor = torch.tensor(
prev_draft_token_indices, dtype=torch.int64)
draft_tokens_index_tensor = torch.tensor(spec_flattened_indices, dtype=torch.int64)
prev_draft_token_indices_tensor = torch.tensor(prev_draft_token_indices, dtype=torch.int64)
# because input_ids dtype is torch.int32,
# so convert draft_token_ids to torch.int32 here.
@@ -528,8 +487,7 @@ class PCPManager:
self.input_ids_pcp_full.cpu.scatter_(
dim=0,
index=draft_tokens_index_tensor,
src=draft_token_ids.flatten()
[prev_draft_token_indices_tensor].cpu(),
src=draft_token_ids.flatten()[prev_draft_token_indices_tensor].cpu(),
)
def _get_cp_local_seq_lens(
@@ -545,41 +503,32 @@ class PCPManager:
num_requests = seq_lens.size(0)
total_world_size = pcp_world_size * dcp_world_size
seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size)
rank_offsets = (torch.arange(total_world_size,
dtype=torch.int32).unsqueeze(0).repeat(
num_requests, 1))
base = (seq_lens_tiled // cp_kv_cache_interleave_size //
total_world_size * cp_kv_cache_interleave_size)
rank_offsets = torch.arange(total_world_size, dtype=torch.int32).unsqueeze(0).repeat(num_requests, 1)
base = seq_lens_tiled // cp_kv_cache_interleave_size // total_world_size * cp_kv_cache_interleave_size
remainder = seq_lens_tiled - base * total_world_size
remainder = torch.clip(
remainder - rank_offsets * cp_kv_cache_interleave_size,
0,
cp_kv_cache_interleave_size,
)
dcp_local_seq_lens = (base + remainder).reshape(
[-1, pcp_world_size, dcp_world_size])
dcp_local_seq_lens = (base + remainder).reshape([-1, pcp_world_size, dcp_world_size])
return dcp_local_seq_lens
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens,
input_batch, num_scheduled_tokens):
from vllm_ascend.attention.utils import \
AscendPrefillContextParallelMetadata
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, input_batch, num_scheduled_tokens):
from vllm_ascend.attention.utils import AscendPrefillContextParallelMetadata
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
long_seq_metadata = None
if self.pcp_world_size * self.dcp_world_size > 1:
decode_context_lens = input_batch.num_computed_tokens_cpu[:
self.num_decode_reqs] + num_scheduled_tokens[:
self.num_decode_reqs]
prefill_context_lens = input_batch.num_computed_tokens_cpu[
self.num_decode_reqs:self.num_reqs]
context_lens = np.concatenate(
[decode_context_lens, prefill_context_lens])
decode_context_lens = (
input_batch.num_computed_tokens_cpu[: self.num_decode_reqs]
+ num_scheduled_tokens[: self.num_decode_reqs]
)
prefill_context_lens = input_batch.num_computed_tokens_cpu[self.num_decode_reqs : self.num_reqs]
context_lens = np.concatenate([decode_context_lens, prefill_context_lens])
num_computed_tokens_of_pcp_dcp = torch.zeros(
[
self.num_reqs * self.decode_threshold, self.pcp_world_size,
self.dcp_world_size
],
[self.num_reqs * self.decode_threshold, self.pcp_world_size, self.dcp_world_size],
dtype=torch.int32,
)
# For pcp + spec decode, we flatten seq_lens
@@ -587,41 +536,37 @@ class PCPManager:
# Same as block_table, we flatten decode seq_lens to query_lens,
# and keep prefill seq_lens unchanged.
for decode_idx in range(self.decode_threshold):
num_computed_tokens_of_pcp_dcp[
self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \
num_computed_tokens_of_pcp_dcp[self.decode_threshold - 1 - decode_idx :: self.decode_threshold] = (
self._get_cp_local_seq_lens(
torch.tensor(context_lens) - decode_idx,
self.pcp_world_size,
self.dcp_world_size,
self.vllm_config.parallel_config.cp_kv_cache_interleave_size,
)
)
if self.decode_threshold > 1:
num_computed_tokens_of_pcp_dcp_list = []
if self.num_decode_reqs:
num_decodes_flatten = \
query_lens[:self.num_decode_reqs].sum().item()
if query_lens[:self.num_decode_reqs].min().item(
) == self.decode_threshold:
num_decodes_flatten = query_lens[: self.num_decode_reqs].sum().item()
if query_lens[: self.num_decode_reqs].min().item() == self.decode_threshold:
decode_flatten_idx = list(range(num_decodes_flatten))
else:
decode_flatten_idx = []
for req_id in range(self.num_decode_reqs):
offset = (req_id + 1) * self.decode_threshold
decode_flatten_idx += \
list(range(offset - query_lens[req_id], offset))
num_computed_tokens_of_pcp_dcp_list.append(
num_computed_tokens_of_pcp_dcp[decode_flatten_idx])
decode_flatten_idx += list(range(offset - query_lens[req_id], offset))
num_computed_tokens_of_pcp_dcp_list.append(num_computed_tokens_of_pcp_dcp[decode_flatten_idx])
if self.num_prefill_reqs:
num_computed_tokens_of_pcp_dcp_list.append(
num_computed_tokens_of_pcp_dcp[
(self.num_decode_reqs + 1) * self.decode_threshold -
1::self.decode_threshold])
num_computed_tokens_of_pcp_dcp = torch.cat(
num_computed_tokens_of_pcp_dcp_list, dim=0)
(self.num_decode_reqs + 1) * self.decode_threshold - 1 :: self.decode_threshold
]
)
num_computed_tokens_of_pcp_dcp = torch.cat(num_computed_tokens_of_pcp_dcp_list, dim=0)
long_seq_metadata = AscendPrefillContextParallelMetadata(
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.
numpy())
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.numpy(),
)
if self.pcp_world_size > 1:
q_head_idx, q_tail_idx = [], []
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
@@ -639,109 +584,102 @@ class PCPManager:
continue
chunk_len = seq_len // 2
chunk_seqlens.append(chunk_len)
q_head_idx.extend(
list(range(q_req_offset, q_req_offset + chunk_len)))
q_head_idx.extend(list(range(q_req_offset, q_req_offset + chunk_len)))
kv_with_q_head_nomask_idx.extend(
list(
range(kv_req_offset, kv_req_offset +
chunk_len * q_head_chunk_id)))
list(range(kv_req_offset, kv_req_offset + chunk_len * q_head_chunk_id))
)
kv_with_q_head_mask_idx.extend(
list(
range(
kv_req_offset + chunk_len * q_head_chunk_id,
kv_req_offset + chunk_len *
(q_head_chunk_id + 1))))
kv_with_q_head_nomask_seqlens.append(chunk_len *
q_head_chunk_id)
kv_req_offset + chunk_len * (q_head_chunk_id + 1),
)
)
)
kv_with_q_head_nomask_seqlens.append(chunk_len * q_head_chunk_id)
split_with_q_head_nomask_idx_reqs.append(
list(
range(kv_req_offset, kv_req_offset +
chunk_len * q_head_chunk_id)))
q_tail_idx.extend(
list(
range(q_req_offset + chunk_len,
q_req_offset + chunk_len * 2)))
list(range(kv_req_offset, kv_req_offset + chunk_len * q_head_chunk_id))
)
q_tail_idx.extend(list(range(q_req_offset + chunk_len, q_req_offset + chunk_len * 2)))
kv_with_q_tail_nomask_idx.extend(
list(
range(kv_req_offset, kv_req_offset +
chunk_len * q_tail_chunk_id)))
list(range(kv_req_offset, kv_req_offset + chunk_len * q_tail_chunk_id))
)
kv_with_q_tail_mask_idx.extend(
list(
range(
kv_req_offset + chunk_len * q_tail_chunk_id,
kv_req_offset + chunk_len *
(q_tail_chunk_id + 1))))
kv_with_q_tail_nomask_seqlens.append(chunk_len *
q_tail_chunk_id)
kv_req_offset + chunk_len * (q_tail_chunk_id + 1),
)
)
)
kv_with_q_tail_nomask_seqlens.append(chunk_len * q_tail_chunk_id)
split_kv_with_q_tail_nomask_idx_reqs.append(
list(
range(kv_req_offset, kv_req_offset +
chunk_len * q_tail_chunk_id)))
list(range(kv_req_offset, kv_req_offset + chunk_len * q_tail_chunk_id))
)
q_req_offset += seq_len
kv_req_offset += seq_len * self.pcp_world_size
q_head_idx_tensor = self._list_to_tensor(
q_head_idx, self.device)
q_tail_idx_tensor = self._list_to_tensor(
q_tail_idx, self.device)
q_head_idx_tensor = self._list_to_tensor(q_head_idx, self.device)
q_tail_idx_tensor = self._list_to_tensor(q_tail_idx, self.device)
self.q_head_idx_tensor = q_head_idx_tensor
self.q_tail_idx_tensor = q_tail_idx_tensor
q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor])
q_full_idx = q_full_idx.to(torch.float32).argsort().to(
torch.int32)
q_full_idx = q_full_idx.to(torch.float32).argsort().to(torch.int32)
self.q_full_idx = q_full_idx
self.kv_idx_names = {
'kv_with_q_head_nomask_idx_tensor':
kv_with_q_head_nomask_idx,
'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx,
'kv_with_q_tail_nomask_idx_tensor':
kv_with_q_tail_nomask_idx,
'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx
"kv_with_q_head_nomask_idx_tensor": kv_with_q_head_nomask_idx,
"kv_with_q_head_mask_idx_tensor": kv_with_q_head_mask_idx,
"kv_with_q_tail_nomask_idx_tensor": kv_with_q_tail_nomask_idx,
"kv_with_q_tail_mask_idx_tensor": kv_with_q_tail_mask_idx,
}
for key, value in self.kv_idx_names.items():
tensor_npu = self._list_to_tensor(value, self.device)
self.kv_idx_names[key] = tensor_npu
attn_mask_seqlens = torch.tensor(
[chunk_seqlens, chunk_seqlens], dtype=torch.int32)
attn_mask_seqlens = torch.tensor([chunk_seqlens, chunk_seqlens], dtype=torch.int32)
head_attn_nomask_seqlens = torch.tensor(
[chunk_seqlens, kv_with_q_head_nomask_seqlens],
dtype=torch.int32)
[chunk_seqlens, kv_with_q_head_nomask_seqlens], dtype=torch.int32
)
tail_attn_nomask_seqlens = torch.tensor(
[chunk_seqlens, kv_with_q_tail_nomask_seqlens],
dtype=torch.int32)
[chunk_seqlens, kv_with_q_tail_nomask_seqlens], dtype=torch.int32
)
if self.vllm_config.model_config.use_mla:
split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list, head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list = self._split_nomask_idx_tensor_list(
(
split_q_head_nomask_idx_tensor_list,
split_q_tail_nomask_idx_tensor_list,
head_attn_nomask_seqlens_list,
tail_attn_nomask_seqlens_list,
) = self._split_nomask_idx_tensor_list(
split_with_q_head_nomask_idx_reqs,
split_kv_with_q_tail_nomask_idx_reqs,
head_attn_nomask_seqlens, chunk_seqlens)
head_attn_nomask_seqlens,
chunk_seqlens,
)
self.extra_long_seq_kwargs = {
'attn_mask_seqlens': attn_mask_seqlens,
'head_attn_nomask_seqlens': head_attn_nomask_seqlens,
'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens
"attn_mask_seqlens": attn_mask_seqlens,
"head_attn_nomask_seqlens": head_attn_nomask_seqlens,
"tail_attn_nomask_seqlens": tail_attn_nomask_seqlens,
}
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[:
num_actual_tokens_pcp_padded]
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[
:num_actual_tokens_pcp_padded
]
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
long_seq_metadata.q_full_idx = self.q_full_idx
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[
'kv_with_q_head_nomask_idx_tensor']
long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[
'kv_with_q_head_mask_idx_tensor']
"kv_with_q_head_nomask_idx_tensor"
]
long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names["kv_with_q_head_mask_idx_tensor"]
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[
'kv_with_q_tail_nomask_idx_tensor']
long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[
'kv_with_q_tail_mask_idx_tensor']
long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[
'attn_mask_seqlens']
long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[
'head_attn_nomask_seqlens']
long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[
'tail_attn_nomask_seqlens']
"kv_with_q_tail_nomask_idx_tensor"
]
long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names["kv_with_q_tail_mask_idx_tensor"]
long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs["attn_mask_seqlens"]
long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs["head_attn_nomask_seqlens"]
long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs["tail_attn_nomask_seqlens"]
if self.vllm_config.model_config.use_mla:
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = split_q_head_nomask_idx_tensor_list
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list
@@ -755,46 +693,53 @@ class PCPManager:
tensor_npu.copy_(torch.tensor(lst, dtype=dtype), non_blocking=True)
return tensor_npu
def _split_nomask_idx_tensor_list(self, split_with_q_head_nomask_idx_reqs,
split_kv_with_q_tail_nomask_idx_reqs,
head_attn_nomask_seqlens, chunk_seqlens):
split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list= [], []
def _split_nomask_idx_tensor_list(
self,
split_with_q_head_nomask_idx_reqs,
split_kv_with_q_tail_nomask_idx_reqs,
head_attn_nomask_seqlens,
chunk_seqlens,
):
split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list = [], []
head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list = [], []
if split_with_q_head_nomask_idx_reqs:
#In long-sequence scenarios, the computational cost and latency
#of the _npu_ring_mla operator are not proportional, so we split
#long sequences into shorter ones to improve performance.
# In long-sequence scenarios, the computational cost and latency
# of the _npu_ring_mla operator are not proportional, so we split
# long sequences into shorter ones to improve performance.
split_size = 16 * 1024
if self.pcp_world_rank == 0:
split_q_head_nomask_idx_list = [
self.kv_idx_names['kv_with_q_head_nomask_idx_tensor']
]
split_q_head_nomask_idx_list = [self.kv_idx_names["kv_with_q_head_nomask_idx_tensor"]]
else:
split_q_head_nomask_idx_list, split_q_head_nomask_lens_list = self._split_multi_batch_kv_idx(
split_with_q_head_nomask_idx_reqs, split_size)
split_with_q_head_nomask_idx_reqs, split_size
)
split_q_tail_nomask_idx_list, split_q_tail_nomask_lens_list = self._split_multi_batch_kv_idx(
split_kv_with_q_tail_nomask_idx_reqs, split_size)
split_kv_with_q_tail_nomask_idx_reqs, split_size
)
for q_head_nomask_idx in split_q_head_nomask_idx_list:
split_q_head_nomask_idx_tensor_list.append(
self._list_to_tensor(q_head_nomask_idx, self.device))
split_q_head_nomask_idx_tensor_list.append(self._list_to_tensor(q_head_nomask_idx, self.device))
for q_tail_nomask_idx in split_q_tail_nomask_idx_list:
split_q_tail_nomask_idx_tensor_list.append(
self._list_to_tensor(q_tail_nomask_idx, self.device))
split_q_tail_nomask_idx_tensor_list.append(self._list_to_tensor(q_tail_nomask_idx, self.device))
if self.pcp_world_rank == 0:
head_attn_nomask_seqlens_list = [head_attn_nomask_seqlens]
else:
for q_head_nomask_lens in split_q_head_nomask_lens_list:
head_attn_nomask_seqlens_list.append(
torch.tensor([chunk_seqlens, q_head_nomask_lens],
dtype=torch.int32))
torch.tensor([chunk_seqlens, q_head_nomask_lens], dtype=torch.int32)
)
for q_tail_nomask_lens in split_q_tail_nomask_lens_list:
tail_attn_nomask_seqlens_list.append(
torch.tensor([chunk_seqlens, q_tail_nomask_lens],
dtype=torch.int32))
return split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list, head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list
torch.tensor([chunk_seqlens, q_tail_nomask_lens], dtype=torch.int32)
)
return (
split_q_head_nomask_idx_tensor_list,
split_q_tail_nomask_idx_tensor_list,
head_attn_nomask_seqlens_list,
tail_attn_nomask_seqlens_list,
)
def _split_multi_batch_kv_idx(
self,
@@ -813,7 +758,7 @@ class PCPManager:
current_batch_len = []
for t in range(time):
start = t * split_size
current_segment = single_batch[start:start + split_size]
current_segment = single_batch[start : start + split_size]
current_batch_split.append(current_segment)
current_batch_len.append(len(current_segment))
@@ -829,8 +774,9 @@ class PCPManager:
def reshape_kv_len_to_time_first(split_kv_len_2d):
if not split_kv_len_2d or not split_kv_len_2d[0]:
return []
return [[batch_len[time_idx] for batch_len in split_kv_len_2d]
for time_idx in range(len(split_kv_len_2d[0]))]
return [
[batch_len[time_idx] for batch_len in split_kv_len_2d] for time_idx in range(len(split_kv_len_2d[0]))
]
merged_split_kv_len_2d = reshape_kv_len_to_time_first(split_kv_len_2d)
return merged_split_kv_idx_3d, merged_split_kv_len_2d