### What this PR does / why we need it? This PR adds the Prefill Context Parallelism (PCP) feature, which corresponds to DCP. For specific implementation details, please refer to the RFC https://github.com/vllm-project/vllm/issues/25749. TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage. ### Does this PR introduce _any_ user-facing change? The current implementation primarily includes the following changes: Modified ModelRunner.py for CP partitioning logic for tokens; Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to PCP. Modified block_tables.py to extend the KV cache storage based on DCP&PCP; Added necessary command-line arguments to control parallelism for PCP; ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: chenjie <chenjie137@huawei.com> Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: Feng Liu <liufeng248@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: chenjie <chenjie137@huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com> Co-authored-by: Feng Liu <liufeng248@huawei.com> Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: w00896881 <wangzixuan40@huawei.com>
217 lines
6.5 KiB
Python
217 lines
6.5 KiB
Python
from dataclasses import dataclass
|
|
from typing import Any, List, Optional
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|
has_kv_transfer_group,
|
|
is_v1_kv_transfer_group)
|
|
from vllm.forward_context import ForwardContext, get_forward_context
|
|
|
|
|
|
@dataclass
|
|
# class AscendCommonLongSequenceMetadata:
|
|
class AscendPrefillContextParallelMetadata:
|
|
pcp_allgather_restore_idx: torch.Tensor = None
|
|
|
|
num_actual_tokens_pcp_padded: Optional[int] = None
|
|
|
|
num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
|
|
list[int]]]]]] = None
|
|
|
|
q_head_idx_tensor: torch.Tensor = None
|
|
|
|
q_tail_idx_tensor: torch.Tensor = None
|
|
|
|
kv_with_q_head_nomask_idx_tensor: torch.Tensor = None
|
|
|
|
kv_with_q_head_mask_idx_tensor: torch.Tensor = None
|
|
|
|
kv_with_q_tail_nomask_idx_tensor: torch.Tensor = None
|
|
|
|
kv_with_q_tail_mask_idx_tensor: torch.Tensor = None
|
|
|
|
attn_mask_seqlens: torch.Tensor = None
|
|
|
|
head_attn_nomask_seqlens: torch.Tensor = None
|
|
|
|
tail_attn_nomask_seqlens: torch.Tensor = None
|
|
|
|
q_full_idx: torch.Tensor = None
|
|
|
|
pcp_prefill_mask: torch.Tensor = None
|
|
|
|
|
|
@dataclass
|
|
class AscendCommonAttentionMetadata:
|
|
"""
|
|
Per-batch attention metadata, shared across layers and backends.
|
|
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
|
|
|
For many of the tensors we keep both GPU and CPU versions.
|
|
"""
|
|
|
|
query_start_loc: torch.Tensor
|
|
query_start_loc_cpu: torch.Tensor
|
|
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
|
|
|
seq_lens_cpu: torch.Tensor
|
|
"""(batch_size,), the length of each request including both computed tokens
|
|
and newly scheduled tokens"""
|
|
|
|
seq_lens: torch.Tensor
|
|
"""same to seq_lens_cpu, for compatibility with some new attn metadata
|
|
(such as GDN)."""
|
|
|
|
num_computed_tokens_cpu: torch.Tensor
|
|
"""(batch_size,), the number of computed tokens for each request"""
|
|
|
|
num_reqs: int
|
|
"""Number of requests"""
|
|
num_actual_tokens: int
|
|
"""Total number of tokens in batch"""
|
|
|
|
max_query_len: int
|
|
"""Max token number of request in batch"""
|
|
|
|
decode_token_per_req: int
|
|
"""decode token number per request"""
|
|
|
|
block_table_tensor: torch.Tensor
|
|
|
|
slot_mapping: torch.Tensor
|
|
|
|
actual_seq_lengths_q: list[int]
|
|
|
|
positions: torch.Tensor = None
|
|
|
|
attn_mask: torch.Tensor = None
|
|
|
|
spec_attn_mask: torch.Tensor = None
|
|
|
|
attn_state: Any = None
|
|
|
|
enable_dbo_across_dp: bool = False
|
|
|
|
is_only_prefill: bool = False
|
|
|
|
graph_pad_size: int = -1
|
|
|
|
# num_input_tokens refers to total number of tokens including
|
|
# padding tokens. It is used to handle some padding operations.
|
|
num_input_tokens: int = 0
|
|
|
|
# NOTE: This is a temporary solution for rotary embedding in MLA
|
|
cos: torch.Tensor = None
|
|
sin: torch.Tensor = None
|
|
|
|
prefill_context_parallel_metadata: Optional[
|
|
AscendPrefillContextParallelMetadata] = None
|
|
|
|
|
|
def split_decodes_and_prefills(
|
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
|
decode_threshold: int = 1,
|
|
) -> tuple[int, int, int, int]:
|
|
"""
|
|
Assuming a reordered batch, finds the boundary between prefill and decode
|
|
requests.
|
|
|
|
Args:
|
|
common_attn_metadata: AscendCommonAttentionMetadata object containing the
|
|
batch metadata.
|
|
decode_threshold: The maximum query length to be considered a decode.
|
|
|
|
Returns:
|
|
num_decodes: The number of decode requests.
|
|
num_prefills: The number of prefill requests.
|
|
num_decode_tokens: The number of tokens in the decode requests.
|
|
num_prefill_tokens: The number of tokens in the prefill requests.
|
|
"""
|
|
max_query_len = common_attn_metadata.max_query_len
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
num_tokens = common_attn_metadata.num_actual_tokens
|
|
query_start_loc = common_attn_metadata.query_start_loc_cpu
|
|
|
|
if max_query_len <= decode_threshold:
|
|
return num_reqs, 0, num_tokens, 0
|
|
|
|
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
|
is_prefill = query_lens > decode_threshold
|
|
if not torch.any(is_prefill):
|
|
return num_reqs, 0, num_tokens, 0
|
|
|
|
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
|
num_decodes = first_prefill
|
|
num_prefills = num_reqs - num_decodes
|
|
num_decode_tokens = query_start_loc[first_prefill].item()
|
|
num_prefill_tokens = num_tokens - num_decode_tokens
|
|
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
|
|
|
|
|
|
def wait_for_kv_layer_from_connector(layer_name: str):
|
|
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
|
return
|
|
|
|
connector = get_kv_transfer_group()
|
|
|
|
forward_context: ForwardContext = get_forward_context()
|
|
attn_metadata = forward_context.attn_metadata
|
|
if attn_metadata is None:
|
|
return
|
|
# TODO: assert ascendMetadata
|
|
connector.wait_for_layer_load(layer_name)
|
|
|
|
|
|
def maybe_save_kv_layer_to_connector(
|
|
layer_name: str,
|
|
kv_cache_layer: List[torch.Tensor],
|
|
):
|
|
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
|
return
|
|
|
|
connector = get_kv_transfer_group()
|
|
|
|
forward_context: ForwardContext = get_forward_context()
|
|
attn_metadata = forward_context.attn_metadata
|
|
if attn_metadata is None:
|
|
return
|
|
# TODO: assert ascendMetadata
|
|
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
|
|
|
|
|
|
def round_up(val: int, align: int) -> int:
|
|
if align == 0:
|
|
return 0
|
|
return -(val // -align) * align
|
|
|
|
|
|
def trans_rope_weight(weight, rope_dim):
|
|
if rope_dim == 0:
|
|
return weight.contiguous()
|
|
nope_part = weight[..., :-rope_dim, :]
|
|
rope_part = weight[..., -rope_dim:, :]
|
|
reordered_rope_part = torch.cat(
|
|
(rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2)
|
|
return torch.cat((nope_part, reordered_rope_part), dim=-2).contiguous()
|
|
|
|
|
|
def transdata(nd_mat, block_size: tuple = (16, 16)):
|
|
r = round_up(nd_mat.shape[0], block_size[0])
|
|
c = round_up(nd_mat.shape[1], block_size[1])
|
|
r_pad = r - nd_mat.shape[0]
|
|
c_pad = c - nd_mat.shape[1]
|
|
nd_mat = F.pad(nd_mat, (0, r_pad, 0, c_pad))
|
|
nz_mat = torch.permute(
|
|
torch.reshape(
|
|
nd_mat,
|
|
(r // block_size[0], block_size[0], c // block_size[1],
|
|
block_size[1]),
|
|
),
|
|
[2, 0, 1, 3],
|
|
)
|
|
nz_mat = torch.reshape(
|
|
nz_mat,
|
|
(nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]))
|
|
return nz_mat
|