[Refactor] Fix AttentionMaskBuilder singleton and remove redundant pcp_prefill_mask (#4870)

## What this PR does / why we need it?

This PR fixes the `AttentionMaskBuilder` singleton initialization issue
introduced in PR #4779 and removes the unused `pcp_prefill_mask` field.

### Background

After PR #4779 made `AttentionMaskBuilder` a singleton with `@singleton`
decorator, the class constructor now requires a `device` parameter.
However, two initialization sites were still using the old parameterless
constructor, causing failures.

### Changes

1. **Fix singleton initialization**
- Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)`
in `AscendMLAMetadataBuilder.__init__()`
- Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)`
in `AscendAttentionMetadataBuilder.__init__()`

2. **Remove unused field**
- Removed `pcp_prefill_mask` field from
`AscendPrefillContextParallelMetadata` (never used in codebase)
   - Updated related test assertions

### Related

- Issue #5463
- PR #4779 (Unify all mask generation methods)
- PR #5389 (Make AttentionMaskBuilder singleton)

## Does this PR introduce _any_ user-facing change?

No. This is an internal refactoring.

## How was this patch tested?

-  Local testing: No linter errors
-  Unit tests for attention modules verified
-  CI pipeline

Signed-off-by: lico67373 <918688502@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
LICO67373
2026-01-07 17:09:52 +08:00
committed by GitHub
parent 91790fd85a
commit 380f089fbf
21 changed files with 118 additions and 148 deletions

View File

@@ -18,6 +18,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.context_parallel.common_cp import (
AscendPCPMetadata, CPChunkedContextMetadata)
@@ -263,6 +264,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
self.graph_pad_size = 0
self.query_lens: torch.Tensor = None
self.seq_lens: torch.Tensor = None
self.attn_mask_builder = AttentionMaskBuilder(self.device)
@classmethod
def get_cudagraph_support(
@@ -448,7 +450,8 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
num_decodes=self.num_decodes,
num_decode_tokens=self.num_decode_tokens,
num_prefills=self.num_prefills,
attn_mask=common_attn_metadata.attn_mask,
attn_mask=self.attn_mask_builder.get_final_mla_mask(
self.model_config),
attn_state=common_attn_metadata.attn_state,
prefill=prefill_metadata,
decode=decode_metadata,
@@ -542,7 +545,8 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
prefill_input_positions = input_positions[tokens_start:]
cos, sin = get_cos_and_sin_mla(prefill_input_positions)
return AscendMLAPrefillMetadata(
attn_mask=common_attn_metadata.attn_mask,
attn_mask=self.attn_mask_builder.get_final_mla_mask(
self.model_config),
query_lens=self.query_lens[reqs_start:].to(torch.int32),
seq_lens=self.seq_lens,
context_lens=self.seq_lens[reqs_start:],
@@ -643,7 +647,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
seq_lens=self.seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(),
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin[:self.num_decode_tokens, ...],
cos=cos[:self.num_decode_tokens, ...],
@@ -1197,7 +1201,7 @@ class AscendMLAImpl(MLAAttentionImpl):
# Output shape: [num_heads, num_tokens, dim]
attn_output_shape = (self.num_heads, num_tokens, self.kv_lora_rank)
sparse_mode = 3
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
attn_mask = attn_metadata.decode.attn_mask # type:ignore
actual_seq_lengths = decode_meta.actual_seq_lengths_q
else:
# The output layout is set to NBSD to eliminate the need for a
@@ -1218,7 +1222,7 @@ class AscendMLAImpl(MLAAttentionImpl):
attn_output_shape = (self.num_heads, num_tokens, 1,
self.kv_lora_rank)
sparse_mode = 0
spec_attn_mask = None
attn_mask = None
common_kwargs = {
'query_rope': q_pe,
@@ -1226,7 +1230,7 @@ class AscendMLAImpl(MLAAttentionImpl):
'num_heads': self.num_heads,
'num_key_value_heads': self.num_kv_heads,
'input_layout': input_layout,
'atten_mask': spec_attn_mask,
'atten_mask': attn_mask,
'sparse_mode': sparse_mode,
'scale': self.scale,
'antiquant_mode': 0,
@@ -1269,8 +1273,8 @@ class AscendMLAImpl(MLAAttentionImpl):
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
weak_ref_tensors(q_pe), weak_ref_tensors(k_pe),
self.num_heads, self.num_kv_heads, input_layout,
weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None
else None, sparse_mode, self.scale, decode_meta.block_table,
weak_ref_tensors(attn_mask) if attn_mask is not None else
None, sparse_mode, self.scale, decode_meta.block_table,
block_size, decode_meta.seq_lens_list, actual_seq_lengths,
weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse)))