[Refactor] AttentionBuilder inherit from base class in vllm (#5916)

### What this PR does / why we need it?

This PR makes `AscendMLAMetadataBuilder` and `AscendSFAMetadataBuilder`
properly inherit from the base class `MLACommonMetadataBuilder` in vllm
by adding `super().__init__()` calls.

**Changes:**
- Add `super().__init__()` call in `AscendMLAMetadataBuilder.__init__()`
- Add `super().__init__()` call in `AscendSFAMetadataBuilder.__init__()`
- Extract `ascend_chunked_prefill_workspace_size()` to
`vllm_ascend/attention/utils.py` to avoid code duplication
- Override `determine_chunked_prefill_workspace_size()` to support
Ascend-specific 128k tokens workspace size (vs 64k in parent class)
- Update unit tests to mock parent class `__init__` for proper isolation

**Why we need it:**
- Follow proper Python inheritance patterns by calling
`super().__init__()`
- Reduce code duplication by reusing parent class initialization logic
- Better maintainability as parent class changes will be automatically
inherited

Part of issue #5463 item 10

### Does this PR introduce _any_ user-facing change?

No, this is an internal refactoring that does not change any user-facing
behavior.

Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
LICO67373
2026-01-21 10:45:45 +08:00
committed by GitHub
parent 839e03cbc9
commit 12a668b1d9
5 changed files with 158 additions and 40 deletions

View File

@@ -99,13 +99,44 @@ class TestAscendSFAMetadataBuilder(TestBase):
return_value=self.mock_cfg)
self.patcher.start()
# Mock parent class __init__ to avoid complex initialization,
# but still set the essential attributes that child class needs
def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config,
device, metadata_cls, supports_dcp_with_varlen):
self.metadata_cls = metadata_cls
self.kv_cache_spec = kv_cache_spec
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.device = device
self.chunked_prefill_workspace_size = 128 * 1024
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
vllm_config.model_config.get_head_size()),
dtype=vllm_config.model_config.dtype,
device=device,
)
self.parent_init_patcher = patch(
"vllm.v1.attention.backends.mla.common.MLACommonMetadataBuilder.__init__",
mock_parent_init)
self.parent_init_patcher.start()
if hasattr(enable_dsa_cp, "cache_clear"):
enable_dsa_cp.cache_clear()
def tearDown(self):
self.patcher.stop()
self.parent_init_patcher.stop()
def test_ascend_sfa_metadata_builder_default(self):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
vllm_config.cache_config.block_size = 16
vllm_config.model_config.max_model_len = 1024
vllm_config.model_config.get_head_size.return_value = 64
vllm_config.model_config.dtype = torch.float16
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
@@ -138,6 +169,11 @@ class TestAscendSFAMetadataBuilder(TestBase):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
vllm_config.cache_config.block_size = 16
vllm_config.model_config.max_model_len = 1024
vllm_config.model_config.get_head_size.return_value = 64
vllm_config.model_config.dtype = torch.float16
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
@@ -190,6 +226,11 @@ class TestAscendSFAMetadataBuilder(TestBase):
kv_cache_spec = MagicMock()
layer_names = ["layer1", "layer2"]
vllm_config = MagicMock()
vllm_config.cache_config.block_size = 16
vllm_config.model_config.max_model_len = 1024
vllm_config.model_config.get_head_size.return_value = 64
vllm_config.model_config.dtype = torch.float16
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config