[Refactor] Fix AttentionMaskBuilder singleton and remove redundant pcp_prefill_mask (#4870)

## What this PR does / why we need it?

This PR fixes the `AttentionMaskBuilder` singleton initialization issue
introduced in PR #4779 and removes the unused `pcp_prefill_mask` field.

### Background

After PR #4779 made `AttentionMaskBuilder` a singleton with `@singleton`
decorator, the class constructor now requires a `device` parameter.
However, two initialization sites were still using the old parameterless
constructor, causing failures.

### Changes

1. **Fix singleton initialization**
- Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)`
in `AscendMLAMetadataBuilder.__init__()`
- Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)`
in `AscendAttentionMetadataBuilder.__init__()`

2. **Remove unused field**
- Removed `pcp_prefill_mask` field from
`AscendPrefillContextParallelMetadata` (never used in codebase)
   - Updated related test assertions

### Related

- Issue #5463
- PR #4779 (Unify all mask generation methods)
- PR #5389 (Make AttentionMaskBuilder singleton)

## Does this PR introduce _any_ user-facing change?

No. This is an internal refactoring.

## How was this patch tested?

-  Local testing: No linter errors
-  Unit tests for attention modules verified
-  CI pipeline

Signed-off-by: lico67373 <918688502@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
LICO67373
2026-01-07 17:09:52 +08:00
committed by GitHub
parent 91790fd85a
commit 380f089fbf
21 changed files with 118 additions and 148 deletions

View File

@@ -53,6 +53,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
self.mock_vllm_config = MagicMock()
self.mock_vllm_config.speculative_config = None
self.mock_vllm_config.model_config.max_model_len = 640
self.mock_vllm_config.model_config.hf_text_config.sliding_window = None
self.mock_vllm_config.cache_config.block_size = 64
self.mock_vllm_config.compilation_config.cudagraph_mode = None
self.mock_vllm_config.scheduler_config.max_num_seqs = 10
@@ -89,8 +90,6 @@ class TestAscendAttentionMetadataBuilder(TestBase):
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill,
num_computed_tokens_cpu=None,
seq_lens=None,

View File

@@ -1004,8 +1004,6 @@ class TestAscendMLAImpl(TestBase):
[chunk_seqlens, chunk_seqlens], dtype=torch.int32)
attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens = kv_with_q_head_nomask_seqlens
attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens = kv_with_q_tail_nomask_seqlens
attn_metadata.prefill.pcp_metadata.pcp_prefill_mask = torch.triu(
torch.ones(10, 10, dtype=torch.float16), 1)
output = self.impl._forward_prefill(q_nope, q_pe, k_nope, k_pe,
value, kv_c_and_k_pe_cache,

View File

@@ -244,8 +244,15 @@ class TestAscendMLAMetadataBuilder(TestBase):
mock_vllm_config.scheduler_config.enable_chunked_prefill)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
@patch('vllm.distributed.parallel_state.get_pcp_group')
def test_ascend_mla_metadata_builder_build_full_graph(
self, mock_get_cos_and_sin_mla):
self, mock_get_pcp_group, mock_get_pcp_group_mask,
mock_get_cos_and_sin_mla):
pcp_group = MagicMock()
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
mock_get_pcp_group_mask.return_value = pcp_group
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
@@ -400,14 +407,21 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
self.kv_cache_spec.num_heads = 32
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
@patch("torch.Tensor.npu", new=lambda self: self)
@patch("torch.npu.is_available")
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
mock_zeros,
mock_zeros, mock_get_pcp_group,
mock_get_pcp_group_mask,
mock_get_cos_and_sin_mla):
mock_npu_available.return_value = False
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock()
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
mock_get_pcp_group_mask.return_value = pcp_group
def zeros_override(*args, **kwargs):
kwargs.pop('pin_memory', None)
@@ -426,8 +440,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((10, 10)),
spec_attn_mask=None,
attn_state=AscendAttentionState.PrefillNoCache,
num_computed_tokens_cpu=None,
seq_lens=None,
@@ -458,14 +470,21 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
@patch("torch.Tensor.npu", new=lambda self: self)
@patch("torch.npu.is_available")
def test_build_chunked_prefix_metadata(self, mock_npu_available,
mock_zeros,
mock_zeros, mock_get_pcp_group,
mock_get_pcp_group_mask,
mock_get_cos_and_sin_mla):
mock_npu_available.return_value = False
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock()
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
mock_get_pcp_group_mask.return_value = pcp_group
def zeros_override(*args, **kwargs):
kwargs.pop('pin_memory', None)
@@ -485,8 +504,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill,
num_computed_tokens_cpu=None,
seq_lens=None,
@@ -517,8 +534,16 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
def test_build_decode_only_metadata(self, mock_get_cos_and_sin_mla):
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
@patch('vllm.distributed.parallel_state.get_pcp_group')
def test_build_decode_only_metadata(self, mock_get_pcp_group,
mock_get_pcp_group_mask,
mock_get_cos_and_sin_mla):
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock()
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
mock_get_pcp_group_mask.return_value = pcp_group
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 1, 2, 3]),
@@ -532,8 +557,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
decode_token_per_req=torch.tensor([1, 1, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((3, 3)),
spec_attn_mask=None,
attn_state=AscendAttentionState.DecodeOnly,
num_computed_tokens_cpu=None,
seq_lens=None,
@@ -563,9 +586,16 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
def test_build_for_graph_capture_decode_only(self,
@patch('vllm_ascend.attention.attention_mask.get_pcp_group')
@patch('vllm.distributed.parallel_state.get_pcp_group')
def test_build_for_graph_capture_decode_only(self, mock_get_pcp_group,
mock_get_pcp_group_mask,
mock_get_cos_and_sin_mla):
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock()
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
mock_get_pcp_group_mask.return_value = pcp_group
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 1, 2, 3]),
@@ -579,8 +609,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
decode_token_per_req=torch.tensor([1, 1, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((3, 3)),
spec_attn_mask=None,
attn_state=AscendAttentionState.DecodeOnly,
num_computed_tokens_cpu=None,
seq_lens=None,
@@ -625,8 +653,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
slot_mapping=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((10, 10)),
spec_attn_mask=None,
attn_state=AscendAttentionState.PrefillNoCache,
num_computed_tokens_cpu=None,
seq_lens=None,

View File

@@ -291,8 +291,6 @@ class TestMtpProposer:
mock_runner = MagicMock()
mock_runner.actual_seq_lengths_q = MagicMock()
mock_runner.attn_mask = MagicMock()
mock_runner.spec_attn_mask = MagicMock()
mock_runner.attn_state = MagicMock()
mock_runner.graph_pad_size = 0
mock_runner.decode_token_per_req = MagicMock()
@@ -334,5 +332,3 @@ class TestMtpProposer:
assert spec_common_attn_metadata.num_actual_tokens == total_num_tokens
assert spec_common_attn_metadata.max_query_len == 8
assert spec_common_attn_metadata.actual_seq_lengths_q == proposer.runner.actual_seq_lengths_q
assert spec_common_attn_metadata.attn_mask == proposer.runner.attn_mask
assert spec_common_attn_metadata.spec_attn_mask == proposer.runner.spec_attn_mask

View File

@@ -70,7 +70,7 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
input_batch.num_tokens = torch.tensor(num_tokens)
query_lens = torch.tensor(query_lens)
result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens, None,
result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens,
input_batch)
if not expect_not_none:
@@ -97,21 +97,6 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
assert hasattr(result, 'head_attn_nomask_seqlens')
assert hasattr(result, 'tail_attn_nomask_seqlens')
if hasattr(result, 'pcp_prefill_mask'
) and result.pcp_prefill_mask is not None:
if use_mla:
assert result.pcp_prefill_mask.shape == (512, 512)
else:
assert result.pcp_prefill_mask.shape == (2048, 2048)
else:
if hasattr(result, 'pcp_prefill_mask'):
if result.pcp_prefill_mask is not None:
if use_mla:
assert result.pcp_prefill_mask.shape == (512, 512)
else:
assert result.pcp_prefill_mask.shape == (2048,
2048)
@pytest.mark.parametrize(
"tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens",