[Refactor]5/N Extract common code of mla_v1.py & extract mla_cp (#5097)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: The functions related to Cp differ significantly from those of normal MLA-Attention, but the coupling is quite severe. Steps: 1)Extract common code AscendMLAMetadataBuilder.build to 4 functions: build_prefill_metadata, build_decode_metadata,build_cp_metadata, build_chunked_metadata todo: 1)refactor function _compute_prefill_context; 2)refactor function _mla_preprocess,_mla_decode_preprocess 3)Extract public data and processing functions from the attention_cp.py and mla_cp.py files to the common_cp file. vLLM version: 0.13.0rc3 vLLM main:ad32e3e19c- vLLM version: 0.13.0rc3 - vLLM main:ad32e3e19c--------- Signed-off-by: wujinyuan1 <wjy9595@qq.com> Signed-off-by: wujinyuan1 <wujinyuan1@huawei.com> Co-authored-by: wujinyuan1 <wjy9595@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -6,8 +6,9 @@ from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.common_cp import CPChunkedContextMetadata
|
||||
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
|
||||
from vllm_ascend.attention.mla_v1 import AscendMLAPrefillMetadata
|
||||
from vllm_ascend.attention.mla_v1 import ChunkedContextMetadata
|
||||
|
||||
|
||||
def get_pcp_split_info(pcp_rank, pcp_size, seq_lens):
|
||||
@@ -127,7 +128,7 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes,
|
||||
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
||||
chunked_context_metadata = CPChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True),
|
||||
starts=local_chunk_starts.to(non_blocking=True),
|
||||
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
|
||||
@@ -144,16 +145,15 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes,
|
||||
chunk_size=padded_local_max_context_chunk_across_ranks,
|
||||
)
|
||||
else:
|
||||
chunked_context_metadata = (
|
||||
AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True),
|
||||
starts=chunk_starts.to(non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens,
|
||||
workspace=None,
|
||||
))
|
||||
chunked_context_metadata = (ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(non_blocking=True),
|
||||
starts=chunk_starts.to(non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens,
|
||||
workspace=None,
|
||||
))
|
||||
return chunked_context_metadata
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user