[Refactor]7/N Extract common code to common_cp (#5490)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: Eliminate duplicate code for two file(mla_cp.py attention_cp.py) to common_cp.py. vLLM version: 0.13.0rc3 vLLM main:ad32e3e19cvLLM version: release/v0.13.0 vLLM main:5fbfa8d9ef- vLLM version: v0.13.0 - vLLM main:5326c89803--------- Signed-off-by: wujinyuan1 <wjy9595@qq.com> Signed-off-by: wujinyuan1 <wujinyuan1@huawei.com> Co-authored-by: wujinyuan1 <wjy9595@qq.com>
This commit is contained in:
@@ -19,8 +19,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
||||
from vllm_ascend import envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.common_cp import (AscendPCPMetadata,
|
||||
CPChunkedContextMetadata)
|
||||
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||
AscendPCPMetadata, CPChunkedContextMetadata)
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
enable_cp,
|
||||
maybe_save_kv_layer_to_connector,
|
||||
@@ -46,6 +46,8 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
||||
BUILD_METADATA_STEP_PREFILL = 0
|
||||
BUILD_METADATA_STEP_DECODE = 1
|
||||
|
||||
|
||||
class AscendMLABackend(AttentionBackend):
|
||||
@@ -61,7 +63,8 @@ class AscendMLABackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_builder_cls():
|
||||
if enable_cp():
|
||||
from vllm_ascend.attention.mla_cp import AscendMlaCPMetadataBuilder
|
||||
from vllm_ascend.attention.context_parallel.mla_cp import \
|
||||
AscendMlaCPMetadataBuilder
|
||||
return AscendMlaCPMetadataBuilder
|
||||
return AscendMLAMetadataBuilder
|
||||
|
||||
@@ -73,7 +76,8 @@ class AscendMLABackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["MLAAttentionImpl"]:
|
||||
if enable_cp():
|
||||
from vllm_ascend.attention.mla_cp import AscendMlaCPImpl
|
||||
from vllm_ascend.attention.context_parallel.mla_cp import \
|
||||
AscendMlaCPImpl
|
||||
return AscendMlaCPImpl
|
||||
return AscendMLAImpl
|
||||
|
||||
@@ -418,7 +422,11 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
self.query_lens = query_seq_lens_cpu[:num_reqs]
|
||||
self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
|
||||
self.set_prefill_block_table(common_attn_metadata)
|
||||
self.graph_pad_size = common_attn_metadata.graph_pad_size
|
||||
block_table_size = self.get_block_table_size(
|
||||
common_attn_metadata, BUILD_METADATA_STEP_PREFILL)
|
||||
self.block_table = common_attn_metadata.block_table_tensor[:
|
||||
block_table_size]
|
||||
|
||||
prefill_metadata = None
|
||||
if self.num_prefills > 0:
|
||||
@@ -499,23 +507,16 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
|
||||
def set_prefill_block_table(
|
||||
self,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
):
|
||||
# If graph_pad_size > -1, mean is running in fullgraph mode.
|
||||
self.graph_pad_size = common_attn_metadata.graph_pad_size
|
||||
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
||||
if self.graph_pad_size > common_attn_metadata.num_reqs and self.speculative_config.disable_padded_drafter_batch:
|
||||
self.block_table = (
|
||||
common_attn_metadata.block_table_tensor[:self.graph_pad_size])
|
||||
else:
|
||||
self.block_table = (
|
||||
common_attn_metadata.block_table_tensor[:common_attn_metadata.
|
||||
num_reqs])
|
||||
|
||||
def set_decode_block_table(self):
|
||||
self.block_table = self.block_table[:self.num_decodes, ...]
|
||||
def get_block_table_size(
|
||||
self, common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
build_metadata_step: int):
|
||||
if build_metadata_step == BUILD_METADATA_STEP_PREFILL:
|
||||
# If graph_pad_size > -1, mean is running in fullgraph mode.
|
||||
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
||||
if self.graph_pad_size > common_attn_metadata.num_reqs and self.speculative_config.disable_padded_drafter_batch:
|
||||
return self.graph_pad_size
|
||||
return common_attn_metadata.num_reqs
|
||||
return self.num_decodes
|
||||
|
||||
def build_prefill_metadata(
|
||||
self,
|
||||
@@ -574,7 +575,9 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
self.seq_lens = self.seq_lens[:self.num_decodes]
|
||||
input_positions = input_positions[:self.num_decode_tokens]
|
||||
|
||||
self.set_decode_block_table()
|
||||
block_table_size = self.get_block_table_size(
|
||||
common_attn_metadata, BUILD_METADATA_STEP_DECODE)
|
||||
self.block_table = self.block_table[:block_table_size]
|
||||
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
|
||||
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
|
||||
|
||||
Reference in New Issue
Block a user