[Refactor] AttentionBuilder inherit from base class in vllm (#5916)
### What this PR does / why we need it? This PR makes `AscendMLAMetadataBuilder` and `AscendSFAMetadataBuilder` properly inherit from the base class `MLACommonMetadataBuilder` in vllm by adding `super().__init__()` calls. **Changes:** - Add `super().__init__()` call in `AscendMLAMetadataBuilder.__init__()` - Add `super().__init__()` call in `AscendSFAMetadataBuilder.__init__()` - Extract `ascend_chunked_prefill_workspace_size()` to `vllm_ascend/attention/utils.py` to avoid code duplication - Override `determine_chunked_prefill_workspace_size()` to support Ascend-specific 128k tokens workspace size (vs 64k in parent class) - Update unit tests to mock parent class `__init__` for proper isolation **Why we need it:** - Follow proper Python inheritance patterns by calling `super().__init__()` - Reduce code duplication by reusing parent class initialization logic - Better maintainability as parent class changes will be automatically inherited Part of issue #5463 item 10 ### Does this PR introduce _any_ user-facing change? No, this is an internal refactoring that does not change any user-facing behavior. Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
@@ -208,11 +208,38 @@ class TestAscendMLAMetadata(TestBase):
|
||||
|
||||
class TestAscendMLAMetadataBuilder(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Mock parent class __init__ to avoid complex initialization,
|
||||
# but still set the essential attributes that child class needs
|
||||
def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config,
|
||||
device, metadata_cls, supports_dcp_with_varlen):
|
||||
self.metadata_cls = metadata_cls
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.model_config = vllm_config.model_config
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
self.chunked_prefill_workspace_size = 128 * 1024
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
vllm_config.model_config.get_head_size()),
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.parent_init_patcher = patch(
|
||||
"vllm.v1.attention.backends.mla.common.MLACommonMetadataBuilder.__init__",
|
||||
mock_parent_init)
|
||||
self.parent_init_patcher.start()
|
||||
|
||||
def tearDown(self):
|
||||
self.parent_init_patcher.stop()
|
||||
|
||||
def test_ascend_mla_metadata_builder_default(self):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
mock_vllm_config.model_config.dtype = torch.float16
|
||||
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||
@@ -238,6 +265,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
mock_vllm_config.model_config.dtype = torch.float16
|
||||
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||
@@ -274,10 +302,12 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
mock_vllm_config.model_config.dtype = torch.float16
|
||||
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
mock_device = 'cpu'
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
|
||||
@@ -314,6 +344,9 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
mock_vllm_config.model_config.dtype = torch.float16
|
||||
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||
@@ -352,10 +385,12 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
mock_vllm_config.model_config.dtype = torch.float16
|
||||
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
mock_device = 'cpu'
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
@@ -374,10 +409,12 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
mock_vllm_config.model_config.dtype = torch.float16
|
||||
mock_vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.decode_max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
mock_device = 'cpu'
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
@@ -398,11 +435,34 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Mock parent class __init__ to avoid complex initialization,
|
||||
# but still set the essential attributes that child class needs
|
||||
def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config,
|
||||
device, metadata_cls, supports_dcp_with_varlen):
|
||||
self.metadata_cls = metadata_cls
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.model_config = vllm_config.model_config
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
self.chunked_prefill_workspace_size = 128 * 1024
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
vllm_config.model_config.get_head_size()),
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.parent_init_patcher = patch(
|
||||
"vllm.v1.attention.backends.mla.common.MLACommonMetadataBuilder.__init__",
|
||||
mock_parent_init)
|
||||
self.parent_init_patcher.start()
|
||||
|
||||
self.mock_vllm_config = MagicMock(spec=VllmConfig)
|
||||
self.mock_vllm_config.cache_config = CacheConfig(block_size=32)
|
||||
mock_scheduler_config = MagicMock(spec=SchedulerConfig)
|
||||
mock_scheduler_config.max_num_seqs = 8
|
||||
mock_scheduler_config.chunked_prefill_enabled = True
|
||||
mock_scheduler_config.enable_chunked_prefill = True
|
||||
self.mock_vllm_config.scheduler_config = mock_scheduler_config
|
||||
self.mock_vllm_config.speculative_config = None
|
||||
self.mock_device = torch.device("cpu")
|
||||
@@ -423,6 +483,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.kv_cache_spec.head_size = 64
|
||||
self.kv_cache_spec.num_heads = 32
|
||||
|
||||
def tearDown(self):
|
||||
self.parent_init_patcher.stop()
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
|
||||
@@ -99,13 +99,44 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
||||
return_value=self.mock_cfg)
|
||||
self.patcher.start()
|
||||
|
||||
# Mock parent class __init__ to avoid complex initialization,
|
||||
# but still set the essential attributes that child class needs
|
||||
def mock_parent_init(self, kv_cache_spec, layer_names, vllm_config,
|
||||
device, metadata_cls, supports_dcp_with_varlen):
|
||||
self.metadata_cls = metadata_cls
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.model_config = vllm_config.model_config
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
self.chunked_prefill_workspace_size = 128 * 1024
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
vllm_config.model_config.get_head_size()),
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.parent_init_patcher = patch(
|
||||
"vllm.v1.attention.backends.mla.common.MLACommonMetadataBuilder.__init__",
|
||||
mock_parent_init)
|
||||
self.parent_init_patcher.start()
|
||||
|
||||
if hasattr(enable_dsa_cp, "cache_clear"):
|
||||
enable_dsa_cp.cache_clear()
|
||||
|
||||
def tearDown(self):
|
||||
self.patcher.stop()
|
||||
self.parent_init_patcher.stop()
|
||||
|
||||
def test_ascend_sfa_metadata_builder_default(self):
|
||||
kv_cache_spec = MagicMock()
|
||||
layer_names = ["layer1", "layer2"]
|
||||
vllm_config = MagicMock()
|
||||
vllm_config.cache_config.block_size = 16
|
||||
vllm_config.model_config.max_model_len = 1024
|
||||
vllm_config.model_config.get_head_size.return_value = 64
|
||||
vllm_config.model_config.dtype = torch.float16
|
||||
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
||||
speculative_config = MagicMock()
|
||||
speculative_config.num_speculative_tokens = 4
|
||||
vllm_config.speculative_config = speculative_config
|
||||
@@ -138,6 +169,11 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
||||
kv_cache_spec = MagicMock()
|
||||
layer_names = ["layer1", "layer2"]
|
||||
vllm_config = MagicMock()
|
||||
vllm_config.cache_config.block_size = 16
|
||||
vllm_config.model_config.max_model_len = 1024
|
||||
vllm_config.model_config.get_head_size.return_value = 64
|
||||
vllm_config.model_config.dtype = torch.float16
|
||||
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
||||
speculative_config = MagicMock()
|
||||
speculative_config.num_speculative_tokens = 4
|
||||
vllm_config.speculative_config = speculative_config
|
||||
@@ -190,6 +226,11 @@ class TestAscendSFAMetadataBuilder(TestBase):
|
||||
kv_cache_spec = MagicMock()
|
||||
layer_names = ["layer1", "layer2"]
|
||||
vllm_config = MagicMock()
|
||||
vllm_config.cache_config.block_size = 16
|
||||
vllm_config.model_config.max_model_len = 1024
|
||||
vllm_config.model_config.get_head_size.return_value = 64
|
||||
vllm_config.model_config.dtype = torch.float16
|
||||
vllm_config.model_config.hf_text_config.qk_rope_head_dim = 64
|
||||
speculative_config = MagicMock()
|
||||
speculative_config.num_speculative_tokens = 4
|
||||
vllm_config.speculative_config = speculative_config
|
||||
|
||||
Reference in New Issue
Block a user