[Refactor] AttentionBuilder inherit from base class in vllm (#5916)

### What this PR does / why we need it?

This PR makes `AscendMLAMetadataBuilder` and `AscendSFAMetadataBuilder`
properly inherit from the base class `MLACommonMetadataBuilder` in vllm
by adding `super().__init__()` calls.

**Changes:**
- Add `super().__init__()` call in `AscendMLAMetadataBuilder.__init__()`
- Add `super().__init__()` call in `AscendSFAMetadataBuilder.__init__()`
- Extract `ascend_chunked_prefill_workspace_size()` to
`vllm_ascend/attention/utils.py` to avoid code duplication
- Override `determine_chunked_prefill_workspace_size()` to support
Ascend-specific 128k tokens workspace size (vs 64k in parent class)
- Update unit tests to mock parent class `__init__` for proper isolation

**Why we need it:**
- Follow proper Python inheritance patterns by calling
`super().__init__()`
- Reduce code duplication by reusing parent class initialization logic
- Better maintainability as parent class changes will be automatically
inherited

Part of issue #5463 item 10

### Does this PR introduce _any_ user-facing change?

No, this is an internal refactoring that does not change any user-facing
behavior.

Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
LICO67373
2026-01-21 10:45:45 +08:00
committed by GitHub
parent 839e03cbc9
commit 12a668b1d9
5 changed files with 158 additions and 40 deletions

View File

@@ -20,6 +20,7 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
ascend_chunked_prefill_workspace_size,
maybe_save_kv_layer_to_connector,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
@@ -131,7 +132,6 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
understand this class
"""
# _attn_mask_builder = None
def __init__(
self,
kv_cache_spec,
@@ -141,11 +141,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
metadata_cls: type[AscendSFAMetadata] | None = None,
supports_dcp_with_varlen: bool = False,
):
self.metadata_cls = (metadata_cls if metadata_cls is not None else
AscendSFAMetadata)
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.device = device
super().__init__(
kv_cache_spec, layer_names, vllm_config, device,
metadata_cls if metadata_cls is not None else AscendSFAMetadata,
supports_dcp_with_varlen)
self.block_size = vllm_config.cache_config.block_size
self.max_blocks = (vllm_config.model_config.max_model_len +
self.block_size - 1) // self.block_size
@@ -169,6 +169,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
self.attn_mask_builder = AttentionMaskBuilder(self.device)
@staticmethod
def determine_chunked_prefill_workspace_size(
vllm_config: VllmConfig) -> int:
return ascend_chunked_prefill_workspace_size(vllm_config)
@classmethod
def get_cudagraph_support(
cls: type["AscendSFAMetadataBuilder"],