[long_seq_Feat] support chunk prefill (#4158)

### What this PR does / why we need it?
1、qwen GQA attention_v1 optim
2、DeepSeek MLA refactor, all gather q -> all gather kv 
3、modelrunner refactor for chunk prefill, we remove some code not use

- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

---------

Signed-off-by: LookAround <lixushi@huawei.com>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
This commit is contained in:
LookAround0301
2025-11-14 08:43:37 +08:00
committed by GitHub
parent 7294f89e43
commit 5ec96fd46c
6 changed files with 419 additions and 941 deletions

View File

@@ -29,7 +29,7 @@ from copy import deepcopy
from dataclasses import dataclass
from multiprocessing import Manager
from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional,
Tuple, Union, cast)
Union, cast)
import numpy as np
import numpy.typing as npt
@@ -763,7 +763,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
backward_kwargs["mm_features"] = new_req_data.mm_features
# Create request state - PCP/DCP tracking will be computed below
req_state = CachedRequestState(
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt_embeds=new_req_data.prompt_embeds,
@@ -774,42 +774,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
lora_request=new_req_data.lora_request,
local_chunked_kv_lens=None,
**backward_kwargs,
)
# Compute PCP/DCP tracking fields for chunked prefill
self.input_batch.local_chunked_kv_lens = [None] * self.max_num_reqs
if self.pcp_size * self.dcp_size > 1:
num_computed_tokens = new_req_data.num_computed_tokens
if num_computed_tokens > 0:
# Initialize with starting rank 0
temp_start_rank_dict = {req_id: (0, 0)}
# Compute token distribution for initial tokens
current_distribution = self.get_split_computed_tokens(
np.array([num_computed_tokens]),
request_ids=[req_id],
request_start_rank_dict=temp_start_rank_dict,
cp_kv_cache_interleave_size=self.parallel_config.
cp_kv_cache_interleave_size,
)[0]
# Update next_pcp_dcp_start_rank
req_state.next_pcp_dcp_start_rank = temp_start_rank_dict[
req_id][0]
req_state.token_blank_in_last_blk = temp_start_rank_dict[
req_id][1]
req_state.local_chunked_kv_lens = [
copy.deepcopy(current_distribution)
]
else:
# No computed tokens yet
req_state.local_chunked_kv_lens = []
self.requests[req_id] = req_state
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
self._init_mrope_positions(self.requests[req_id])
@@ -826,44 +793,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states.
prev_num_computed_tokens = req_state.num_computed_tokens
req_state.num_computed_tokens = num_computed_tokens
# Compute PCP/DCP tracking fields for chunked prefill
if self.pcp_size * self.dcp_size > 1:
# If this is the first chunk, initialize tracking fields
if req_state.local_chunked_kv_lens is None:
req_state.local_chunked_kv_lens = []
# Compute tokens added in this chunk (not cumulative)
chunk_tokens = num_computed_tokens - prev_num_computed_tokens
if chunk_tokens > 0:
# Create a temporary dict with this request's starting rank
temp_start_rank_dict = {
req_id: (req_state.next_pcp_dcp_start_rank,
req_state.token_blank_in_last_blk)
}
# Compute distribution for this chunk only
chunk_distribution = self.get_split_computed_tokens(
np.array([chunk_tokens]),
request_ids=[req_id],
request_start_rank_dict=temp_start_rank_dict,
cp_kv_cache_interleave_size=self.parallel_config.
cp_kv_cache_interleave_size,
)[0]
# Update next_pcp_dcp_start_rank for this request
req_state.next_pcp_dcp_start_rank = temp_start_rank_dict[
req_id][0]
req_state.token_blank_in_last_blk = temp_start_rank_dict[
req_id][1]
# Append this chunk's distribution to accumulation list
req_state.local_chunked_kv_lens.append(
copy.deepcopy(chunk_distribution))
if not is_last_rank:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
@@ -908,10 +839,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.block_table.append_row(
new_block_ids, req_index)
# Update PCP/DCP tracking fields in input_batch
self.input_batch.local_chunked_kv_lens[
req_index] = req_state.local_chunked_kv_lens
# For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached.
if not is_last_rank:
@@ -1477,7 +1404,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
scheduler_output,
decode_threshold=self.reorder_batch_threshold)
def generate_kv_idx(self, tokens, scheduler_output):
def generate_kv_idx(self, scheduler_output):
if not self.pcp_size > 1:
return
self.cp_kv_recover_idx_for_chunk = [[] for _ in range(self.pcp_size)]
@@ -1548,12 +1475,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
)
self.generate_kv_idx(tokens, scheduler_output)
self.input_batch.block_table.compute_slot_mapping(
req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens)
if self.pcp_size > 1:
if not self.vllm_config.model_config.use_mla:
self.generate_kv_idx(scheduler_output)
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
tokens)
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
@@ -1905,18 +1834,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_reqs, scheduler_output.total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens)
# prepare pcp meta data
# For chunked prefill, use num_scheduled_tokens instead of cumulative seq_lens
# to correctly calculate chunk_len in _generate_pcp_metadata
if self.vllm_config.scheduler_config.chunked_prefill_enabled and self.pcp_size > 1:
# In chunked prefill, seq_lens_for_chunk should be the current chunk size
seq_lens_for_chunk = torch.from_numpy(
num_scheduled_tokens[:num_reqs])
else:
# Normal mode: use cumulative sequence lengths
seq_lens_for_chunk = seq_lens_cpu
long_seq_metadata = self._generate_pcp_metadata(
total_num_scheduled_tokens, seq_lens_for_chunk, seq_lens_cpu)
total_num_scheduled_tokens)
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
@@ -2842,6 +2761,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor
self.query_start_loc_cpu[1:num_reqs +
1] = torch.Tensor(cu_num_tokens)
self.query_lens = torch.from_numpy(num_scheduled_tokens)
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
@@ -2855,8 +2775,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
long_seq_metadata = self._generate_pcp_metadata(
num_tokens, self.seq_lens_cpu, self.seq_lens_cpu)
long_seq_metadata = self._generate_pcp_metadata(num_tokens)
if long_seq_metadata is not None:
pcp_world_size = get_pcp_group(
).world_size if prefill_context_parallel_enable() else 1
@@ -4411,7 +4330,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
all_positions_tensor.float().argsort().long(), non_blocking=True)
return pcp_tokens, positions, unpad_mask
def _get_pcp_local_seq_lens(
def _get_cp_local_seq_lens(
self,
seq_lens: torch.Tensor,
pcp_world_size: int = 1,
@@ -4439,139 +4358,20 @@ class NPUModelRunner(LoRAModelRunnerMixin):
[-1, pcp_world_size, dcp_world_size])
return dcp_local_seq_lens
def get_split_computed_tokens(
self,
num_computed_tokens: np.ndarray,
request_ids: Optional[List[str]] = None,
request_start_rank_dict: Dict[str, tuple[
int, int]] = {}, # tuple: start_rank, tokens_blank_in_this_block
cp_kv_cache_interleave_size: int = 1
) -> list[Optional[list[Optional[list[int]]]]]:
"""Splits computed token counts across dcp and sp dimensions for distributed allocation.
Args:
num_computed_tokens: Number of tokens for each request (current chunk, not cumulative)
request_ids: Request IDs to track state
request_start_rank_dict: Dict mapping req_id to the starting rank for this chunk.
Will be updated with next starting rank after distribution.
Returns:
List of [pcp_size][dcp_size] distribution for each request
"""
self.pcp_world_size = get_pcp_group(
).world_size if prefill_context_parallel_enable() else 1
self.dcp_world_size = get_dcp_group().world_size
num_requests = len(num_computed_tokens)
assert request_start_rank_dict is not None and request_ids is not None and len(
request_ids) == num_requests
local_chunked_kv_lens = [[[0] * self.dcp_world_size
for _ in range(self.pcp_world_size)]
for _ in range(num_requests)]
total_ranks = self.pcp_world_size * self.dcp_world_size
for req_idx, (req_id, total_tokens) in enumerate(
zip(request_ids, num_computed_tokens)):
if total_tokens <= 0:
continue
# Get starting rank for this chunk
start_rank = 0
tokens_blank = 0
if request_start_rank_dict is not None:
start_rank, tokens_blank = request_start_rank_dict.get(
req_id, (0, 0))
if tokens_blank > 0: # need to continue writing in the last block of previous chunk
consumed_tokens = min(tokens_blank, total_tokens)
total_tokens -= consumed_tokens
tokens_blank -= consumed_tokens
pcp_idx = start_rank // self.dcp_world_size
dcp_idx = start_rank % self.dcp_world_size
local_chunked_kv_lens[req_idx][pcp_idx][
dcp_idx] += consumed_tokens
if tokens_blank == 0:
start_rank = (start_rank + 1) % total_ranks
if total_tokens == 0:
request_start_rank_dict[req_id] = (start_rank,
tokens_blank)
continue
virtual_size = total_ranks * cp_kv_cache_interleave_size
base = int(total_tokens) // virtual_size
# Distribute base tokens to all ranks
for rank_idx in range(total_ranks):
pcp_idx = rank_idx // self.dcp_world_size
dcp_idx = rank_idx % self.dcp_world_size
local_chunked_kv_lens[req_idx][pcp_idx][
dcp_idx] += base * cp_kv_cache_interleave_size
remainder = int(total_tokens) % virtual_size
if remainder == 0:
request_start_rank_dict[req_id] = (start_rank, tokens_blank)
continue
remain_blocks = cdiv(remainder, cp_kv_cache_interleave_size)
assert remain_blocks > 0
# Distribute remainder tokens starting from start_rank
for i in range(remain_blocks):
rank = (start_rank + i) % total_ranks
pcp_idx = rank // self.dcp_world_size
dcp_idx = rank % self.dcp_world_size
if i < remain_blocks - 1 or remainder % cp_kv_cache_interleave_size == 0: # not last block or divisible
local_chunked_kv_lens[req_idx][pcp_idx][
dcp_idx] += 1 * cp_kv_cache_interleave_size
tokens_blank = 0
else: # if last block and undivisible
local_chunked_kv_lens[req_idx][pcp_idx][
dcp_idx] += remainder % cp_kv_cache_interleave_size
tokens_blank = cp_kv_cache_interleave_size - (
remainder % cp_kv_cache_interleave_size)
start_rank = (start_rank + remain_blocks - 1) % total_ranks
if tokens_blank == 0:
start_rank = (start_rank + 1) % total_ranks
# Update next starting rank for this request
request_start_rank_dict[req_id] = (start_rank, tokens_blank)
return cast(List[Optional[List[Optional[List[int]]]]],
local_chunked_kv_lens)
def _get_chunked_req_mask_and_max_chunk(
self,
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[
Optional[list[int]]]]]]]] = None
) -> Tuple[List[bool], int]:
"""
given 4-d list [req][chunk][pcp][dcp], return:
1. if each req has any chunk (list[bool])
2. max chunk num along all reqs (int)
"""
assert local_chunked_kv_lens is not None
if len(local_chunked_kv_lens) == 0:
return ([], 0)
mask_for_non_zero_chunk = [
len(req) > 0 for req in local_chunked_kv_lens if req is not None
]
max_chunk_num = max(
(len(req) for req in local_chunked_kv_lens if req is not None),
default=0)
return mask_for_non_zero_chunk, max_chunk_num
def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens,
seq_lens_origin):
def _generate_pcp_metadata(self, total_num_scheduled_tokens):
# In dummy run num_reqs == 0, update it from seq_lens
num_reqs = self.input_batch.num_reqs or seq_lens.size(0)
num_reqs = self.input_batch.num_reqs or self.query_lens.size(0)
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
>= self.input_batch.num_prompt_tokens[:num_reqs])
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
local_chunked_kv_lens = self.input_batch.local_chunked_kv_lens[
num_decodes:num_reqs]
mask_for_non_zero_chunk, max_chunk_num = self._get_chunked_req_mask_and_max_chunk(
local_chunked_kv_lens)
long_seq_metadata = None
if self.pcp_size * self.dcp_size > 1:
decode_context_lens = self.input_batch.num_tokens[:num_decodes]
prefill_context_lens = self.input_batch.num_computed_tokens_cpu[
num_decodes:num_reqs]
context_lens = np.concatenate(
[decode_context_lens, prefill_context_lens])
num_computed_tokens_of_pcp_dcp = torch.zeros(
[
num_reqs * self.decode_threshold, self.pcp_size,
@@ -4584,8 +4384,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
for decode_idx in range(self.decode_threshold):
num_computed_tokens_of_pcp_dcp[
self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \
self._get_pcp_local_seq_lens(
seq_lens_origin - decode_idx,
self._get_cp_local_seq_lens(
torch.tensor(context_lens),
self.pcp_size,
self.dcp_size,
self.parallel_config.cp_kv_cache_interleave_size,
@@ -4593,10 +4393,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
long_seq_metadata = AscendPrefillContextParallelMetadata(
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.
numpy(),
local_chunked_kv_lens=local_chunked_kv_lens,
mask_for_non_zero_chunk=mask_for_non_zero_chunk,
max_chunk_num=max_chunk_num)
numpy())
if self.pcp_size > 1:
q_head_idx, q_tail_idx = [], []
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
@@ -4607,7 +4404,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_req_offset = 0
q_head_chunk_id = self.pcp_rank
q_tail_chunk_id = self.pcp_size * 2 - 1 - self.pcp_rank
for i, seq_len in enumerate(seq_lens):
for i, seq_len in enumerate(self.query_lens):
if i < num_decodes:
continue
chunk_len = seq_len // 2

View File

@@ -73,12 +73,6 @@ class CachedRequestState:
lora_request: Optional[LoRARequest] = None
prompt_embeds: Optional[torch.Tensor] = None
# pcp/dcp param
local_chunked_kv_lens: Optional[list[Optional[list[Optional[
list[int]]]]]] = None # Records computed tokens for each chunk
next_pcp_dcp_start_rank: int = 0 # Tracks next starting rank for round-robin distribution
token_blank_in_last_blk: int = 0 # if the last block is not full, how many future tokens can be stored
def __post_init__(self):
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds)
@@ -319,10 +313,6 @@ class InputBatch:
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
self.prev_req_id_to_index: Optional[dict[str, int]] = None
# pcp/dcp parameters
self.local_chunked_kv_lens: list[Optional[list[Optional[list[Optional[
list[int]]]]]]] = [None] * max_num_reqs
@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
@@ -395,9 +385,6 @@ class InputBatch:
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)
# Add PCP/DCP tracking fields
self.local_chunked_kv_lens[req_index] = request.local_chunked_kv_lens
if sampling_params := request.sampling_params:
if (self.is_spec_decode
and is_spec_decode_unsupported(sampling_params)):
@@ -693,8 +680,6 @@ class InputBatch:
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.local_chunked_kv_lens[
empty_index] = self.local_chunked_kv_lens[last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]