[Refactor] Fix AttentionMaskBuilder singleton and remove redundant pcp_prefill_mask (#4870)
## What this PR does / why we need it? This PR fixes the `AttentionMaskBuilder` singleton initialization issue introduced in PR #4779 and removes the unused `pcp_prefill_mask` field. ### Background After PR #4779 made `AttentionMaskBuilder` a singleton with `@singleton` decorator, the class constructor now requires a `device` parameter. However, two initialization sites were still using the old parameterless constructor, causing failures. ### Changes 1. **Fix singleton initialization** - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendMLAMetadataBuilder.__init__()` - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendAttentionMetadataBuilder.__init__()` 2. **Remove unused field** - Removed `pcp_prefill_mask` field from `AscendPrefillContextParallelMetadata` (never used in codebase) - Updated related test assertions ### Related - Issue #5463 - PR #4779 (Unify all mask generation methods) - PR #5389 (Make AttentionMaskBuilder singleton) ## Does this PR introduce _any_ user-facing change? No. This is an internal refactoring. ## How was this patch tested? - ✅ Local testing: No linter errors - ✅ Unit tests for attention modules verified - ⏳ CI pipeline Signed-off-by: lico67373 <918688502@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -18,6 +18,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
||||
|
||||
from vllm_ascend import envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||
AscendPCPMetadata, CPChunkedContextMetadata)
|
||||
@@ -263,6 +264,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
self.graph_pad_size = 0
|
||||
self.query_lens: torch.Tensor = None
|
||||
self.seq_lens: torch.Tensor = None
|
||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
@@ -448,7 +450,8 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
num_decodes=self.num_decodes,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
num_prefills=self.num_prefills,
|
||||
attn_mask=common_attn_metadata.attn_mask,
|
||||
attn_mask=self.attn_mask_builder.get_final_mla_mask(
|
||||
self.model_config),
|
||||
attn_state=common_attn_metadata.attn_state,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
@@ -542,7 +545,8 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
prefill_input_positions = input_positions[tokens_start:]
|
||||
cos, sin = get_cos_and_sin_mla(prefill_input_positions)
|
||||
return AscendMLAPrefillMetadata(
|
||||
attn_mask=common_attn_metadata.attn_mask,
|
||||
attn_mask=self.attn_mask_builder.get_final_mla_mask(
|
||||
self.model_config),
|
||||
query_lens=self.query_lens[reqs_start:].to(torch.int32),
|
||||
seq_lens=self.seq_lens,
|
||||
context_lens=self.seq_lens[reqs_start:],
|
||||
@@ -643,7 +647,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
seq_lens=self.seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(),
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin[:self.num_decode_tokens, ...],
|
||||
cos=cos[:self.num_decode_tokens, ...],
|
||||
@@ -1197,7 +1201,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# Output shape: [num_heads, num_tokens, dim]
|
||||
attn_output_shape = (self.num_heads, num_tokens, self.kv_lora_rank)
|
||||
sparse_mode = 3
|
||||
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
||||
attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
||||
actual_seq_lengths = decode_meta.actual_seq_lengths_q
|
||||
else:
|
||||
# The output layout is set to NBSD to eliminate the need for a
|
||||
@@ -1218,7 +1222,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
attn_output_shape = (self.num_heads, num_tokens, 1,
|
||||
self.kv_lora_rank)
|
||||
sparse_mode = 0
|
||||
spec_attn_mask = None
|
||||
attn_mask = None
|
||||
|
||||
common_kwargs = {
|
||||
'query_rope': q_pe,
|
||||
@@ -1226,7 +1230,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
'num_heads': self.num_heads,
|
||||
'num_key_value_heads': self.num_kv_heads,
|
||||
'input_layout': input_layout,
|
||||
'atten_mask': spec_attn_mask,
|
||||
'atten_mask': attn_mask,
|
||||
'sparse_mode': sparse_mode,
|
||||
'scale': self.scale,
|
||||
'antiquant_mode': 0,
|
||||
@@ -1269,8 +1273,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
|
||||
weak_ref_tensors(q_pe), weak_ref_tensors(k_pe),
|
||||
self.num_heads, self.num_kv_heads, input_layout,
|
||||
weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None
|
||||
else None, sparse_mode, self.scale, decode_meta.block_table,
|
||||
weak_ref_tensors(attn_mask) if attn_mask is not None else
|
||||
None, sparse_mode, self.scale, decode_meta.block_table,
|
||||
block_size, decode_meta.seq_lens_list, actual_seq_lengths,
|
||||
weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse)))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user