[Refactor] Fix AttentionMaskBuilder singleton and remove redundant pcp_prefill_mask (#4870)

## What this PR does / why we need it?

This PR fixes the `AttentionMaskBuilder` singleton initialization issue
introduced in PR #4779 and removes the unused `pcp_prefill_mask` field.

### Background

After PR #4779 made `AttentionMaskBuilder` a singleton with `@singleton`
decorator, the class constructor now requires a `device` parameter.
However, two initialization sites were still using the old parameterless
constructor, causing failures.

### Changes

1. **Fix singleton initialization**
- Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)`
in `AscendMLAMetadataBuilder.__init__()`
- Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)`
in `AscendAttentionMetadataBuilder.__init__()`

2. **Remove unused field**
- Removed `pcp_prefill_mask` field from
`AscendPrefillContextParallelMetadata` (never used in codebase)
   - Updated related test assertions

### Related

- Issue #5463
- PR #4779 (Unify all mask generation methods)
- PR #5389 (Make AttentionMaskBuilder singleton)

## Does this PR introduce _any_ user-facing change?

No. This is an internal refactoring.

## How was this patch tested?

-  Local testing: No linter errors
-  Unit tests for attention modules verified
-  CI pipeline

Signed-off-by: lico67373 <918688502@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
LICO67373
2026-01-07 17:09:52 +08:00
committed by GitHub
parent 91790fd85a
commit 380f089fbf
21 changed files with 118 additions and 148 deletions

View File

@@ -498,7 +498,7 @@ class PCPManager:
torch.float32).argsort().to(torch.int32)
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens,
attn_mask, input_batch):
input_batch):
from vllm_ascend.attention.utils import \
AscendPrefillContextParallelMetadata
num_reqs = input_batch.num_reqs or query_lens.size(0)
@@ -523,7 +523,7 @@ class PCPManager:
dtype=torch.int32,
)
# For pcp + spec decode, we flatten seq_lens
# to avoid irregular spec_attn_mask shape.
# to avoid irregular attn_mask shape.
# Same as block_table, we flatten decode seq_lens to query_lens,
# and keep prefill seq_lens unchanged.
for decode_idx in range(self.decode_threshold):
@@ -657,13 +657,11 @@ class PCPManager:
split_with_q_head_nomask_idx_reqs,
split_kv_with_q_tail_nomask_idx_reqs,
head_attn_nomask_seqlens, chunk_seqlens)
pcp_prefill_mask = attn_mask
self.extra_long_seq_kwargs = {
'attn_mask_seqlens': attn_mask_seqlens,
'head_attn_nomask_seqlens': head_attn_nomask_seqlens,
'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens,
'pcp_prefill_mask': pcp_prefill_mask
'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens
}
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[:
num_actual_tokens_pcp_padded]
@@ -685,8 +683,6 @@ class PCPManager:
'head_attn_nomask_seqlens']
long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[
'tail_attn_nomask_seqlens']
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
'pcp_prefill_mask']
if self.vllm_config.model_config.use_mla:
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = split_q_head_nomask_idx_tensor_list
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list