Files
xc-llm-ascend/vllm_ascend/attention/common_cp.py
wujinyuan1 7ff1db4b84 [Refactor]5/N Extract common code of mla_v1.py & extract mla_cp (#5097)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
The functions related to Cp differ significantly from those of normal
MLA-Attention, but the coupling is quite severe.

Steps:
1)Extract common code AscendMLAMetadataBuilder.build to 4 functions: 
build_prefill_metadata, build_decode_metadata,build_cp_metadata,
build_chunked_metadata

todo:
1)refactor function _compute_prefill_context;
2)refactor function _mla_preprocess,_mla_decode_preprocess
3)Extract public data and processing functions from the attention_cp.py
and mla_cp.py files to the common_cp file.

vLLM version: 0.13.0rc3
vLLM main:
ad32e3e19c

- vLLM version: 0.13.0rc3
- vLLM main:
ad32e3e19c

---------

Signed-off-by: wujinyuan1 <wjy9595@qq.com>
Signed-off-by: wujinyuan1 <wujinyuan1@huawei.com>
Co-authored-by: wujinyuan1 <wjy9595@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
2025-12-24 10:25:19 +08:00

41 lines
1.3 KiB
Python

from dataclasses import dataclass
from typing import Optional
import torch
@dataclass
class AscendPCPMetadata:
q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None
kv_with_q_head_mask_idx: torch.Tensor = None
kv_with_q_tail_nomask_idx: torch.Tensor = None
kv_with_q_tail_mask_idx: torch.Tensor = None
attn_mask_seqlens: torch.Tensor = None
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
pcp_allgather_restore_idx: Optional[list[int]] = None
@dataclass
class CPChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor
# for mla DCP & PCP
padded_chunk_seq_lens_npu: torch.Tensor = None
padded_local_chunk_seq_lens: Optional[list[list[int]]] = None
local_context_lens_allranks: Optional[list[list[int]]] = None
padded_local_cu_seq_lens: torch.Tensor = None
cu_seq_lens_lst: Optional[list[list[int]]] = None
chunk_size: Optional[int] = None