[Refactor] Fix AttentionMaskBuilder singleton and remove redundant pcp_prefill_mask (#4870)

## What this PR does / why we need it?

This PR fixes the `AttentionMaskBuilder` singleton initialization issue
introduced in PR #4779 and removes the unused `pcp_prefill_mask` field.

### Background

After PR #4779 made `AttentionMaskBuilder` a singleton with `@singleton`
decorator, the class constructor now requires a `device` parameter.
However, two initialization sites were still using the old parameterless
constructor, causing failures.

### Changes

1. **Fix singleton initialization**
- Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)`
in `AscendMLAMetadataBuilder.__init__()`
- Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)`
in `AscendAttentionMetadataBuilder.__init__()`

2. **Remove unused field**
- Removed `pcp_prefill_mask` field from
`AscendPrefillContextParallelMetadata` (never used in codebase)
   - Updated related test assertions

### Related

- Issue #5463
- PR #4779 (Unify all mask generation methods)
- PR #5389 (Make AttentionMaskBuilder singleton)

## Does this PR introduce _any_ user-facing change?

No. This is an internal refactoring.

## How was this patch tested?

-  Local testing: No linter errors
-  Unit tests for attention modules verified
-  CI pipeline

Signed-off-by: lico67373 <918688502@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
LICO67373
2026-01-07 17:09:52 +08:00
committed by GitHub
parent 91790fd85a
commit 380f089fbf
21 changed files with 118 additions and 148 deletions

View File

@@ -19,6 +19,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
@@ -156,6 +157,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
and self.vllm_config.compilation_config.cudagraph_mode
== CUDAGraphMode.FULL_DECODE_ONLY
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
self.attn_mask_builder = AttentionMaskBuilder(self.device)
@classmethod
def get_cudagraph_support(
@@ -280,7 +282,8 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
seq_lens=seq_lens,
slot_mapping=slot_mapping,
head_dim=self.model_config.get_head_size(),
attn_mask=common_attn_metadata.attn_mask,
attn_mask=self.attn_mask_builder.get_attention_mask(
self.model_config),
attn_state=common_attn_metadata.attn_state,
block_tables=block_table,
sin=sin[:num_input_tokens],