[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #9) (#6135)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|`vllm_ascend/worker/model_runner_v1.py`|
|`vllm_ascend/worker/pcp_utils.py`|

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
d68209402d

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-02-01 23:20:20 +08:00
committed by GitHub
parent f7dc7d9b86
commit 347eb36a59
3 changed files with 802 additions and 1041 deletions

View File

@@ -54,9 +54,7 @@ exclude = [
# (7) # (7)
"vllm_ascend/quantization/**", "vllm_ascend/quantization/**",
"vllm_ascend/sample/*.py", "vllm_ascend/sample/*.py",
"vllm_ascend/worker/v2/**",
"vllm_ascend/worker/block_table.py", "vllm_ascend/worker/block_table.py",
"vllm_ascend/worker/npu_input_batch.py",
# (8) # (8)
"vllm_ascend/ops/__init__.py", "vllm_ascend/ops/__init__.py",
"vllm_ascend/ops/activation.py", "vllm_ascend/ops/activation.py",
@@ -65,13 +63,9 @@ exclude = [
"vllm_ascend/ops/mla.py", "vllm_ascend/ops/mla.py",
"vllm_ascend/ops/mm_encoder_attention.py", "vllm_ascend/ops/mm_encoder_attention.py",
"vllm_ascend/ops/register_custom_ops.py", "vllm_ascend/ops/register_custom_ops.py",
"vllm_ascend/ops/rotary_embedding.py",
"vllm_ascend/ops/vocab_parallel_embedding.py", "vllm_ascend/ops/vocab_parallel_embedding.py",
"vllm_ascend/ops/weight_prefetch.py", "vllm_ascend/ops/weight_prefetch.py",
"vllm_ascend/spec_decode/**", "vllm_ascend/spec_decode/**",
# (9)
"vllm_ascend/worker/model_runner_v1.py",
"vllm_ascend/worker/pcp_utils.py",
# (10) # (10)
"vllm_ascend/ops/*linear*.py", "vllm_ascend/ops/*linear*.py",
"vllm_ascend/worker/worker.py", "vllm_ascend/worker/worker.py",
@@ -79,6 +73,9 @@ exclude = [
"vllm_ascend/distributed/utils.py", "vllm_ascend/distributed/utils.py",
"vllm_ascend/xlite/*.py", "vllm_ascend/xlite/*.py",
"vllm_ascend/patch/worker/patch_*.py", "vllm_ascend/patch/worker/patch_*.py",
"vllm_ascend/worker/v2/**",
"vllm_ascend/worker/npu_input_batch.py",
"vllm_ascend/ops/rotary_embedding.py",
# (11) # (11)
"vllm_ascend/ops/fused_moe/**", "vllm_ascend/ops/fused_moe/**",
] ]

File diff suppressed because it is too large Load Diff

View File

@@ -17,12 +17,11 @@
# Adapted from vllm-project/vllm/vllm/worker/worker.py # Adapted from vllm-project/vllm/vllm/worker/worker.py
# #
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING
import numpy as np import numpy as np
import torch import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -36,6 +35,7 @@ class PCPManager:
This manager encapsulates all PCP-related buffers and logic so that the This manager encapsulates all PCP-related buffers and logic so that the
ModelRunner can access them via `self.pcp_manager`. ModelRunner can access them via `self.pcp_manager`.
""" """
num_reqs: int = 0 num_reqs: int = 0
num_decode_reqs: int = 0 num_decode_reqs: int = 0
num_prefill_reqs: int = 0 num_prefill_reqs: int = 0
@@ -59,9 +59,7 @@ class PCPManager:
self.dcp_world_size = dcp_world_size self.dcp_world_size = dcp_world_size
self.dcp_world_rank = dcp_rank self.dcp_world_rank = dcp_rank
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.decode_threshold = 1 + ( self.decode_threshold = 1 + (self.speculative_config.num_speculative_tokens if self.speculative_config else 0)
self.speculative_config.num_speculative_tokens
if self.speculative_config else 0)
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.max_num_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens self.max_num_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs self.max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs
@@ -74,46 +72,42 @@ class PCPManager:
pin_memory=pin_memory, pin_memory=pin_memory,
) )
self.pcp_padded_slot_mapping = torch.full( self.pcp_padded_slot_mapping = torch.full(
(max_buffer_num_tokens, ), (max_buffer_num_tokens,),
fill_value=-1, fill_value=-1,
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
self.pcp_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) self.pcp_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
self.total_num_sampled_tokens_pcp = 0 self.total_num_sampled_tokens_pcp = 0
self.num_pcp_pads_cpu_tensor = torch.zeros((max_num_reqs, ), self.num_pcp_pads_cpu_tensor = torch.zeros((max_num_reqs,), device="cpu", dtype=torch.int64)
device="cpu",
dtype=torch.int64)
self.num_pcp_pads_cpu = self.num_pcp_pads_cpu_tensor.numpy() self.num_pcp_pads_cpu = self.num_pcp_pads_cpu_tensor.numpy()
self.pcp_unpad_mask_cpu_tensor = torch.ones( self.pcp_unpad_mask_cpu_tensor = torch.ones(
(max_buffer_num_tokens, ), (max_buffer_num_tokens,),
device="cpu", device="cpu",
dtype=torch.bool, dtype=torch.bool,
) )
self.num_actual_tokens_pcp_padded = 0 self.num_actual_tokens_pcp_padded = 0
self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy()
self.full_indices = list( self.full_indices = list(
range(self.max_num_tokens * self.pcp_world_size * range(
self.dcp_world_size + self.pcp_world_size * self.max_num_tokens * self.pcp_world_size * self.dcp_world_size
self.dcp_world_size * self.max_num_reqs)) + self.pcp_world_size * self.dcp_world_size * self.max_num_reqs
)
)
if self.speculative_config and self.pcp_world_size * self.dcp_world_size > 1: if self.speculative_config and self.pcp_world_size * self.dcp_world_size > 1:
self.input_ids_pcp_full = CpuGpuBuffer(self.max_num_tokens, self.input_ids_pcp_full = CpuGpuBuffer(
dtype=torch.int32, self.max_num_tokens, dtype=torch.int32, device=device, pin_memory=pin_memory
device=device, )
pin_memory=pin_memory) self.query_start_loc_pcp_full = CpuGpuBuffer(
self.query_start_loc_pcp_full = CpuGpuBuffer(self.max_num_reqs + 1, self.max_num_reqs + 1, dtype=torch.int32, device=device, pin_memory=pin_memory
dtype=torch.int32, )
device=device, self.positions_pcp_full = torch.zeros(
pin_memory=pin_memory) self.max_num_tokens, dtype=torch.int64, device="cpu", pin_memory=pin_memory
self.positions_pcp_full = torch.zeros(self.max_num_tokens, )
dtype=torch.int64,
device="cpu",
pin_memory=pin_memory)
self.positions_pcp_full_np = self.positions_pcp_full.numpy() self.positions_pcp_full_np = self.positions_pcp_full.numpy()
self.query_lens_pcp_full = CpuGpuBuffer(self.max_num_reqs, self.query_lens_pcp_full = CpuGpuBuffer(
dtype=torch.int32, self.max_num_reqs, dtype=torch.int32, device=device, pin_memory=pin_memory
device=device, )
pin_memory=pin_memory)
def _get_cumsum_and_arange( def _get_cumsum_and_arange(
self, self,
@@ -130,8 +124,7 @@ class PCPManager:
cu_num_tokens = np.cumsum(num_scheduled_tokens, dtype=cumsum_dtype) cu_num_tokens = np.cumsum(num_scheduled_tokens, dtype=cumsum_dtype)
total_num_tokens = cu_num_tokens[-1] total_num_tokens = cu_num_tokens[-1]
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, num_scheduled_tokens)
num_scheduled_tokens)
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
arange = arange_np[:total_num_tokens] - cumsums_offsets arange = arange_np[:total_num_tokens] - cumsums_offsets
@@ -143,15 +136,15 @@ class PCPManager:
num_reqs: int, num_reqs: int,
) -> None: ) -> None:
self.num_reqs = num_reqs self.num_reqs = num_reqs
is_prefill = (num_scheduled_tokens[:num_reqs] > self.decode_threshold) is_prefill = num_scheduled_tokens[:num_reqs] > self.decode_threshold
if not any(is_prefill): if not any(is_prefill):
first_prefill = num_reqs first_prefill = num_reqs
else: else:
first_prefill = is_prefill.argmax() first_prefill = is_prefill.argmax()
self.num_decode_reqs = first_prefill self.num_decode_reqs = first_prefill
self.num_prefill_reqs = num_reqs - self.num_decode_reqs self.num_prefill_reqs = num_reqs - self.num_decode_reqs
self.num_decode_tokens = num_scheduled_tokens[:self.num_decode_reqs].sum() self.num_decode_tokens = num_scheduled_tokens[: self.num_decode_reqs].sum()
def update_tokens_for_pcp( def update_tokens_for_pcp(
self, self,
num_scheduled_tokens: np.ndarray, num_scheduled_tokens: np.ndarray,
@@ -208,32 +201,29 @@ class PCPManager:
# DualChunkSwap requires alignment to a multiple of (2 * pcp_world_size). # DualChunkSwap requires alignment to a multiple of (2 * pcp_world_size).
# We first pad each request's token count up to that multiple. # We first pad each request's token count up to that multiple.
num_padded_scheduled_tokens = np.ceil( num_padded_scheduled_tokens = np.ceil(num_scheduled_tokens / (2 * self.pcp_world_size)).astype(np.int32) * (
num_scheduled_tokens / (2 * self.pcp_world_size)).astype( 2 * self.pcp_world_size
np.int32) * (2 * self.pcp_world_size) )
# PCP does not split decode requests. For decode requests, we instead # PCP does not split decode requests. For decode requests, we instead
# duplicate the scheduled tokens across the pcp_world_size ranks. # duplicate the scheduled tokens across the pcp_world_size ranks.
num_padded_scheduled_tokens[:self.num_decode_reqs] = ( num_padded_scheduled_tokens[: self.num_decode_reqs] = (
num_scheduled_tokens[:self.num_decode_reqs] * self.pcp_world_size) num_scheduled_tokens[: self.num_decode_reqs] * self.pcp_world_size
)
# Record how many pads were added per request (padded - original). # Record how many pads were added per request (padded - original).
self.num_pcp_pads_cpu[:self.num_reqs] = (num_padded_scheduled_tokens - self.num_pcp_pads_cpu[: self.num_reqs] = num_padded_scheduled_tokens - num_scheduled_tokens
num_scheduled_tokens)
# cu_padded_tokens: cumulative sum of padded token counts, # cu_padded_tokens: cumulative sum of padded token counts,
# pcp_padded_arange: per-request arange flattened for padded tokens. # pcp_padded_arange: per-request arange flattened for padded tokens.
cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange( cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange(num_padded_scheduled_tokens, arange_np)
num_padded_scheduled_tokens, arange_np)
# Build the mask that marks which positions in the padded allgather buffer # Build the mask that marks which positions in the padded allgather buffer
# correspond to real (unpadded) tokens. # correspond to real (unpadded) tokens.
self.pcp_unpad_mask_cpu[:pcp_padded_arange.shape[0]] = ( self.pcp_unpad_mask_cpu[: pcp_padded_arange.shape[0]] = pcp_padded_arange < np.repeat(
pcp_padded_arange < np.repeat(num_scheduled_tokens, num_scheduled_tokens, num_padded_scheduled_tokens
num_padded_scheduled_tokens)) )
unpad_mask_decode = self.pcp_unpad_mask_cpu[:self.num_decode_tokens * unpad_mask_decode = self.pcp_unpad_mask_cpu[: self.num_decode_tokens * self.pcp_world_size]
self.pcp_world_size] unpad_mask_decode = unpad_mask_decode.reshape([-1, self.pcp_world_size])
unpad_mask_decode = unpad_mask_decode.reshape(
[-1, self.pcp_world_size])
unpad_mask_decode[:, 0] = True unpad_mask_decode[:, 0] = True
unpad_mask_decode[:, 1:] = False unpad_mask_decode[:, 1:] = False
pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size
@@ -242,23 +232,20 @@ class PCPManager:
# For prefill requests, we further split the pcp_tokens into two chunks # For prefill requests, we further split the pcp_tokens into two chunks
# (head and tail). For decode requests, the chunk equals pcp_tokens. # (head and tail). For decode requests, the chunk equals pcp_tokens.
pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1) pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1)
pcp_chunk_sizes[:self.num_decode_reqs] = pcp_tokens[:self.num_decode_reqs] pcp_chunk_sizes[: self.num_decode_reqs] = pcp_tokens[: self.num_decode_reqs]
# Build arange-style helpers for pcp tokens and chunk sizes: # Build arange-style helpers for pcp tokens and chunk sizes:
# - pcp_arange gives indices repeated for each token in pcp_tokens # - pcp_arange gives indices repeated for each token in pcp_tokens
# - pcp_chunk_arange gives indices repeated for each position inside chunks # - pcp_chunk_arange gives indices repeated for each position inside chunks
_, pcp_arange = self._get_cumsum_and_arange(pcp_tokens, arange_np) _, pcp_arange = self._get_cumsum_and_arange(pcp_tokens, arange_np)
_, pcp_chunk_arange = self._get_cumsum_and_arange( _, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes, arange_np)
pcp_chunk_sizes, arange_np)
# Mask that marks whether a position belongs to the head chunk (True) # Mask that marks whether a position belongs to the head chunk (True)
# or the tail chunk (False). For decode requests, tail chunk won't exist # or the tail chunk (False). For decode requests, tail chunk won't exist
# and is handled specially below. # and is handled specially below.
pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, pcp_tokens)
pcp_tokens)
def get_current_rank_positions(positions_start_loc: int | np.ndarray, def get_current_rank_positions(positions_start_loc: int | np.ndarray, rank: int):
rank: int):
""" """
Compute flattened positions for the given rank with a given start Compute flattened positions for the given rank with a given start
offset for each request (positions_start_loc). offset for each request (positions_start_loc).
@@ -271,59 +258,53 @@ class PCPManager:
""" """
positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32) positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32)
head_start_loc = positions_start_loc + rank * pcp_chunk_sizes head_start_loc = positions_start_loc + rank * pcp_chunk_sizes
tail_start_loc = ( tail_start_loc = positions_start_loc + (2 * self.pcp_world_size - rank - 1) * pcp_chunk_sizes
positions_start_loc +
(2 * self.pcp_world_size - rank - 1) * pcp_chunk_sizes)
# Fill head positions using chunk arange offset by head_start_loc. # Fill head positions using chunk arange offset by head_start_loc.
positions[pcp_head_chunk_mask] = pcp_chunk_arange + np.repeat( positions[pcp_head_chunk_mask] = pcp_chunk_arange + np.repeat(head_start_loc, pcp_chunk_sizes)
head_start_loc, pcp_chunk_sizes)
# Fill tail positions. Note decode requests do not have tail chunks, # Fill tail positions. Note decode requests do not have tail chunks,
# so the tail filling is only for prefill positions. # so the tail filling is only for prefill positions.
positions[~pcp_head_chunk_mask] = ( positions[~pcp_head_chunk_mask] = (
pcp_chunk_arange[self.num_decode_tokens:] + pcp_chunk_arange[self.num_decode_tokens :]
np.repeat(tail_start_loc, pcp_chunk_sizes)[self.num_decode_tokens:]) + np.repeat(tail_start_loc, pcp_chunk_sizes)[self.num_decode_tokens :]
)
return positions return positions
positions = get_current_rank_positions(0, self.pcp_world_rank) positions = get_current_rank_positions(0, self.pcp_world_rank)
# Decode tokens are duplicated only after AG. But their positions are # Decode tokens are duplicated only after AG. But their positions are
# same without prefill context parallel. # same without prefill context parallel.
if self.num_decode_reqs > 0: if self.num_decode_reqs > 0:
positions[:self.num_decode_tokens] = self._get_cumsum_and_arange( positions[: self.num_decode_tokens] = self._get_cumsum_and_arange(
num_scheduled_tokens[:self.num_decode_reqs], arange_np)[1] num_scheduled_tokens[: self.num_decode_reqs], arange_np
)[1]
# Build the restore index used after allgather. # Build the restore index used after allgather.
padded_pos_start_loc = np.roll(cu_padded_tokens, 1) padded_pos_start_loc = np.roll(cu_padded_tokens, 1)
padded_pos_start_loc[0] = 0 padded_pos_start_loc[0] = 0
all_positions_lst = [ all_positions_lst = [
get_current_rank_positions(padded_pos_start_loc, rank_i) get_current_rank_positions(padded_pos_start_loc, rank_i) for rank_i in range(self.pcp_world_size)
for rank_i in range(self.pcp_world_size)
] ]
all_positions = np.concatenate(all_positions_lst) all_positions = np.concatenate(all_positions_lst)
self.pcp_allgather_restore_idx.np[:all_positions.shape[0]] = ( self.pcp_allgather_restore_idx.np[: all_positions.shape[0]] = all_positions.argsort()
all_positions.argsort())
self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0])
self.pcp_tokens[:self.num_reqs] = pcp_tokens[:self.num_reqs] self.pcp_tokens[: self.num_reqs] = pcp_tokens[: self.num_reqs]
self.total_num_sampled_tokens_pcp = pcp_tokens[:self.num_reqs].sum() self.total_num_sampled_tokens_pcp = pcp_tokens[: self.num_reqs].sum()
return ( return (
pcp_tokens[:self.num_reqs], pcp_tokens[: self.num_reqs],
positions, positions,
) )
def get_logits_indices(self, cu_num_tokens: np.ndarray): def get_logits_indices(self, cu_num_tokens: np.ndarray):
return (torch.from_numpy(cu_num_tokens) * self.pcp_world_size - return torch.from_numpy(cu_num_tokens) * self.pcp_world_size - self.num_pcp_pads_cpu_tensor[: self.num_reqs] - 1
self.num_pcp_pads_cpu_tensor[:self.num_reqs] - 1)
def get_padded_slot_mapping(self, num_tokens: int, num_tokens_padded: int, def get_padded_slot_mapping(self, num_tokens: int, num_tokens_padded: int, slot_mapping: torch.Tensor):
slot_mapping: torch.Tensor):
# After pcp allgather and restore, there are padded tokens in kv, # After pcp allgather and restore, there are padded tokens in kv,
# so we need pad slotmapping for alignment. # so we need pad slotmapping for alignment.
pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:num_tokens_padded * self.pcp_world_size] pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[: num_tokens_padded * self.pcp_world_size]
cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[:num_tokens * cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[: num_tokens * self.pcp_world_size]
self.pcp_world_size]
pcp_padded_slot_mapping.fill_(-1) pcp_padded_slot_mapping.fill_(-1)
pcp_padded_slot_mapping[:num_tokens * self.pcp_world_size][cp_unpad_mask] = slot_mapping pcp_padded_slot_mapping[: num_tokens * self.pcp_world_size][cp_unpad_mask] = slot_mapping
return pcp_padded_slot_mapping return pcp_padded_slot_mapping
def get_restore_hidden_states( def get_restore_hidden_states(
@@ -333,13 +314,12 @@ class PCPManager:
# NOTE we must `slice` hidden_states because pcp_allgather_restore_idx # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx
# ignores the padding from CUDA Graph. # ignores the padding from CUDA Graph.
from vllm.distributed.parallel_state import get_pcp_group from vllm.distributed.parallel_state import get_pcp_group
hidden_states = get_pcp_group().all_gather( hidden_states = get_pcp_group().all_gather(
hidden_states[:self.num_actual_tokens_pcp_padded // hidden_states[: self.num_actual_tokens_pcp_padded // self.pcp_world_size],
self.pcp_world_size],
0, 0,
) )
restore_idx = self.pcp_allgather_restore_idx.gpu[:hidden_states. restore_idx = self.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]]
shape[0]]
return torch.index_select( return torch.index_select(
hidden_states, hidden_states,
0, 0,
@@ -369,73 +349,61 @@ class PCPManager:
num_scheduled_tokens_pcp_full = np.empty(self.num_reqs, dtype=np.int32) num_scheduled_tokens_pcp_full = np.empty(self.num_reqs, dtype=np.int32)
for i, req_id in enumerate(input_batch.req_ids): for i, req_id in enumerate(input_batch.req_ids):
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id] num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
self.query_lens_pcp_full.cpu[:self.num_reqs] = torch.from_numpy( self.query_lens_pcp_full.cpu[: self.num_reqs] = torch.from_numpy(num_scheduled_tokens_pcp_full)
num_scheduled_tokens_pcp_full) req_indices_pcp_full = np.repeat(arange_np[: self.num_reqs], num_scheduled_tokens_pcp_full)
req_indices_pcp_full = np.repeat(arange_np[:self.num_reqs],
num_scheduled_tokens_pcp_full)
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full) cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
self.query_start_loc_pcp_full.np[0] = 0 self.query_start_loc_pcp_full.np[0] = 0
self.query_start_loc_pcp_full.np[1:self.num_reqs + self.query_start_loc_pcp_full.np[1 : self.num_reqs + 1] = cu_num_tokens_pcp_full
1] = cu_num_tokens_pcp_full self.query_start_loc_pcp_full.np[self.num_reqs + 1 :].fill(-1)
self.query_start_loc_pcp_full.np[self.num_reqs + 1:].fill(-1)
cumsums_offsets_pcp_full = np.repeat( cumsums_offsets_pcp_full = np.repeat(
cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full, cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full, num_scheduled_tokens_pcp_full
num_scheduled_tokens_pcp_full) )
arange_pcp_full = arange_np[:total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full arange_pcp_full = arange_np[:total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full
positions_pcp_full_np = self.positions_pcp_full_np[: positions_pcp_full_np = self.positions_pcp_full_np[:total_num_scheduled_tokens_pcp_full]
total_num_scheduled_tokens_pcp_full] np.add(input_batch.num_computed_tokens_cpu[req_indices_pcp_full], arange_pcp_full, out=positions_pcp_full_np)
np.add(input_batch.num_computed_tokens_cpu[req_indices_pcp_full], token_indices_pcp_full = positions_pcp_full_np + req_indices_pcp_full * input_batch.token_ids_cpu.shape[1]
arange_pcp_full, torch.index_select(
out=positions_pcp_full_np) input_batch.token_ids_cpu_tensor.flatten(),
token_indices_pcp_full = ( 0,
positions_pcp_full_np + torch.from_numpy(token_indices_pcp_full),
req_indices_pcp_full * input_batch.token_ids_cpu.shape[1]) out=self.input_ids_pcp_full.cpu[:total_num_scheduled_tokens_pcp_full],
torch.index_select(input_batch.token_ids_cpu_tensor.flatten(), )
0,
torch.from_numpy(token_indices_pcp_full),
out=self.input_ids_pcp_full.
cpu[:total_num_scheduled_tokens_pcp_full])
if self.use_async_scheduling: if self.use_async_scheduling:
self._update_input_ids_pcp_full_ids(input_batch, draft_token_ids, self._update_input_ids_pcp_full_ids(
scheduler_output, input_batch,
total_num_scheduled_tokens, draft_token_ids,
cu_num_tokens_pcp_full, scheduler_output,
num_spec_tokens) total_num_scheduled_tokens,
cu_num_tokens_pcp_full,
num_spec_tokens,
)
self.query_lens_pcp_full.copy_to_gpu() self.query_lens_pcp_full.copy_to_gpu()
self.query_start_loc_pcp_full.copy_to_gpu() self.query_start_loc_pcp_full.copy_to_gpu()
self.input_ids_pcp_full.copy_to_gpu( self.input_ids_pcp_full.copy_to_gpu(total_num_scheduled_tokens_pcp_full)
total_num_scheduled_tokens_pcp_full)
self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full
# For mtpx, pre-allocate mtp slot_mapping here # For mtpx, pre-allocate mtp slot_mapping here
if self.decode_threshold > 2 and not with_prefill: if self.decode_threshold > 2 and not with_prefill:
num_tokens_ori = sum(list(num_scheduled_tokens.values())) num_tokens_ori = sum(list(num_scheduled_tokens.values()))
num_tokens_mtp = \ num_tokens_mtp = num_tokens_ori + self.num_reqs * (self.decode_threshold - 2)
num_tokens_ori + self.num_reqs * (self.decode_threshold - 2)
num_tokens_mtp_pad = num_tokens_mtp * self.pcp_world_size num_tokens_mtp_pad = num_tokens_mtp * self.pcp_world_size
req_indices_split = np.array_split(req_indices, req_indices_split = np.array_split(req_indices, cu_num_tokens)[: self.num_reqs]
cu_num_tokens)[:self.num_reqs] positions_split = np.array_split(positions_np, cu_num_tokens)[: self.num_reqs]
positions_split = np.array_split(positions_np,
cu_num_tokens)[:self.num_reqs]
for req_idx in range(self.num_reqs): for req_idx in range(self.num_reqs):
ori_req_indice = req_indices_split[req_idx] ori_req_indice = req_indices_split[req_idx]
ori_position = positions_split[req_idx] ori_position = positions_split[req_idx]
req_indices_split[req_idx] = np.append( req_indices_split[req_idx] = np.append(
ori_req_indice, ori_req_indice, np.repeat(ori_req_indice[-1], self.decode_threshold - 2)
np.repeat(ori_req_indice[-1], self.decode_threshold - 2)) )
positions_split[req_idx] = np.append( positions_split[req_idx] = np.append(
ori_position, ori_position, np.arange(ori_position[-1] + 1, ori_position[-1] + self.decode_threshold - 1)
np.arange(ori_position[-1] + 1, )
ori_position[-1] + self.decode_threshold - 1))
req_indices_mtp = np.concatenate(req_indices_split) req_indices_mtp = np.concatenate(req_indices_split)
positions_mtp = np.concatenate(positions_split) positions_mtp = np.concatenate(positions_split)
input_batch.block_table.compute_slot_mapping( input_batch.block_table.compute_slot_mapping(req_indices_mtp, positions_mtp)
req_indices_mtp, positions_mtp) mtp_slot_ori = input_batch.block_table.block_tables[0].slot_mapping.cpu[:num_tokens_mtp]
mtp_slot_ori = input_batch.block_table.block_tables[
0].slot_mapping.cpu[:num_tokens_mtp]
unpad_mask = np.repeat(False, num_tokens_mtp_pad) unpad_mask = np.repeat(False, num_tokens_mtp_pad)
unpad_mask[::self.pcp_world_size] = True unpad_mask[:: self.pcp_world_size] = True
mtp_slot_pad = \ mtp_slot_pad = torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32)
torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32)
mtp_slot_pad[unpad_mask] = mtp_slot_ori mtp_slot_pad[unpad_mask] = mtp_slot_ori
self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True) self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True)
@@ -454,8 +422,7 @@ class PCPManager:
from the previous engine iteration, in which case those tokens on the from the previous engine iteration, in which case those tokens on the
GPU need to be copied into the corresponding slots into input_ids.""" GPU need to be copied into the corresponding slots into input_ids."""
if (input_batch.prev_sampled_token_ids is None if input_batch.prev_sampled_token_ids is None or input_batch.prev_req_id_to_index is None:
or input_batch.prev_req_id_to_index is None):
return return
# Async scheduling case, where some decode requests from the previous # Async scheduling case, where some decode requests from the previous
@@ -481,9 +448,7 @@ class PCPManager:
# sample_flattened_indices = [0, 2, 5] # sample_flattened_indices = [0, 2, 5]
# spec_flattened_indices = [1, 3, 4, 6, 7] # spec_flattened_indices = [1, 3, 4, 6, 7]
sample_flattened_indices.append(flattened_index - draft_len) sample_flattened_indices.append(flattened_index - draft_len)
spec_flattened_indices.extend( spec_flattened_indices.extend(range(flattened_index - draft_len + 1, flattened_index + 1))
range(flattened_index - draft_len + 1,
flattened_index + 1))
start = prev_index * num_spec_tokens start = prev_index * num_spec_tokens
# prev_draft_token_indices is used to find which draft_tokens_id # prev_draft_token_indices is used to find which draft_tokens_id
# should be copied to input_ids # should be copied to input_ids
@@ -491,8 +456,7 @@ class PCPManager:
# flatten draft_tokens_id [1,2,3,4,5,6] # flatten draft_tokens_id [1,2,3,4,5,6]
# draft_len of each request [1, 2, 1] # draft_len of each request [1, 2, 1]
# then prev_draft_token_indices is [0, 2, 3, 4] # then prev_draft_token_indices is [0, 2, 3, 4]
prev_draft_token_indices.extend(range(start, prev_draft_token_indices.extend(range(start, start + draft_len))
start + draft_len))
num_commmon_tokens = len(sample_flattened_indices) num_commmon_tokens = len(sample_flattened_indices)
if num_commmon_tokens == 0: if num_commmon_tokens == 0:
@@ -500,15 +464,12 @@ class PCPManager:
# So input_ids.cpu will have all the input ids. # So input_ids.cpu will have all the input ids.
return return
# Upload the index tensors asynchronously so the scatter can be non-blocking. # Upload the index tensors asynchronously so the scatter can be non-blocking.
sampled_tokens_index_tensor = torch.tensor(sample_flattened_indices, sampled_tokens_index_tensor = torch.tensor(sample_flattened_indices, dtype=torch.int64)
dtype=torch.int64) prev_common_req_indices_tensor = torch.tensor(prev_common_req_indices, dtype=torch.int64)
prev_common_req_indices_tensor = torch.tensor(prev_common_req_indices,
dtype=torch.int64)
self.input_ids_pcp_full.cpu.scatter_( self.input_ids_pcp_full.cpu.scatter_(
dim=0, dim=0,
index=sampled_tokens_index_tensor, index=sampled_tokens_index_tensor,
src=input_batch.prev_sampled_token_ids[ src=input_batch.prev_sampled_token_ids[prev_common_req_indices_tensor, 0].cpu(),
prev_common_req_indices_tensor, 0].cpu(),
) )
# Scatter the draft tokens after the sampled tokens are scattered. # Scatter the draft tokens after the sampled tokens are scattered.
@@ -516,10 +477,8 @@ class PCPManager:
return return
assert isinstance(draft_token_ids, torch.Tensor) assert isinstance(draft_token_ids, torch.Tensor)
draft_tokens_index_tensor = torch.tensor(spec_flattened_indices, draft_tokens_index_tensor = torch.tensor(spec_flattened_indices, dtype=torch.int64)
dtype=torch.int64) prev_draft_token_indices_tensor = torch.tensor(prev_draft_token_indices, dtype=torch.int64)
prev_draft_token_indices_tensor = torch.tensor(
prev_draft_token_indices, dtype=torch.int64)
# because input_ids dtype is torch.int32, # because input_ids dtype is torch.int32,
# so convert draft_token_ids to torch.int32 here. # so convert draft_token_ids to torch.int32 here.
@@ -528,8 +487,7 @@ class PCPManager:
self.input_ids_pcp_full.cpu.scatter_( self.input_ids_pcp_full.cpu.scatter_(
dim=0, dim=0,
index=draft_tokens_index_tensor, index=draft_tokens_index_tensor,
src=draft_token_ids.flatten() src=draft_token_ids.flatten()[prev_draft_token_indices_tensor].cpu(),
[prev_draft_token_indices_tensor].cpu(),
) )
def _get_cp_local_seq_lens( def _get_cp_local_seq_lens(
@@ -545,41 +503,32 @@ class PCPManager:
num_requests = seq_lens.size(0) num_requests = seq_lens.size(0)
total_world_size = pcp_world_size * dcp_world_size total_world_size = pcp_world_size * dcp_world_size
seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size) seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size)
rank_offsets = (torch.arange(total_world_size, rank_offsets = torch.arange(total_world_size, dtype=torch.int32).unsqueeze(0).repeat(num_requests, 1)
dtype=torch.int32).unsqueeze(0).repeat( base = seq_lens_tiled // cp_kv_cache_interleave_size // total_world_size * cp_kv_cache_interleave_size
num_requests, 1))
base = (seq_lens_tiled // cp_kv_cache_interleave_size //
total_world_size * cp_kv_cache_interleave_size)
remainder = seq_lens_tiled - base * total_world_size remainder = seq_lens_tiled - base * total_world_size
remainder = torch.clip( remainder = torch.clip(
remainder - rank_offsets * cp_kv_cache_interleave_size, remainder - rank_offsets * cp_kv_cache_interleave_size,
0, 0,
cp_kv_cache_interleave_size, cp_kv_cache_interleave_size,
) )
dcp_local_seq_lens = (base + remainder).reshape( dcp_local_seq_lens = (base + remainder).reshape([-1, pcp_world_size, dcp_world_size])
[-1, pcp_world_size, dcp_world_size])
return dcp_local_seq_lens return dcp_local_seq_lens
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, input_batch, num_scheduled_tokens):
input_batch, num_scheduled_tokens): from vllm_ascend.attention.utils import AscendPrefillContextParallelMetadata
from vllm_ascend.attention.utils import \
AscendPrefillContextParallelMetadata
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
long_seq_metadata = None long_seq_metadata = None
if self.pcp_world_size * self.dcp_world_size > 1: if self.pcp_world_size * self.dcp_world_size > 1:
decode_context_lens = input_batch.num_computed_tokens_cpu[: decode_context_lens = (
self.num_decode_reqs] + num_scheduled_tokens[: input_batch.num_computed_tokens_cpu[: self.num_decode_reqs]
self.num_decode_reqs] + num_scheduled_tokens[: self.num_decode_reqs]
prefill_context_lens = input_batch.num_computed_tokens_cpu[ )
self.num_decode_reqs:self.num_reqs] prefill_context_lens = input_batch.num_computed_tokens_cpu[self.num_decode_reqs : self.num_reqs]
context_lens = np.concatenate( context_lens = np.concatenate([decode_context_lens, prefill_context_lens])
[decode_context_lens, prefill_context_lens])
num_computed_tokens_of_pcp_dcp = torch.zeros( num_computed_tokens_of_pcp_dcp = torch.zeros(
[ [self.num_reqs * self.decode_threshold, self.pcp_world_size, self.dcp_world_size],
self.num_reqs * self.decode_threshold, self.pcp_world_size,
self.dcp_world_size
],
dtype=torch.int32, dtype=torch.int32,
) )
# For pcp + spec decode, we flatten seq_lens # For pcp + spec decode, we flatten seq_lens
@@ -587,41 +536,37 @@ class PCPManager:
# Same as block_table, we flatten decode seq_lens to query_lens, # Same as block_table, we flatten decode seq_lens to query_lens,
# and keep prefill seq_lens unchanged. # and keep prefill seq_lens unchanged.
for decode_idx in range(self.decode_threshold): for decode_idx in range(self.decode_threshold):
num_computed_tokens_of_pcp_dcp[ num_computed_tokens_of_pcp_dcp[self.decode_threshold - 1 - decode_idx :: self.decode_threshold] = (
self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \
self._get_cp_local_seq_lens( self._get_cp_local_seq_lens(
torch.tensor(context_lens) - decode_idx, torch.tensor(context_lens) - decode_idx,
self.pcp_world_size, self.pcp_world_size,
self.dcp_world_size, self.dcp_world_size,
self.vllm_config.parallel_config.cp_kv_cache_interleave_size, self.vllm_config.parallel_config.cp_kv_cache_interleave_size,
) )
)
if self.decode_threshold > 1: if self.decode_threshold > 1:
num_computed_tokens_of_pcp_dcp_list = [] num_computed_tokens_of_pcp_dcp_list = []
if self.num_decode_reqs: if self.num_decode_reqs:
num_decodes_flatten = \ num_decodes_flatten = query_lens[: self.num_decode_reqs].sum().item()
query_lens[:self.num_decode_reqs].sum().item() if query_lens[: self.num_decode_reqs].min().item() == self.decode_threshold:
if query_lens[:self.num_decode_reqs].min().item(
) == self.decode_threshold:
decode_flatten_idx = list(range(num_decodes_flatten)) decode_flatten_idx = list(range(num_decodes_flatten))
else: else:
decode_flatten_idx = [] decode_flatten_idx = []
for req_id in range(self.num_decode_reqs): for req_id in range(self.num_decode_reqs):
offset = (req_id + 1) * self.decode_threshold offset = (req_id + 1) * self.decode_threshold
decode_flatten_idx += \ decode_flatten_idx += list(range(offset - query_lens[req_id], offset))
list(range(offset - query_lens[req_id], offset)) num_computed_tokens_of_pcp_dcp_list.append(num_computed_tokens_of_pcp_dcp[decode_flatten_idx])
num_computed_tokens_of_pcp_dcp_list.append(
num_computed_tokens_of_pcp_dcp[decode_flatten_idx])
if self.num_prefill_reqs: if self.num_prefill_reqs:
num_computed_tokens_of_pcp_dcp_list.append( num_computed_tokens_of_pcp_dcp_list.append(
num_computed_tokens_of_pcp_dcp[ num_computed_tokens_of_pcp_dcp[
(self.num_decode_reqs + 1) * self.decode_threshold - (self.num_decode_reqs + 1) * self.decode_threshold - 1 :: self.decode_threshold
1::self.decode_threshold]) ]
num_computed_tokens_of_pcp_dcp = torch.cat( )
num_computed_tokens_of_pcp_dcp_list, dim=0) num_computed_tokens_of_pcp_dcp = torch.cat(num_computed_tokens_of_pcp_dcp_list, dim=0)
long_seq_metadata = AscendPrefillContextParallelMetadata( long_seq_metadata = AscendPrefillContextParallelMetadata(
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp. num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.numpy(),
numpy()) )
if self.pcp_world_size > 1: if self.pcp_world_size > 1:
q_head_idx, q_tail_idx = [], [] q_head_idx, q_tail_idx = [], []
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
@@ -639,109 +584,102 @@ class PCPManager:
continue continue
chunk_len = seq_len // 2 chunk_len = seq_len // 2
chunk_seqlens.append(chunk_len) chunk_seqlens.append(chunk_len)
q_head_idx.extend( q_head_idx.extend(list(range(q_req_offset, q_req_offset + chunk_len)))
list(range(q_req_offset, q_req_offset + chunk_len)))
kv_with_q_head_nomask_idx.extend( kv_with_q_head_nomask_idx.extend(
list( list(range(kv_req_offset, kv_req_offset + chunk_len * q_head_chunk_id))
range(kv_req_offset, kv_req_offset + )
chunk_len * q_head_chunk_id)))
kv_with_q_head_mask_idx.extend( kv_with_q_head_mask_idx.extend(
list( list(
range( range(
kv_req_offset + chunk_len * q_head_chunk_id, kv_req_offset + chunk_len * q_head_chunk_id,
kv_req_offset + chunk_len * kv_req_offset + chunk_len * (q_head_chunk_id + 1),
(q_head_chunk_id + 1)))) )
kv_with_q_head_nomask_seqlens.append(chunk_len * )
q_head_chunk_id) )
kv_with_q_head_nomask_seqlens.append(chunk_len * q_head_chunk_id)
split_with_q_head_nomask_idx_reqs.append( split_with_q_head_nomask_idx_reqs.append(
list( list(range(kv_req_offset, kv_req_offset + chunk_len * q_head_chunk_id))
range(kv_req_offset, kv_req_offset + )
chunk_len * q_head_chunk_id))) q_tail_idx.extend(list(range(q_req_offset + chunk_len, q_req_offset + chunk_len * 2)))
q_tail_idx.extend(
list(
range(q_req_offset + chunk_len,
q_req_offset + chunk_len * 2)))
kv_with_q_tail_nomask_idx.extend( kv_with_q_tail_nomask_idx.extend(
list( list(range(kv_req_offset, kv_req_offset + chunk_len * q_tail_chunk_id))
range(kv_req_offset, kv_req_offset + )
chunk_len * q_tail_chunk_id)))
kv_with_q_tail_mask_idx.extend( kv_with_q_tail_mask_idx.extend(
list( list(
range( range(
kv_req_offset + chunk_len * q_tail_chunk_id, kv_req_offset + chunk_len * q_tail_chunk_id,
kv_req_offset + chunk_len * kv_req_offset + chunk_len * (q_tail_chunk_id + 1),
(q_tail_chunk_id + 1)))) )
kv_with_q_tail_nomask_seqlens.append(chunk_len * )
q_tail_chunk_id) )
kv_with_q_tail_nomask_seqlens.append(chunk_len * q_tail_chunk_id)
split_kv_with_q_tail_nomask_idx_reqs.append( split_kv_with_q_tail_nomask_idx_reqs.append(
list( list(range(kv_req_offset, kv_req_offset + chunk_len * q_tail_chunk_id))
range(kv_req_offset, kv_req_offset + )
chunk_len * q_tail_chunk_id)))
q_req_offset += seq_len q_req_offset += seq_len
kv_req_offset += seq_len * self.pcp_world_size kv_req_offset += seq_len * self.pcp_world_size
q_head_idx_tensor = self._list_to_tensor( q_head_idx_tensor = self._list_to_tensor(q_head_idx, self.device)
q_head_idx, self.device) q_tail_idx_tensor = self._list_to_tensor(q_tail_idx, self.device)
q_tail_idx_tensor = self._list_to_tensor(
q_tail_idx, self.device)
self.q_head_idx_tensor = q_head_idx_tensor self.q_head_idx_tensor = q_head_idx_tensor
self.q_tail_idx_tensor = q_tail_idx_tensor self.q_tail_idx_tensor = q_tail_idx_tensor
q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor]) q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor])
q_full_idx = q_full_idx.to(torch.float32).argsort().to( q_full_idx = q_full_idx.to(torch.float32).argsort().to(torch.int32)
torch.int32)
self.q_full_idx = q_full_idx self.q_full_idx = q_full_idx
self.kv_idx_names = { self.kv_idx_names = {
'kv_with_q_head_nomask_idx_tensor': "kv_with_q_head_nomask_idx_tensor": kv_with_q_head_nomask_idx,
kv_with_q_head_nomask_idx, "kv_with_q_head_mask_idx_tensor": kv_with_q_head_mask_idx,
'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx, "kv_with_q_tail_nomask_idx_tensor": kv_with_q_tail_nomask_idx,
'kv_with_q_tail_nomask_idx_tensor': "kv_with_q_tail_mask_idx_tensor": kv_with_q_tail_mask_idx,
kv_with_q_tail_nomask_idx,
'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx
} }
for key, value in self.kv_idx_names.items(): for key, value in self.kv_idx_names.items():
tensor_npu = self._list_to_tensor(value, self.device) tensor_npu = self._list_to_tensor(value, self.device)
self.kv_idx_names[key] = tensor_npu self.kv_idx_names[key] = tensor_npu
attn_mask_seqlens = torch.tensor( attn_mask_seqlens = torch.tensor([chunk_seqlens, chunk_seqlens], dtype=torch.int32)
[chunk_seqlens, chunk_seqlens], dtype=torch.int32)
head_attn_nomask_seqlens = torch.tensor( head_attn_nomask_seqlens = torch.tensor(
[chunk_seqlens, kv_with_q_head_nomask_seqlens], [chunk_seqlens, kv_with_q_head_nomask_seqlens], dtype=torch.int32
dtype=torch.int32) )
tail_attn_nomask_seqlens = torch.tensor( tail_attn_nomask_seqlens = torch.tensor(
[chunk_seqlens, kv_with_q_tail_nomask_seqlens], [chunk_seqlens, kv_with_q_tail_nomask_seqlens], dtype=torch.int32
dtype=torch.int32) )
if self.vllm_config.model_config.use_mla: if self.vllm_config.model_config.use_mla:
split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list, head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list = self._split_nomask_idx_tensor_list( (
split_q_head_nomask_idx_tensor_list,
split_q_tail_nomask_idx_tensor_list,
head_attn_nomask_seqlens_list,
tail_attn_nomask_seqlens_list,
) = self._split_nomask_idx_tensor_list(
split_with_q_head_nomask_idx_reqs, split_with_q_head_nomask_idx_reqs,
split_kv_with_q_tail_nomask_idx_reqs, split_kv_with_q_tail_nomask_idx_reqs,
head_attn_nomask_seqlens, chunk_seqlens) head_attn_nomask_seqlens,
chunk_seqlens,
)
self.extra_long_seq_kwargs = { self.extra_long_seq_kwargs = {
'attn_mask_seqlens': attn_mask_seqlens, "attn_mask_seqlens": attn_mask_seqlens,
'head_attn_nomask_seqlens': head_attn_nomask_seqlens, "head_attn_nomask_seqlens": head_attn_nomask_seqlens,
'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens "tail_attn_nomask_seqlens": tail_attn_nomask_seqlens,
} }
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[: long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[
num_actual_tokens_pcp_padded] :num_actual_tokens_pcp_padded
]
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
long_seq_metadata.q_full_idx = self.q_full_idx long_seq_metadata.q_full_idx = self.q_full_idx
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[ long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[
'kv_with_q_head_nomask_idx_tensor'] "kv_with_q_head_nomask_idx_tensor"
long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[ ]
'kv_with_q_head_mask_idx_tensor'] long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names["kv_with_q_head_mask_idx_tensor"]
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[ long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[
'kv_with_q_tail_nomask_idx_tensor'] "kv_with_q_tail_nomask_idx_tensor"
long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[ ]
'kv_with_q_tail_mask_idx_tensor'] long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names["kv_with_q_tail_mask_idx_tensor"]
long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[ long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs["attn_mask_seqlens"]
'attn_mask_seqlens'] long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs["head_attn_nomask_seqlens"]
long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[ long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs["tail_attn_nomask_seqlens"]
'head_attn_nomask_seqlens']
long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[
'tail_attn_nomask_seqlens']
if self.vllm_config.model_config.use_mla: if self.vllm_config.model_config.use_mla:
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = split_q_head_nomask_idx_tensor_list long_seq_metadata.kv_with_q_head_nomask_idx_tensor = split_q_head_nomask_idx_tensor_list
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list
@@ -755,46 +693,53 @@ class PCPManager:
tensor_npu.copy_(torch.tensor(lst, dtype=dtype), non_blocking=True) tensor_npu.copy_(torch.tensor(lst, dtype=dtype), non_blocking=True)
return tensor_npu return tensor_npu
def _split_nomask_idx_tensor_list(self, split_with_q_head_nomask_idx_reqs, def _split_nomask_idx_tensor_list(
split_kv_with_q_tail_nomask_idx_reqs, self,
head_attn_nomask_seqlens, chunk_seqlens): split_with_q_head_nomask_idx_reqs,
split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list= [], [] split_kv_with_q_tail_nomask_idx_reqs,
head_attn_nomask_seqlens,
chunk_seqlens,
):
split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list = [], []
head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list = [], [] head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list = [], []
if split_with_q_head_nomask_idx_reqs: if split_with_q_head_nomask_idx_reqs:
#In long-sequence scenarios, the computational cost and latency # In long-sequence scenarios, the computational cost and latency
#of the _npu_ring_mla operator are not proportional, so we split # of the _npu_ring_mla operator are not proportional, so we split
#long sequences into shorter ones to improve performance. # long sequences into shorter ones to improve performance.
split_size = 16 * 1024 split_size = 16 * 1024
if self.pcp_world_rank == 0: if self.pcp_world_rank == 0:
split_q_head_nomask_idx_list = [ split_q_head_nomask_idx_list = [self.kv_idx_names["kv_with_q_head_nomask_idx_tensor"]]
self.kv_idx_names['kv_with_q_head_nomask_idx_tensor']
]
else: else:
split_q_head_nomask_idx_list, split_q_head_nomask_lens_list = self._split_multi_batch_kv_idx( split_q_head_nomask_idx_list, split_q_head_nomask_lens_list = self._split_multi_batch_kv_idx(
split_with_q_head_nomask_idx_reqs, split_size) split_with_q_head_nomask_idx_reqs, split_size
)
split_q_tail_nomask_idx_list, split_q_tail_nomask_lens_list = self._split_multi_batch_kv_idx( split_q_tail_nomask_idx_list, split_q_tail_nomask_lens_list = self._split_multi_batch_kv_idx(
split_kv_with_q_tail_nomask_idx_reqs, split_size) split_kv_with_q_tail_nomask_idx_reqs, split_size
)
for q_head_nomask_idx in split_q_head_nomask_idx_list: for q_head_nomask_idx in split_q_head_nomask_idx_list:
split_q_head_nomask_idx_tensor_list.append( split_q_head_nomask_idx_tensor_list.append(self._list_to_tensor(q_head_nomask_idx, self.device))
self._list_to_tensor(q_head_nomask_idx, self.device))
for q_tail_nomask_idx in split_q_tail_nomask_idx_list: for q_tail_nomask_idx in split_q_tail_nomask_idx_list:
split_q_tail_nomask_idx_tensor_list.append( split_q_tail_nomask_idx_tensor_list.append(self._list_to_tensor(q_tail_nomask_idx, self.device))
self._list_to_tensor(q_tail_nomask_idx, self.device))
if self.pcp_world_rank == 0: if self.pcp_world_rank == 0:
head_attn_nomask_seqlens_list = [head_attn_nomask_seqlens] head_attn_nomask_seqlens_list = [head_attn_nomask_seqlens]
else: else:
for q_head_nomask_lens in split_q_head_nomask_lens_list: for q_head_nomask_lens in split_q_head_nomask_lens_list:
head_attn_nomask_seqlens_list.append( head_attn_nomask_seqlens_list.append(
torch.tensor([chunk_seqlens, q_head_nomask_lens], torch.tensor([chunk_seqlens, q_head_nomask_lens], dtype=torch.int32)
dtype=torch.int32)) )
for q_tail_nomask_lens in split_q_tail_nomask_lens_list: for q_tail_nomask_lens in split_q_tail_nomask_lens_list:
tail_attn_nomask_seqlens_list.append( tail_attn_nomask_seqlens_list.append(
torch.tensor([chunk_seqlens, q_tail_nomask_lens], torch.tensor([chunk_seqlens, q_tail_nomask_lens], dtype=torch.int32)
dtype=torch.int32)) )
return split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list, head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list return (
split_q_head_nomask_idx_tensor_list,
split_q_tail_nomask_idx_tensor_list,
head_attn_nomask_seqlens_list,
tail_attn_nomask_seqlens_list,
)
def _split_multi_batch_kv_idx( def _split_multi_batch_kv_idx(
self, self,
@@ -813,7 +758,7 @@ class PCPManager:
current_batch_len = [] current_batch_len = []
for t in range(time): for t in range(time):
start = t * split_size start = t * split_size
current_segment = single_batch[start:start + split_size] current_segment = single_batch[start : start + split_size]
current_batch_split.append(current_segment) current_batch_split.append(current_segment)
current_batch_len.append(len(current_segment)) current_batch_len.append(len(current_segment))
@@ -829,8 +774,9 @@ class PCPManager:
def reshape_kv_len_to_time_first(split_kv_len_2d): def reshape_kv_len_to_time_first(split_kv_len_2d):
if not split_kv_len_2d or not split_kv_len_2d[0]: if not split_kv_len_2d or not split_kv_len_2d[0]:
return [] return []
return [[batch_len[time_idx] for batch_len in split_kv_len_2d] return [
for time_idx in range(len(split_kv_len_2d[0]))] [batch_len[time_idx] for batch_len in split_kv_len_2d] for time_idx in range(len(split_kv_len_2d[0]))
]
merged_split_kv_len_2d = reshape_kv_len_to_time_first(split_kv_len_2d) merged_split_kv_len_2d = reshape_kv_len_to_time_first(split_kv_len_2d)
return merged_split_kv_idx_3d, merged_split_kv_len_2d return merged_split_kv_idx_3d, merged_split_kv_len_2d