From 12a668b1d9696163c6fc48f391d1ea18e7b42f5d Mon Sep 17 00:00:00 2001 From: LICO67373 <110013619+LICO1314@users.noreply.github.com> Date: Wed, 21 Jan 2026 10:45:45 +0800 Subject: [PATCH] [Refactor] AttentionBuilder inherit from base class in vllm (#5916) ### What this PR does / why we need it? This PR makes `AscendMLAMetadataBuilder` and `AscendSFAMetadataBuilder` properly inherit from the base class `MLACommonMetadataBuilder` in vllm by adding `super().__init__()` calls. **Changes:** - Add `super().__init__()` call in `AscendMLAMetadataBuilder.__init__()` - Add `super().__init__()` call in `AscendSFAMetadataBuilder.__init__()` - Extract `ascend_chunked_prefill_workspace_size()` to `vllm_ascend/attention/utils.py` to avoid code duplication - Override `determine_chunked_prefill_workspace_size()` to support Ascend-specific 128k tokens workspace size (vs 64k in parent class) - Update unit tests to mock parent class `__init__` for proper isolation **Why we need it:** - Follow proper Python inheritance patterns by calling `super().__init__()` - Reduce code duplication by reusing parent class initialization logic - Better maintainability as parent class changes will be automatically inherited Part of issue #5463 item 10 ### Does this PR introduce _any_ user-facing change? No, this is an internal refactoring that does not change any user-facing behavior. Signed-off-by: lico67373 <918688502@qq.com> --- tests/ut/attention/test_mla_v1.py | 63 +++++++++++++++++++++++++++++++ tests/ut/attention/test_sfa_v1.py | 41 ++++++++++++++++++++ vllm_ascend/attention/mla_v1.py | 49 ++++++++---------------- vllm_ascend/attention/sfa_v1.py | 17 ++++++--- vllm_ascend/attention/utils.py | 28 ++++++++++++++ 5 files changed, 158 insertions(+), 40 deletions(-) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 6d25fbba..4f76691b 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -208,11 +208,38 @@ class TestAscendMLAMetadata(TestBase): class TestAscendMLAMetadataBuilder(TestBase): + def setUp(self): + # Mock parent class __init__ to avoid complex initialization, + # but still set the essential attributes that child class needs + def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config, + device, metadata_cls, supports_dcp_with_varlen): + self.metadata_cls = metadata_cls + self.kv_cache_spec = kv_cache_spec + self.model_config = vllm_config.model_config + self.vllm_config = vllm_config + self.device = device + self.chunked_prefill_workspace_size = 128 * 1024 + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + vllm_config.model_config.get_head_size()), + dtype=vllm_config.model_config.dtype, + device=device, + ) + + self.parent_init_patcher = patch( + "vllm.v1.attention.backends.mla.common.MLACommonMetadataBuilder.__init__", + mock_parent_init) + self.parent_init_patcher.start() + + def tearDown(self): + self.parent_init_patcher.stop() + def test_ascend_mla_metadata_builder_default(self): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 mock_vllm_config.model_config.dtype = torch.float16 + mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64 mock_vllm_config.cache_config.block_size = 16 mock_vllm_config.scheduler_config.max_num_seqs = 4 mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 @@ -238,6 +265,7 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 mock_vllm_config.model_config.dtype = torch.float16 + mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64 mock_vllm_config.cache_config.block_size = 16 mock_vllm_config.scheduler_config.max_num_seqs = 4 mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 @@ -274,10 +302,12 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 mock_vllm_config.model_config.dtype = torch.float16 + mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64 mock_vllm_config.cache_config.block_size = 16 mock_vllm_config.scheduler_config.max_num_seqs = 4 mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_vllm_config.scheduler_config.enable_chunked_prefill = False mock_device = 'cpu' torch.Tensor.pin_memory = lambda x: x # noqa @@ -314,6 +344,9 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.model_config.get_head_size.return_value = 64 + mock_vllm_config.model_config.dtype = torch.float16 + mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64 mock_vllm_config.cache_config.block_size = 16 mock_vllm_config.scheduler_config.max_num_seqs = 4 mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 @@ -352,10 +385,12 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 mock_vllm_config.model_config.dtype = torch.float16 + mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64 mock_vllm_config.cache_config.block_size = 16 mock_vllm_config.scheduler_config.max_num_seqs = 4 mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_vllm_config.scheduler_config.enable_chunked_prefill = False mock_device = 'cpu' mock_vllm_config.speculative_config = None @@ -374,10 +409,12 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 mock_vllm_config.model_config.dtype = torch.float16 + mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64 mock_vllm_config.cache_config.block_size = 16 mock_vllm_config.scheduler_config.max_num_seqs = 4 mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_vllm_config.scheduler_config.enable_chunked_prefill = False mock_device = 'cpu' mock_vllm_config.speculative_config = None @@ -398,11 +435,34 @@ class TestAscendMLAMetadataBuilder(TestBase): class TestAscendMLAMetadataBuilderBuild(TestBase): def setUp(self): + # Mock parent class __init__ to avoid complex initialization, + # but still set the essential attributes that child class needs + def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config, + device, metadata_cls, supports_dcp_with_varlen): + self.metadata_cls = metadata_cls + self.kv_cache_spec = kv_cache_spec + self.model_config = vllm_config.model_config + self.vllm_config = vllm_config + self.device = device + self.chunked_prefill_workspace_size = 128 * 1024 + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + vllm_config.model_config.get_head_size()), + dtype=vllm_config.model_config.dtype, + device=device, + ) + + self.parent_init_patcher = patch( + "vllm.v1.attention.backends.mla.common.MLACommonMetadataBuilder.__init__", + mock_parent_init) + self.parent_init_patcher.start() + self.mock_vllm_config = MagicMock(spec=VllmConfig) self.mock_vllm_config.cache_config = CacheConfig(block_size=32) mock_scheduler_config = MagicMock(spec=SchedulerConfig) mock_scheduler_config.max_num_seqs = 8 mock_scheduler_config.chunked_prefill_enabled = True + mock_scheduler_config.enable_chunked_prefill = True self.mock_vllm_config.scheduler_config = mock_scheduler_config self.mock_vllm_config.speculative_config = None self.mock_device = torch.device("cpu") @@ -423,6 +483,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): self.kv_cache_spec.head_size = 64 self.kv_cache_spec.num_heads = 32 + def tearDown(self): + self.parent_init_patcher.stop() + @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") @patch('vllm_ascend.attention.attention_mask.get_pcp_group') @patch('vllm.distributed.parallel_state.get_pcp_group') diff --git a/tests/ut/attention/test_sfa_v1.py b/tests/ut/attention/test_sfa_v1.py index 4bcfd3c6..aca95244 100644 --- a/tests/ut/attention/test_sfa_v1.py +++ b/tests/ut/attention/test_sfa_v1.py @@ -99,13 +99,44 @@ class TestAscendSFAMetadataBuilder(TestBase): return_value=self.mock_cfg) self.patcher.start() + # Mock parent class __init__ to avoid complex initialization, + # but still set the essential attributes that child class needs + def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config, + device, metadata_cls, supports_dcp_with_varlen): + self.metadata_cls = metadata_cls + self.kv_cache_spec = kv_cache_spec + self.model_config = vllm_config.model_config + self.vllm_config = vllm_config + self.device = device + self.chunked_prefill_workspace_size = 128 * 1024 + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + vllm_config.model_config.get_head_size()), + dtype=vllm_config.model_config.dtype, + device=device, + ) + + self.parent_init_patcher = patch( + "vllm.v1.attention.backends.mla.common.MLACommonMetadataBuilder.__init__", + mock_parent_init) + self.parent_init_patcher.start() + if hasattr(enable_dsa_cp, "cache_clear"): enable_dsa_cp.cache_clear() + def tearDown(self): + self.patcher.stop() + self.parent_init_patcher.stop() + def test_ascend_sfa_metadata_builder_default(self): kv_cache_spec = MagicMock() layer_names = ["layer1", "layer2"] vllm_config = MagicMock() + vllm_config.cache_config.block_size = 16 + vllm_config.model_config.max_model_len = 1024 + vllm_config.model_config.get_head_size.return_value = 64 + vllm_config.model_config.dtype = torch.float16 + vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64 speculative_config = MagicMock() speculative_config.num_speculative_tokens = 4 vllm_config.speculative_config = speculative_config @@ -138,6 +169,11 @@ class TestAscendSFAMetadataBuilder(TestBase): kv_cache_spec = MagicMock() layer_names = ["layer1", "layer2"] vllm_config = MagicMock() + vllm_config.cache_config.block_size = 16 + vllm_config.model_config.max_model_len = 1024 + vllm_config.model_config.get_head_size.return_value = 64 + vllm_config.model_config.dtype = torch.float16 + vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64 speculative_config = MagicMock() speculative_config.num_speculative_tokens = 4 vllm_config.speculative_config = speculative_config @@ -190,6 +226,11 @@ class TestAscendSFAMetadataBuilder(TestBase): kv_cache_spec = MagicMock() layer_names = ["layer1", "layer2"] vllm_config = MagicMock() + vllm_config.cache_config.block_size = 16 + vllm_config.model_config.max_model_len = 1024 + vllm_config.model_config.get_head_size.return_value = 64 + vllm_config.model_config.dtype = torch.float16 + vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64 speculative_config = MagicMock() speculative_config.num_speculative_tokens = 4 vllm_config.speculative_config = speculative_config diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 128d7547..e76c64a7 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -19,12 +19,10 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.context_parallel.common_cp import ( AscendPCPMetadata, CPChunkedContextMetadata) -from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - enable_cp, - maybe_save_kv_layer_to_connector, - split_decodes_and_prefills, - trans_rope_weight, transdata, - wait_for_kv_layer_from_connector) +from vllm_ascend.attention.utils import ( + AscendCommonAttentionMetadata, ascend_chunked_prefill_workspace_size, + enable_cp, maybe_save_kv_layer_to_connector, split_decodes_and_prefills, + trans_rope_weight, transdata, wait_for_kv_layer_from_connector) from vllm_ascend.compilation.acl_graph import ( get_draft_graph_params, get_graph_params, update_draft_graph_params_workspaces, update_graph_params_workspaces) @@ -215,11 +213,11 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): metadata_cls: type[AscendMLAMetadata] | None = None, supports_dcp_with_varlen: bool = False, ): - self.metadata_cls = (metadata_cls if metadata_cls is not None else - AscendMLAMetadata) - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.device = device + super().__init__( + kv_cache_spec, layer_names, vllm_config, device, + metadata_cls if metadata_cls is not None else AscendMLAMetadata, + supports_dcp_with_varlen) + scheduler_config = vllm_config.scheduler_config self.block_size = vllm_config.cache_config.block_size self.max_blocks = (vllm_config.model_config.max_model_len + @@ -236,29 +234,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): got {self.decode_threshold}" self.reorder_batch_threshold = self.decode_threshold - if self.chunked_prefill_enabled: - self.chunked_prefill_workspace_size = min( - # Max sure there is enough for 8 full length request or at least - # 4 pages of cache per request - max(8 * self.model_config.max_model_len, - 4 * scheduler_config.max_num_seqs * self.block_size), - # For long-context models try not to over-allocate limiting - # kv-cache space, limiting it to 64k tokens, - # which would result in the workspace being: - # 2*(576)*(64*1024) = 144mb - # (assuming 576 MLA head dim, and fp16) - # which would result in up-projected context being - # 2*(192*128)*(64*1024) = 3gb - # (assuming 192 QK head dim, 128 heads, and fp16) - 128 * 1024) - assert self.chunked_prefill_workspace_size >= \ - scheduler_config.max_num_seqs * self.block_size - self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size, - self.model_config.get_head_size()), - dtype=self.model_config.dtype, - device=device, - ) + self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.cos_cache = None self.sin_cache = None @@ -280,6 +256,11 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): self.seq_lens: torch.Tensor = None self.attn_mask_builder = AttentionMaskBuilder(self.device) + @staticmethod + def determine_chunked_prefill_workspace_size( + vllm_config: VllmConfig) -> int: + return ascend_chunked_prefill_workspace_size(vllm_config) + @classmethod def get_cudagraph_support( cls: type["AscendMLAMetadataBuilder"], diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 3cf1d9e3..f362e62c 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -20,6 +20,7 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + ascend_chunked_prefill_workspace_size, maybe_save_kv_layer_to_connector, trans_rope_weight, transdata, wait_for_kv_layer_from_connector) @@ -131,7 +132,6 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): understand this class """ - # _attn_mask_builder = None def __init__( self, kv_cache_spec, @@ -141,11 +141,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): metadata_cls: type[AscendSFAMetadata] | None = None, supports_dcp_with_varlen: bool = False, ): - self.metadata_cls = (metadata_cls if metadata_cls is not None else - AscendSFAMetadata) - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.device = device + super().__init__( + kv_cache_spec, layer_names, vllm_config, device, + metadata_cls if metadata_cls is not None else AscendSFAMetadata, + supports_dcp_with_varlen) + self.block_size = vllm_config.cache_config.block_size self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size @@ -169,6 +169,11 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): ), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1." self.attn_mask_builder = AttentionMaskBuilder(self.device) + @staticmethod + def determine_chunked_prefill_workspace_size( + vllm_config: VllmConfig) -> int: + return ascend_chunked_prefill_workspace_size(vllm_config) + @classmethod def get_cudagraph_support( cls: type["AscendSFAMetadataBuilder"], diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 50be439b..3c6ba213 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -12,6 +12,34 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm_ascend.utils import AscendDeviceType, get_ascend_config, get_ascend_device_type +def ascend_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int: + scheduler_config = vllm_config.scheduler_config + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + + chunked_prefill_workspace_size = min( + # Make sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max(8 * model_config.max_model_len, 4 * scheduler_config.max_num_seqs * cache_config.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 128k tokens, + # which would result in the workspace being: + # 2*(576)*(128*1024) = 288mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(128*1024) = 6gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024, + ) + + chunked_prefill_workspace_size = max( + chunked_prefill_workspace_size, + scheduler_config.max_num_seqs * cache_config.block_size, + ) + + return chunked_prefill_workspace_size + + def using_paged_attention(runtime_shape: int, vllm_config: VllmConfig) -> bool: if vllm_config.speculative_config is not None: return False