[long_seq_Feat] support chunk prefill (#4158)
### What this PR does / why we need it?
1、qwen GQA attention_v1 optim
2、DeepSeek MLA refactor, all gather q -> all gather kv
3、modelrunner refactor for chunk prefill, we remove some code not use
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: LookAround <lixushi@huawei.com>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
This commit is contained in:
@@ -29,7 +29,7 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Manager
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional,
|
||||
Tuple, Union, cast)
|
||||
Union, cast)
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@@ -763,7 +763,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
backward_kwargs["mm_features"] = new_req_data.mm_features
|
||||
|
||||
# Create request state - PCP/DCP tracking will be computed below
|
||||
req_state = CachedRequestState(
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
prompt_embeds=new_req_data.prompt_embeds,
|
||||
@@ -774,42 +774,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
lora_request=new_req_data.lora_request,
|
||||
local_chunked_kv_lens=None,
|
||||
**backward_kwargs,
|
||||
)
|
||||
|
||||
# Compute PCP/DCP tracking fields for chunked prefill
|
||||
self.input_batch.local_chunked_kv_lens = [None] * self.max_num_reqs
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
num_computed_tokens = new_req_data.num_computed_tokens
|
||||
if num_computed_tokens > 0:
|
||||
# Initialize with starting rank 0
|
||||
temp_start_rank_dict = {req_id: (0, 0)}
|
||||
|
||||
# Compute token distribution for initial tokens
|
||||
current_distribution = self.get_split_computed_tokens(
|
||||
np.array([num_computed_tokens]),
|
||||
request_ids=[req_id],
|
||||
request_start_rank_dict=temp_start_rank_dict,
|
||||
cp_kv_cache_interleave_size=self.parallel_config.
|
||||
cp_kv_cache_interleave_size,
|
||||
)[0]
|
||||
|
||||
# Update next_pcp_dcp_start_rank
|
||||
req_state.next_pcp_dcp_start_rank = temp_start_rank_dict[
|
||||
req_id][0]
|
||||
req_state.token_blank_in_last_blk = temp_start_rank_dict[
|
||||
req_id][1]
|
||||
|
||||
req_state.local_chunked_kv_lens = [
|
||||
copy.deepcopy(current_distribution)
|
||||
]
|
||||
else:
|
||||
# No computed tokens yet
|
||||
req_state.local_chunked_kv_lens = []
|
||||
|
||||
self.requests[req_id] = req_state
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
self._init_mrope_positions(self.requests[req_id])
|
||||
@@ -826,44 +793,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
||||
|
||||
# Update the cached states.
|
||||
prev_num_computed_tokens = req_state.num_computed_tokens
|
||||
req_state.num_computed_tokens = num_computed_tokens
|
||||
|
||||
# Compute PCP/DCP tracking fields for chunked prefill
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
# If this is the first chunk, initialize tracking fields
|
||||
if req_state.local_chunked_kv_lens is None:
|
||||
req_state.local_chunked_kv_lens = []
|
||||
|
||||
# Compute tokens added in this chunk (not cumulative)
|
||||
chunk_tokens = num_computed_tokens - prev_num_computed_tokens
|
||||
|
||||
if chunk_tokens > 0:
|
||||
# Create a temporary dict with this request's starting rank
|
||||
temp_start_rank_dict = {
|
||||
req_id: (req_state.next_pcp_dcp_start_rank,
|
||||
req_state.token_blank_in_last_blk)
|
||||
}
|
||||
|
||||
# Compute distribution for this chunk only
|
||||
chunk_distribution = self.get_split_computed_tokens(
|
||||
np.array([chunk_tokens]),
|
||||
request_ids=[req_id],
|
||||
request_start_rank_dict=temp_start_rank_dict,
|
||||
cp_kv_cache_interleave_size=self.parallel_config.
|
||||
cp_kv_cache_interleave_size,
|
||||
)[0]
|
||||
|
||||
# Update next_pcp_dcp_start_rank for this request
|
||||
req_state.next_pcp_dcp_start_rank = temp_start_rank_dict[
|
||||
req_id][0]
|
||||
req_state.token_blank_in_last_blk = temp_start_rank_dict[
|
||||
req_id][1]
|
||||
|
||||
# Append this chunk's distribution to accumulation list
|
||||
req_state.local_chunked_kv_lens.append(
|
||||
copy.deepcopy(chunk_distribution))
|
||||
|
||||
if not is_last_rank:
|
||||
# When using PP, the scheduler sends the sampled tokens back,
|
||||
# because there's no direct communication between the first-
|
||||
@@ -908,10 +839,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch.block_table.append_row(
|
||||
new_block_ids, req_index)
|
||||
|
||||
# Update PCP/DCP tracking fields in input_batch
|
||||
self.input_batch.local_chunked_kv_lens[
|
||||
req_index] = req_state.local_chunked_kv_lens
|
||||
|
||||
# For the last rank, we don't need to update the token_ids_cpu
|
||||
# because the sampled tokens are already cached.
|
||||
if not is_last_rank:
|
||||
@@ -1477,7 +1404,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
|
||||
def generate_kv_idx(self, tokens, scheduler_output):
|
||||
def generate_kv_idx(self, scheduler_output):
|
||||
if not self.pcp_size > 1:
|
||||
return
|
||||
self.cp_kv_recover_idx_for_chunk = [[] for _ in range(self.pcp_size)]
|
||||
@@ -1548,12 +1475,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||
arange,
|
||||
)
|
||||
self.generate_kv_idx(tokens, scheduler_output)
|
||||
|
||||
self.input_batch.block_table.compute_slot_mapping(
|
||||
req_indices, positions_np)
|
||||
self.input_batch.block_table.commit_slot_mapping(
|
||||
total_num_scheduled_tokens)
|
||||
if self.pcp_size > 1:
|
||||
if not self.vllm_config.model_config.use_mla:
|
||||
self.generate_kv_idx(scheduler_output)
|
||||
tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
|
||||
tokens)
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
@@ -1905,18 +1834,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_reqs, scheduler_output.total_num_scheduled_tokens,
|
||||
scheduler_output.num_scheduled_tokens)
|
||||
|
||||
# prepare pcp meta data
|
||||
# For chunked prefill, use num_scheduled_tokens instead of cumulative seq_lens
|
||||
# to correctly calculate chunk_len in _generate_pcp_metadata
|
||||
if self.vllm_config.scheduler_config.chunked_prefill_enabled and self.pcp_size > 1:
|
||||
# In chunked prefill, seq_lens_for_chunk should be the current chunk size
|
||||
seq_lens_for_chunk = torch.from_numpy(
|
||||
num_scheduled_tokens[:num_reqs])
|
||||
else:
|
||||
# Normal mode: use cumulative sequence lengths
|
||||
seq_lens_for_chunk = seq_lens_cpu
|
||||
long_seq_metadata = self._generate_pcp_metadata(
|
||||
total_num_scheduled_tokens, seq_lens_for_chunk, seq_lens_cpu)
|
||||
total_num_scheduled_tokens)
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
@@ -2842,6 +2761,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor
|
||||
self.query_start_loc_cpu[1:num_reqs +
|
||||
1] = torch.Tensor(cu_num_tokens)
|
||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||
|
||||
num_computed_tokens_cpu = (
|
||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
@@ -2855,8 +2775,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
long_seq_metadata = self._generate_pcp_metadata(
|
||||
num_tokens, self.seq_lens_cpu, self.seq_lens_cpu)
|
||||
long_seq_metadata = self._generate_pcp_metadata(num_tokens)
|
||||
if long_seq_metadata is not None:
|
||||
pcp_world_size = get_pcp_group(
|
||||
).world_size if prefill_context_parallel_enable() else 1
|
||||
@@ -4411,7 +4330,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
all_positions_tensor.float().argsort().long(), non_blocking=True)
|
||||
return pcp_tokens, positions, unpad_mask
|
||||
|
||||
def _get_pcp_local_seq_lens(
|
||||
def _get_cp_local_seq_lens(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
pcp_world_size: int = 1,
|
||||
@@ -4439,139 +4358,20 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
[-1, pcp_world_size, dcp_world_size])
|
||||
return dcp_local_seq_lens
|
||||
|
||||
def get_split_computed_tokens(
|
||||
self,
|
||||
num_computed_tokens: np.ndarray,
|
||||
request_ids: Optional[List[str]] = None,
|
||||
request_start_rank_dict: Dict[str, tuple[
|
||||
int, int]] = {}, # tuple: start_rank, tokens_blank_in_this_block
|
||||
cp_kv_cache_interleave_size: int = 1
|
||||
) -> list[Optional[list[Optional[list[int]]]]]:
|
||||
"""Splits computed token counts across dcp and sp dimensions for distributed allocation.
|
||||
|
||||
Args:
|
||||
num_computed_tokens: Number of tokens for each request (current chunk, not cumulative)
|
||||
request_ids: Request IDs to track state
|
||||
request_start_rank_dict: Dict mapping req_id to the starting rank for this chunk.
|
||||
Will be updated with next starting rank after distribution.
|
||||
|
||||
Returns:
|
||||
List of [pcp_size][dcp_size] distribution for each request
|
||||
"""
|
||||
self.pcp_world_size = get_pcp_group(
|
||||
).world_size if prefill_context_parallel_enable() else 1
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
num_requests = len(num_computed_tokens)
|
||||
assert request_start_rank_dict is not None and request_ids is not None and len(
|
||||
request_ids) == num_requests
|
||||
local_chunked_kv_lens = [[[0] * self.dcp_world_size
|
||||
for _ in range(self.pcp_world_size)]
|
||||
for _ in range(num_requests)]
|
||||
total_ranks = self.pcp_world_size * self.dcp_world_size
|
||||
|
||||
for req_idx, (req_id, total_tokens) in enumerate(
|
||||
zip(request_ids, num_computed_tokens)):
|
||||
if total_tokens <= 0:
|
||||
continue
|
||||
|
||||
# Get starting rank for this chunk
|
||||
start_rank = 0
|
||||
tokens_blank = 0
|
||||
if request_start_rank_dict is not None:
|
||||
start_rank, tokens_blank = request_start_rank_dict.get(
|
||||
req_id, (0, 0))
|
||||
|
||||
if tokens_blank > 0: # need to continue writing in the last block of previous chunk
|
||||
consumed_tokens = min(tokens_blank, total_tokens)
|
||||
total_tokens -= consumed_tokens
|
||||
tokens_blank -= consumed_tokens
|
||||
pcp_idx = start_rank // self.dcp_world_size
|
||||
dcp_idx = start_rank % self.dcp_world_size
|
||||
local_chunked_kv_lens[req_idx][pcp_idx][
|
||||
dcp_idx] += consumed_tokens
|
||||
if tokens_blank == 0:
|
||||
start_rank = (start_rank + 1) % total_ranks
|
||||
if total_tokens == 0:
|
||||
request_start_rank_dict[req_id] = (start_rank,
|
||||
tokens_blank)
|
||||
continue
|
||||
|
||||
virtual_size = total_ranks * cp_kv_cache_interleave_size
|
||||
base = int(total_tokens) // virtual_size
|
||||
|
||||
# Distribute base tokens to all ranks
|
||||
for rank_idx in range(total_ranks):
|
||||
pcp_idx = rank_idx // self.dcp_world_size
|
||||
dcp_idx = rank_idx % self.dcp_world_size
|
||||
local_chunked_kv_lens[req_idx][pcp_idx][
|
||||
dcp_idx] += base * cp_kv_cache_interleave_size
|
||||
|
||||
remainder = int(total_tokens) % virtual_size
|
||||
if remainder == 0:
|
||||
request_start_rank_dict[req_id] = (start_rank, tokens_blank)
|
||||
continue
|
||||
remain_blocks = cdiv(remainder, cp_kv_cache_interleave_size)
|
||||
assert remain_blocks > 0
|
||||
|
||||
# Distribute remainder tokens starting from start_rank
|
||||
for i in range(remain_blocks):
|
||||
rank = (start_rank + i) % total_ranks
|
||||
pcp_idx = rank // self.dcp_world_size
|
||||
dcp_idx = rank % self.dcp_world_size
|
||||
if i < remain_blocks - 1 or remainder % cp_kv_cache_interleave_size == 0: # not last block or divisible
|
||||
local_chunked_kv_lens[req_idx][pcp_idx][
|
||||
dcp_idx] += 1 * cp_kv_cache_interleave_size
|
||||
tokens_blank = 0
|
||||
else: # if last block and undivisible
|
||||
local_chunked_kv_lens[req_idx][pcp_idx][
|
||||
dcp_idx] += remainder % cp_kv_cache_interleave_size
|
||||
tokens_blank = cp_kv_cache_interleave_size - (
|
||||
remainder % cp_kv_cache_interleave_size)
|
||||
start_rank = (start_rank + remain_blocks - 1) % total_ranks
|
||||
if tokens_blank == 0:
|
||||
start_rank = (start_rank + 1) % total_ranks
|
||||
|
||||
# Update next starting rank for this request
|
||||
request_start_rank_dict[req_id] = (start_rank, tokens_blank)
|
||||
|
||||
return cast(List[Optional[List[Optional[List[int]]]]],
|
||||
local_chunked_kv_lens)
|
||||
|
||||
def _get_chunked_req_mask_and_max_chunk(
|
||||
self,
|
||||
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[
|
||||
Optional[list[int]]]]]]]] = None
|
||||
) -> Tuple[List[bool], int]:
|
||||
"""
|
||||
given 4-d list [req][chunk][pcp][dcp], return:
|
||||
1. if each req has any chunk (list[bool])
|
||||
2. max chunk num along all reqs (int)
|
||||
"""
|
||||
assert local_chunked_kv_lens is not None
|
||||
if len(local_chunked_kv_lens) == 0:
|
||||
return ([], 0)
|
||||
mask_for_non_zero_chunk = [
|
||||
len(req) > 0 for req in local_chunked_kv_lens if req is not None
|
||||
]
|
||||
max_chunk_num = max(
|
||||
(len(req) for req in local_chunked_kv_lens if req is not None),
|
||||
default=0)
|
||||
return mask_for_non_zero_chunk, max_chunk_num
|
||||
|
||||
def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens,
|
||||
seq_lens_origin):
|
||||
def _generate_pcp_metadata(self, total_num_scheduled_tokens):
|
||||
# In dummy run num_reqs == 0, update it from seq_lens
|
||||
num_reqs = self.input_batch.num_reqs or seq_lens.size(0)
|
||||
num_reqs = self.input_batch.num_reqs or self.query_lens.size(0)
|
||||
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
|
||||
>= self.input_batch.num_prompt_tokens[:num_reqs])
|
||||
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
|
||||
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
|
||||
local_chunked_kv_lens = self.input_batch.local_chunked_kv_lens[
|
||||
num_decodes:num_reqs]
|
||||
mask_for_non_zero_chunk, max_chunk_num = self._get_chunked_req_mask_and_max_chunk(
|
||||
local_chunked_kv_lens)
|
||||
long_seq_metadata = None
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
decode_context_lens = self.input_batch.num_tokens[:num_decodes]
|
||||
prefill_context_lens = self.input_batch.num_computed_tokens_cpu[
|
||||
num_decodes:num_reqs]
|
||||
context_lens = np.concatenate(
|
||||
[decode_context_lens, prefill_context_lens])
|
||||
num_computed_tokens_of_pcp_dcp = torch.zeros(
|
||||
[
|
||||
num_reqs * self.decode_threshold, self.pcp_size,
|
||||
@@ -4584,8 +4384,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
for decode_idx in range(self.decode_threshold):
|
||||
num_computed_tokens_of_pcp_dcp[
|
||||
self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \
|
||||
self._get_pcp_local_seq_lens(
|
||||
seq_lens_origin - decode_idx,
|
||||
self._get_cp_local_seq_lens(
|
||||
torch.tensor(context_lens),
|
||||
self.pcp_size,
|
||||
self.dcp_size,
|
||||
self.parallel_config.cp_kv_cache_interleave_size,
|
||||
@@ -4593,10 +4393,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.
|
||||
numpy(),
|
||||
local_chunked_kv_lens=local_chunked_kv_lens,
|
||||
mask_for_non_zero_chunk=mask_for_non_zero_chunk,
|
||||
max_chunk_num=max_chunk_num)
|
||||
numpy())
|
||||
if self.pcp_size > 1:
|
||||
q_head_idx, q_tail_idx = [], []
|
||||
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
|
||||
@@ -4607,7 +4404,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_req_offset = 0
|
||||
q_head_chunk_id = self.pcp_rank
|
||||
q_tail_chunk_id = self.pcp_size * 2 - 1 - self.pcp_rank
|
||||
for i, seq_len in enumerate(seq_lens):
|
||||
for i, seq_len in enumerate(self.query_lens):
|
||||
if i < num_decodes:
|
||||
continue
|
||||
chunk_len = seq_len // 2
|
||||
|
||||
@@ -73,12 +73,6 @@ class CachedRequestState:
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
prompt_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
# pcp/dcp param
|
||||
local_chunked_kv_lens: Optional[list[Optional[list[Optional[
|
||||
list[int]]]]]] = None # Records computed tokens for each chunk
|
||||
next_pcp_dcp_start_rank: int = 0 # Tracks next starting rank for round-robin distribution
|
||||
token_blank_in_last_blk: int = 0 # if the last block is not full, how many future tokens can be stored
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
self.prompt_token_ids, self.prompt_embeds)
|
||||
@@ -319,10 +313,6 @@ class InputBatch:
|
||||
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
|
||||
self.prev_req_id_to_index: Optional[dict[str, int]] = None
|
||||
|
||||
# pcp/dcp parameters
|
||||
self.local_chunked_kv_lens: list[Optional[list[Optional[list[Optional[
|
||||
list[int]]]]]]] = [None] * max_num_reqs
|
||||
|
||||
@property
|
||||
def req_ids(self) -> list[str]:
|
||||
# None elements should only be present transiently
|
||||
@@ -395,9 +385,6 @@ class InputBatch:
|
||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||
self.block_table.add_row(request.block_ids, req_index)
|
||||
|
||||
# Add PCP/DCP tracking fields
|
||||
self.local_chunked_kv_lens[req_index] = request.local_chunked_kv_lens
|
||||
|
||||
if sampling_params := request.sampling_params:
|
||||
if (self.is_spec_decode
|
||||
and is_spec_decode_unsupported(sampling_params)):
|
||||
@@ -693,8 +680,6 @@ class InputBatch:
|
||||
last_req_index]
|
||||
self.num_computed_tokens_cpu[
|
||||
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
||||
self.local_chunked_kv_lens[
|
||||
empty_index] = self.local_chunked_kv_lens[last_req_index]
|
||||
self.block_table.move_row(last_req_index, empty_index)
|
||||
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
||||
last_req_index]
|
||||
|
||||
Reference in New Issue
Block a user