[Refactor]7/N Extract common code to common_cp (#5490)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
Eliminate duplicate code for two file(mla_cp.py attention_cp.py) to
common_cp.py.

vLLM version: 0.13.0rc3
vLLM main:
ad32e3e19c

vLLM version: release/v0.13.0
vLLM main:
5fbfa8d9ef

- vLLM version: v0.13.0
- vLLM main:
5326c89803

---------

Signed-off-by: wujinyuan1 <wjy9595@qq.com>
Signed-off-by: wujinyuan1 <wujinyuan1@huawei.com>
Co-authored-by: wujinyuan1 <wjy9595@qq.com>
This commit is contained in:
wujinyuan1
2026-01-05 17:41:12 +08:00
committed by GitHub
parent 755caeb06e
commit 4a3663327b
10 changed files with 252 additions and 301 deletions

View File

@@ -19,8 +19,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.common_cp import (AscendPCPMetadata,
CPChunkedContextMetadata)
from vllm_ascend.attention.context_parallel.common_cp import (
AscendPCPMetadata, CPChunkedContextMetadata)
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
enable_cp,
maybe_save_kv_layer_to_connector,
@@ -46,6 +46,8 @@ if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
BUILD_METADATA_STEP_PREFILL = 0
BUILD_METADATA_STEP_DECODE = 1
class AscendMLABackend(AttentionBackend):
@@ -61,7 +63,8 @@ class AscendMLABackend(AttentionBackend):
@staticmethod
def get_builder_cls():
if enable_cp():
from vllm_ascend.attention.mla_cp import AscendMlaCPMetadataBuilder
from vllm_ascend.attention.context_parallel.mla_cp import \
AscendMlaCPMetadataBuilder
return AscendMlaCPMetadataBuilder
return AscendMLAMetadataBuilder
@@ -73,7 +76,8 @@ class AscendMLABackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["MLAAttentionImpl"]:
if enable_cp():
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
from vllm_ascend.attention.context_parallel.mla_cp import \
AscendMlaCPImpl
return AscendMlaCPImpl
return AscendMLAImpl
@@ -418,7 +422,11 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
self.query_lens = query_seq_lens_cpu[:num_reqs]
self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
self.set_prefill_block_table(common_attn_metadata)
self.graph_pad_size = common_attn_metadata.graph_pad_size
block_table_size = self.get_block_table_size(
common_attn_metadata, BUILD_METADATA_STEP_PREFILL)
self.block_table = common_attn_metadata.block_table_tensor[:
block_table_size]
prefill_metadata = None
if self.num_prefills > 0:
@@ -499,23 +507,16 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
workspace=self.chunked_prefill_workspace,
)
def set_prefill_block_table(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
):
# If graph_pad_size > -1, mean is running in fullgraph mode.
self.graph_pad_size = common_attn_metadata.graph_pad_size
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
if self.graph_pad_size > common_attn_metadata.num_reqs and self.speculative_config.disable_padded_drafter_batch:
self.block_table = (
common_attn_metadata.block_table_tensor[:self.graph_pad_size])
else:
self.block_table = (
common_attn_metadata.block_table_tensor[:common_attn_metadata.
num_reqs])
def set_decode_block_table(self):
self.block_table = self.block_table[:self.num_decodes, ...]
def get_block_table_size(
self, common_attn_metadata: AscendCommonAttentionMetadata,
build_metadata_step: int):
if build_metadata_step == BUILD_METADATA_STEP_PREFILL:
# If graph_pad_size > -1, mean is running in fullgraph mode.
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
if self.graph_pad_size > common_attn_metadata.num_reqs and self.speculative_config.disable_padded_drafter_batch:
return self.graph_pad_size
return common_attn_metadata.num_reqs
return self.num_decodes
def build_prefill_metadata(
self,
@@ -574,7 +575,9 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
self.seq_lens = self.seq_lens[:self.num_decodes]
input_positions = input_positions[:self.num_decode_tokens]
self.set_decode_block_table()
block_table_size = self.get_block_table_size(
common_attn_metadata, BUILD_METADATA_STEP_DECODE)
self.block_table = self.block_table[:block_table_size]
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.