RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: The functions related to Cp differ significantly from those of normal MLA-Attention, but the coupling is quite severe. Steps: 1)Extract common code AscendMLAMetadataBuilder.build to 4 functions: build_prefill_metadata, build_decode_metadata,build_cp_metadata, build_chunked_metadata todo: 1)refactor function _compute_prefill_context; 2)refactor function _mla_preprocess,_mla_decode_preprocess 3)Extract public data and processing functions from the attention_cp.py and mla_cp.py files to the common_cp file. vLLM version: 0.13.0rc3 vLLM main:ad32e3e19c- vLLM version: 0.13.0rc3 - vLLM main:ad32e3e19c--------- Signed-off-by: wujinyuan1 <wjy9595@qq.com> Signed-off-by: wujinyuan1 <wujinyuan1@huawei.com> Co-authored-by: wujinyuan1 <wjy9595@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
|
|
@dataclass
|
|
class AscendPCPMetadata:
|
|
q_head_idx: torch.Tensor = None
|
|
q_tail_idx: torch.Tensor = None
|
|
kv_with_q_head_nomask_idx: torch.Tensor = None
|
|
kv_with_q_head_mask_idx: torch.Tensor = None
|
|
kv_with_q_tail_nomask_idx: torch.Tensor = None
|
|
kv_with_q_tail_mask_idx: torch.Tensor = None
|
|
attn_mask_seqlens: torch.Tensor = None
|
|
head_attn_nomask_seqlens: torch.Tensor = None
|
|
tail_attn_nomask_seqlens: torch.Tensor = None
|
|
q_full_idx: torch.Tensor = None
|
|
pcp_prefill_mask: torch.Tensor = None
|
|
pcp_allgather_restore_idx: Optional[list[int]] = None
|
|
|
|
|
|
@dataclass
|
|
class CPChunkedContextMetadata:
|
|
# New for MLA (compared to FlashAttention)
|
|
# For handling chunked prefill
|
|
cu_seq_lens: torch.Tensor
|
|
starts: torch.Tensor
|
|
seq_tot: list[int]
|
|
max_seq_lens: list[int]
|
|
workspace: torch.Tensor
|
|
chunk_seq_lens: torch.Tensor
|
|
chunk_seq_lens_npu: torch.Tensor
|
|
# for mla DCP & PCP
|
|
padded_chunk_seq_lens_npu: torch.Tensor = None
|
|
padded_local_chunk_seq_lens: Optional[list[list[int]]] = None
|
|
local_context_lens_allranks: Optional[list[list[int]]] = None
|
|
padded_local_cu_seq_lens: torch.Tensor = None
|
|
cu_seq_lens_lst: Optional[list[list[int]]] = None
|
|
chunk_size: Optional[int] = None
|