[Refactor] Fix AttentionMaskBuilder singleton and remove redundant pcp_prefill_mask (#4870)
## What this PR does / why we need it? This PR fixes the `AttentionMaskBuilder` singleton initialization issue introduced in PR #4779 and removes the unused `pcp_prefill_mask` field. ### Background After PR #4779 made `AttentionMaskBuilder` a singleton with `@singleton` decorator, the class constructor now requires a `device` parameter. However, two initialization sites were still using the old parameterless constructor, causing failures. ### Changes 1. **Fix singleton initialization** - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendMLAMetadataBuilder.__init__()` - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendAttentionMetadataBuilder.__init__()` 2. **Remove unused field** - Removed `pcp_prefill_mask` field from `AscendPrefillContextParallelMetadata` (never used in codebase) - Updated related test assertions ### Related - Issue #5463 - PR #4779 (Unify all mask generation methods) - PR #5389 (Make AttentionMaskBuilder singleton) ## Does this PR introduce _any_ user-facing change? No. This is an internal refactoring. ## How was this patch tested? - ✅ Local testing: No linter errors - ✅ Unit tests for attention modules verified - ⏳ CI pipeline Signed-off-by: lico67373 <918688502@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -244,8 +244,15 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
def test_ascend_mla_metadata_builder_build_full_graph(
|
||||
self, mock_get_cos_and_sin_mla):
|
||||
self, mock_get_pcp_group, mock_get_pcp_group_mask,
|
||||
mock_get_cos_and_sin_mla):
|
||||
pcp_group = MagicMock()
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
mock_get_pcp_group_mask.return_value = pcp_group
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -400,14 +407,21 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.kv_cache_spec.num_heads = 32
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
|
||||
@patch("torch.Tensor.npu", new=lambda self: self)
|
||||
@patch("torch.npu.is_available")
|
||||
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
|
||||
mock_zeros,
|
||||
mock_zeros, mock_get_pcp_group,
|
||||
mock_get_pcp_group_mask,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_npu_available.return_value = False
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
pcp_group = MagicMock()
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
mock_get_pcp_group_mask.return_value = pcp_group
|
||||
|
||||
def zeros_override(*args, **kwargs):
|
||||
kwargs.pop('pin_memory', None)
|
||||
@@ -426,8 +440,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((10, 10)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None,
|
||||
@@ -458,14 +470,21 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
|
||||
@patch("torch.Tensor.npu", new=lambda self: self)
|
||||
@patch("torch.npu.is_available")
|
||||
def test_build_chunked_prefix_metadata(self, mock_npu_available,
|
||||
mock_zeros,
|
||||
mock_zeros, mock_get_pcp_group,
|
||||
mock_get_pcp_group_mask,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_npu_available.return_value = False
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
pcp_group = MagicMock()
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
mock_get_pcp_group_mask.return_value = pcp_group
|
||||
|
||||
def zeros_override(*args, **kwargs):
|
||||
kwargs.pop('pin_memory', None)
|
||||
@@ -485,8 +504,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((15, 15)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.ChunkedPrefill,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None,
|
||||
@@ -517,8 +534,16 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
def test_build_decode_only_metadata(self, mock_get_cos_and_sin_mla):
|
||||
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
def test_build_decode_only_metadata(self, mock_get_pcp_group,
|
||||
mock_get_pcp_group_mask,
|
||||
mock_get_cos_and_sin_mla):
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
pcp_group = MagicMock()
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
mock_get_pcp_group_mask.return_value = pcp_group
|
||||
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
||||
@@ -532,8 +557,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((3, 3)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None,
|
||||
@@ -563,9 +586,16 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
def test_build_for_graph_capture_decode_only(self,
|
||||
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
def test_build_for_graph_capture_decode_only(self, mock_get_pcp_group,
|
||||
mock_get_pcp_group_mask,
|
||||
mock_get_cos_and_sin_mla):
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
pcp_group = MagicMock()
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
mock_get_pcp_group_mask.return_value = pcp_group
|
||||
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
||||
@@ -579,8 +609,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((3, 3)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None,
|
||||
@@ -625,8 +653,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((10, 10)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None,
|
||||
|
||||
Reference in New Issue
Block a user