[Refactor] Fix AttentionMaskBuilder singleton and remove redundant pcp_prefill_mask (#4870)

## What this PR does / why we need it?

This PR fixes the `AttentionMaskBuilder` singleton initialization issue
introduced in PR #4779 and removes the unused `pcp_prefill_mask` field.

### Background

After PR #4779 made `AttentionMaskBuilder` a singleton with `@singleton`
decorator, the class constructor now requires a `device` parameter.
However, two initialization sites were still using the old parameterless
constructor, causing failures.

### Changes

1. **Fix singleton initialization**
- Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)`
in `AscendMLAMetadataBuilder.__init__()`
- Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)`
in `AscendAttentionMetadataBuilder.__init__()`

2. **Remove unused field**
- Removed `pcp_prefill_mask` field from
`AscendPrefillContextParallelMetadata` (never used in codebase)
   - Updated related test assertions

### Related

- Issue #5463
- PR #4779 (Unify all mask generation methods)
- PR #5389 (Make AttentionMaskBuilder singleton)

## Does this PR introduce _any_ user-facing change?

No. This is an internal refactoring.

## How was this patch tested?

-  Local testing: No linter errors
-  Unit tests for attention modules verified
-  CI pipeline

Signed-off-by: lico67373 <918688502@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
LICO67373
2026-01-07 17:09:52 +08:00
committed by GitHub
parent 91790fd85a
commit 380f089fbf
21 changed files with 118 additions and 148 deletions

View File

@@ -13,6 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from vllm.distributed import get_pcp_group
from vllm_ascend.platform import ModelConfig
from vllm_ascend.utils import singleton
def _generate_attn_mask(max_seq_len, dtype):
@@ -29,6 +33,7 @@ def _generate_attn_mask(max_seq_len, dtype):
return attn_mask
@singleton
class AttentionMaskBuilder:
def __init__(self, device: torch.device):
@@ -82,4 +87,16 @@ class AttentionMaskBuilder:
triu_mask = torch.triu(mask, diagonal=1).to(self.device)
tril_mask = torch.tril(mask, -sliding_window).to(self.device)
self.swa_mask = triu_mask + tril_mask
return self.swa_mask
return self.swa_mask
def get_attention_mask(self, model_config: ModelConfig):
if model_config.runner_type == "pooling":
return self.get_attn_mask(2048, torch.bool)
return self.get_splitfuse_attn_mask()
def get_final_mla_mask(self, model_config: ModelConfig):
if get_pcp_group().world_size > 1:
return self.get_pcp_mla_mask(model_config.dtype)
# Prefill stages use 512x512 mask with appropriate dtype
return self.get_mla_mask(model_config.dtype)

View File

@@ -34,6 +34,7 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.context_parallel.common_cp import (
AscendMetadataForDecode, AscendMetadataForPrefill)
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
@@ -219,6 +220,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
scheduler_config = vllm_config.scheduler_config
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
self.attn_mask_builder = AttentionMaskBuilder(self.device)
@classmethod
def get_cudagraph_support(
@@ -253,10 +255,19 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
attn_mask = common_attn_metadata.attn_mask
swa_mask = common_attn_metadata.swa_mask
attn_state = common_attn_metadata.attn_state
# Get attn_mask and swa_mask from singleton AttentionMaskBuilder
attn_mask = self.attn_mask_builder.get_attention_mask(
self.model_config)
swa_mask = None
is_swa = hasattr(self.model_config.hf_text_config, 'sliding_window')
if self.model_config is not None and is_swa:
swa_mask = self.attn_mask_builder.get_swa_mask(
self.model_config.dtype,
self.model_config.hf_text_config.sliding_window)
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
query_start_loc = query_start_loc_cpu.pin_memory().to(
self.device, non_blocking=True)

View File

@@ -121,7 +121,8 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
slot_mapping = common_attn_metadata.slot_mapping[:
num_actual_tokens_pcp_padded]
attn_mask = common_attn_metadata.attn_mask
attn_mask = self.attn_mask_builder.get_attention_mask(
self.model_config)
attn_state = common_attn_metadata.attn_state
num_computed_tokens_cpu = (seq_lens - query_lens)
@@ -212,7 +213,6 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
head_attn_nomask_seqlens=head_attn_nomask_seqlens,
tail_attn_nomask_seqlens=tail_attn_nomask_seqlens,
q_full_idx=common_long_seq_metadata.q_full_idx,
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask,
pcp_allgather_restore_idx=common_long_seq_metadata.
pcp_allgather_restore_idx)
@@ -433,13 +433,12 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens \
if is_head else attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
output, lse = self._attention_with_nomask_and_mask(
**data,
q_seqlens=attn_mask_seqlens,
kv_seqlens_nomask=nomask_seqlens,
kv_seqlens_mask=attn_mask_seqlens,
mask=mask,
mask=attn_metadata.attn_mask,
attn_metadata=attn_metadata)
return output, lse

View File

@@ -21,7 +21,6 @@ class AscendPCPMetadata:
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
pcp_allgather_restore_idx: Optional[list[int]] = None

View File

@@ -118,7 +118,6 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
tail_attn_nomask_seqlens=common_long_seq_metadata.
tail_attn_nomask_seqlens,
q_full_idx=common_long_seq_metadata.q_full_idx,
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask,
pcp_allgather_restore_idx=common_long_seq_metadata.
pcp_allgather_restore_idx)
@@ -195,7 +194,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
).item()
if build_metadata_step == BUILD_METADATA_STEP_PREFILL:
# For pcp + spec decode, we flatten seq_lens and block_table
# to avoid irregular spec_attn_mask shape
# to avoid irregular attn_mask shape
return self.num_decodes_flatten + self.num_prefills
else:
return self.num_decodes_flatten
@@ -420,7 +419,6 @@ class AscendMlaCPImpl(AscendMLAImpl):
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
output_head, lse_head = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_head_idx),
q_pe=torch.index_select(q_pe, 0, q_head_idx),
@@ -431,7 +429,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
kv_nomask_idx=kv_with_q_head_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=head_attn_nomask_seqlens,
mask=mask)
mask=attn_metadata.attn_mask)
output_tail, lse_tail = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
@@ -443,7 +441,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
kv_nomask_idx=kv_with_q_tail_nomask_idx,
attn_mask_seqlens=attn_mask_seqlens,
attn_nomask_seqlens=tail_attn_nomask_seqlens,
mask=mask)
mask=attn_metadata.attn_mask)
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
attn_output = torch.index_select(

View File

@@ -18,6 +18,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.context_parallel.common_cp import (
AscendPCPMetadata, CPChunkedContextMetadata)
@@ -263,6 +264,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
self.graph_pad_size = 0
self.query_lens: torch.Tensor = None
self.seq_lens: torch.Tensor = None
self.attn_mask_builder = AttentionMaskBuilder(self.device)
@classmethod
def get_cudagraph_support(
@@ -448,7 +450,8 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
num_decodes=self.num_decodes,
num_decode_tokens=self.num_decode_tokens,
num_prefills=self.num_prefills,
attn_mask=common_attn_metadata.attn_mask,
attn_mask=self.attn_mask_builder.get_final_mla_mask(
self.model_config),
attn_state=common_attn_metadata.attn_state,
prefill=prefill_metadata,
decode=decode_metadata,
@@ -542,7 +545,8 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
prefill_input_positions = input_positions[tokens_start:]
cos, sin = get_cos_and_sin_mla(prefill_input_positions)
return AscendMLAPrefillMetadata(
attn_mask=common_attn_metadata.attn_mask,
attn_mask=self.attn_mask_builder.get_final_mla_mask(
self.model_config),
query_lens=self.query_lens[reqs_start:].to(torch.int32),
seq_lens=self.seq_lens,
context_lens=self.seq_lens[reqs_start:],
@@ -643,7 +647,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
seq_lens=self.seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(),
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin[:self.num_decode_tokens, ...],
cos=cos[:self.num_decode_tokens, ...],
@@ -1197,7 +1201,7 @@ class AscendMLAImpl(MLAAttentionImpl):
# Output shape: [num_heads, num_tokens, dim]
attn_output_shape = (self.num_heads, num_tokens, self.kv_lora_rank)
sparse_mode = 3
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
attn_mask = attn_metadata.decode.attn_mask # type:ignore
actual_seq_lengths = decode_meta.actual_seq_lengths_q
else:
# The output layout is set to NBSD to eliminate the need for a
@@ -1218,7 +1222,7 @@ class AscendMLAImpl(MLAAttentionImpl):
attn_output_shape = (self.num_heads, num_tokens, 1,
self.kv_lora_rank)
sparse_mode = 0
spec_attn_mask = None
attn_mask = None
common_kwargs = {
'query_rope': q_pe,
@@ -1226,7 +1230,7 @@ class AscendMLAImpl(MLAAttentionImpl):
'num_heads': self.num_heads,
'num_key_value_heads': self.num_kv_heads,
'input_layout': input_layout,
'atten_mask': spec_attn_mask,
'atten_mask': attn_mask,
'sparse_mode': sparse_mode,
'scale': self.scale,
'antiquant_mode': 0,
@@ -1269,8 +1273,8 @@ class AscendMLAImpl(MLAAttentionImpl):
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
weak_ref_tensors(q_pe), weak_ref_tensors(k_pe),
self.num_heads, self.num_kv_heads, input_layout,
weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None
else None, sparse_mode, self.scale, decode_meta.block_table,
weak_ref_tensors(attn_mask) if attn_mask is not None else
None, sparse_mode, self.scale, decode_meta.block_table,
block_size, decode_meta.seq_lens_list, actual_seq_lengths,
weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse)))

View File

@@ -19,6 +19,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
@@ -156,6 +157,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
and self.vllm_config.compilation_config.cudagraph_mode
== CUDAGraphMode.FULL_DECODE_ONLY
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
self.attn_mask_builder = AttentionMaskBuilder(self.device)
@classmethod
def get_cudagraph_support(
@@ -280,7 +282,8 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
seq_lens=seq_lens,
slot_mapping=slot_mapping,
head_dim=self.model_config.get_head_size(),
attn_mask=common_attn_metadata.attn_mask,
attn_mask=self.attn_mask_builder.get_attention_mask(
self.model_config),
attn_state=common_attn_metadata.attn_state,
block_tables=block_table,
sin=sin[:num_input_tokens],

View File

@@ -66,8 +66,6 @@ class AscendPrefillContextParallelMetadata:
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
# original query_lens before pcp split
query_lens_pcp_full_cpu: torch.Tensor = None
@@ -93,12 +91,6 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
positions: torch.Tensor = None
attn_mask: torch.Tensor = None
spec_attn_mask: torch.Tensor = None
swa_mask: torch.Tensor = None
attn_state: Any = None
graph_pad_size: int = -1
@@ -130,9 +122,6 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
causal=self.causal,
actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens],
positions=self.positions[:num_actual_tokens],
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
swa_mask=self.swa_mask,
attn_state=self.attn_state,
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
num_input_tokens=num_actual_tokens,