[Refactor] AttentionBuilder inherit from base class in vllm (#5916)
### What this PR does / why we need it? This PR makes `AscendMLAMetadataBuilder` and `AscendSFAMetadataBuilder` properly inherit from the base class `MLACommonMetadataBuilder` in vllm by adding `super().__init__()` calls. **Changes:** - Add `super().__init__()` call in `AscendMLAMetadataBuilder.__init__()` - Add `super().__init__()` call in `AscendSFAMetadataBuilder.__init__()` - Extract `ascend_chunked_prefill_workspace_size()` to `vllm_ascend/attention/utils.py` to avoid code duplication - Override `determine_chunked_prefill_workspace_size()` to support Ascend-specific 128k tokens workspace size (vs 64k in parent class) - Update unit tests to mock parent class `__init__` for proper isolation **Why we need it:** - Follow proper Python inheritance patterns by calling `super().__init__()` - Reduce code duplication by reusing parent class initialization logic - Better maintainability as parent class changes will be automatically inherited Part of issue #5463 item 10 ### Does this PR introduce _any_ user-facing change? No, this is an internal refactoring that does not change any user-facing behavior. Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
@@ -19,12 +19,10 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||
AscendPCPMetadata, CPChunkedContextMetadata)
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
enable_cp,
|
||||
maybe_save_kv_layer_to_connector,
|
||||
split_decodes_and_prefills,
|
||||
trans_rope_weight, transdata,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.attention.utils import (
|
||||
AscendCommonAttentionMetadata, ascend_chunked_prefill_workspace_size,
|
||||
enable_cp, maybe_save_kv_layer_to_connector, split_decodes_and_prefills,
|
||||
trans_rope_weight, transdata, wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.compilation.acl_graph import (
|
||||
get_draft_graph_params, get_graph_params,
|
||||
update_draft_graph_params_workspaces, update_graph_params_workspaces)
|
||||
@@ -215,11 +213,11 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
metadata_cls: type[AscendMLAMetadata] | None = None,
|
||||
supports_dcp_with_varlen: bool = False,
|
||||
):
|
||||
self.metadata_cls = (metadata_cls if metadata_cls is not None else
|
||||
AscendMLAMetadata)
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.device = device
|
||||
super().__init__(
|
||||
kv_cache_spec, layer_names, vllm_config, device,
|
||||
metadata_cls if metadata_cls is not None else AscendMLAMetadata,
|
||||
supports_dcp_with_varlen)
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.max_blocks = (vllm_config.model_config.max_model_len +
|
||||
@@ -236,29 +234,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
got {self.decode_threshold}"
|
||||
|
||||
self.reorder_batch_threshold = self.decode_threshold
|
||||
if self.chunked_prefill_enabled:
|
||||
self.chunked_prefill_workspace_size = min(
|
||||
# Max sure there is enough for 8 full length request or at least
|
||||
# 4 pages of cache per request
|
||||
max(8 * self.model_config.max_model_len,
|
||||
4 * scheduler_config.max_num_seqs * self.block_size),
|
||||
# For long-context models try not to over-allocate limiting
|
||||
# kv-cache space, limiting it to 64k tokens,
|
||||
# which would result in the workspace being:
|
||||
# 2*(576)*(64*1024) = 144mb
|
||||
# (assuming 576 MLA head dim, and fp16)
|
||||
# which would result in up-projected context being
|
||||
# 2*(192*128)*(64*1024) = 3gb
|
||||
# (assuming 192 QK head dim, 128 heads, and fp16)
|
||||
128 * 1024)
|
||||
assert self.chunked_prefill_workspace_size >= \
|
||||
scheduler_config.max_num_seqs * self.block_size
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
self.model_config.get_head_size()),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
self.cos_cache = None
|
||||
self.sin_cache = None
|
||||
@@ -280,6 +256,11 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
self.seq_lens: torch.Tensor = None
|
||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||
|
||||
@staticmethod
|
||||
def determine_chunked_prefill_workspace_size(
|
||||
vllm_config: VllmConfig) -> int:
|
||||
return ascend_chunked_prefill_workspace_size(vllm_config)
|
||||
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
cls: type["AscendMLAMetadataBuilder"],
|
||||
|
||||
Reference in New Issue
Block a user