[Refactor] Fix AttentionMaskBuilder singleton and remove redundant pcp_prefill_mask (#4870)
## What this PR does / why we need it? This PR fixes the `AttentionMaskBuilder` singleton initialization issue introduced in PR #4779 and removes the unused `pcp_prefill_mask` field. ### Background After PR #4779 made `AttentionMaskBuilder` a singleton with `@singleton` decorator, the class constructor now requires a `device` parameter. However, two initialization sites were still using the old parameterless constructor, causing failures. ### Changes 1. **Fix singleton initialization** - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendMLAMetadataBuilder.__init__()` - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendAttentionMetadataBuilder.__init__()` 2. **Remove unused field** - Removed `pcp_prefill_mask` field from `AscendPrefillContextParallelMetadata` (never used in codebase) - Updated related test assertions ### Related - Issue #5463 - PR #4779 (Unify all mask generation methods) - PR #5389 (Make AttentionMaskBuilder singleton) ## Does this PR introduce _any_ user-facing change? No. This is an internal refactoring. ## How was this patch tested? - ✅ Local testing: No linter errors - ✅ Unit tests for attention modules verified - ⏳ CI pipeline Signed-off-by: lico67373 <918688502@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -13,6 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from vllm.distributed import get_pcp_group
|
||||
|
||||
from vllm_ascend.platform import ModelConfig
|
||||
from vllm_ascend.utils import singleton
|
||||
|
||||
|
||||
def _generate_attn_mask(max_seq_len, dtype):
|
||||
@@ -29,6 +33,7 @@ def _generate_attn_mask(max_seq_len, dtype):
|
||||
return attn_mask
|
||||
|
||||
|
||||
@singleton
|
||||
class AttentionMaskBuilder:
|
||||
|
||||
def __init__(self, device: torch.device):
|
||||
@@ -82,4 +87,16 @@ class AttentionMaskBuilder:
|
||||
triu_mask = torch.triu(mask, diagonal=1).to(self.device)
|
||||
tril_mask = torch.tril(mask, -sliding_window).to(self.device)
|
||||
self.swa_mask = triu_mask + tril_mask
|
||||
return self.swa_mask
|
||||
return self.swa_mask
|
||||
|
||||
def get_attention_mask(self, model_config: ModelConfig):
|
||||
if model_config.runner_type == "pooling":
|
||||
return self.get_attn_mask(2048, torch.bool)
|
||||
|
||||
return self.get_splitfuse_attn_mask()
|
||||
|
||||
def get_final_mla_mask(self, model_config: ModelConfig):
|
||||
if get_pcp_group().world_size > 1:
|
||||
return self.get_pcp_mla_mask(model_config.dtype)
|
||||
# Prefill stages use 512x512 mask with appropriate dtype
|
||||
return self.get_mla_mask(model_config.dtype)
|
||||
@@ -34,6 +34,7 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||
AscendMetadataForDecode, AscendMetadataForPrefill)
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
@@ -219,6 +220,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
|
||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
@@ -253,10 +255,19 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
swa_mask = common_attn_metadata.swa_mask
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
|
||||
# Get attn_mask and swa_mask from singleton AttentionMaskBuilder
|
||||
attn_mask = self.attn_mask_builder.get_attention_mask(
|
||||
self.model_config)
|
||||
|
||||
swa_mask = None
|
||||
is_swa = hasattr(self.model_config.hf_text_config, 'sliding_window')
|
||||
if self.model_config is not None and is_swa:
|
||||
swa_mask = self.attn_mask_builder.get_swa_mask(
|
||||
self.model_config.dtype,
|
||||
self.model_config.hf_text_config.sliding_window)
|
||||
|
||||
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
|
||||
query_start_loc = query_start_loc_cpu.pin_memory().to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
@@ -121,7 +121,8 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
||||
num_actual_tokens_pcp_padded]
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
attn_mask = self.attn_mask_builder.get_attention_mask(
|
||||
self.model_config)
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
num_computed_tokens_cpu = (seq_lens - query_lens)
|
||||
|
||||
@@ -212,7 +213,6 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
head_attn_nomask_seqlens=head_attn_nomask_seqlens,
|
||||
tail_attn_nomask_seqlens=tail_attn_nomask_seqlens,
|
||||
q_full_idx=common_long_seq_metadata.q_full_idx,
|
||||
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask,
|
||||
pcp_allgather_restore_idx=common_long_seq_metadata.
|
||||
pcp_allgather_restore_idx)
|
||||
|
||||
@@ -433,13 +433,12 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
||||
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
|
||||
nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens \
|
||||
if is_head else attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
|
||||
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
|
||||
output, lse = self._attention_with_nomask_and_mask(
|
||||
**data,
|
||||
q_seqlens=attn_mask_seqlens,
|
||||
kv_seqlens_nomask=nomask_seqlens,
|
||||
kv_seqlens_mask=attn_mask_seqlens,
|
||||
mask=mask,
|
||||
mask=attn_metadata.attn_mask,
|
||||
attn_metadata=attn_metadata)
|
||||
return output, lse
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ class AscendPCPMetadata:
|
||||
head_attn_nomask_seqlens: torch.Tensor = None
|
||||
tail_attn_nomask_seqlens: torch.Tensor = None
|
||||
q_full_idx: torch.Tensor = None
|
||||
pcp_prefill_mask: torch.Tensor = None
|
||||
pcp_allgather_restore_idx: Optional[list[int]] = None
|
||||
|
||||
|
||||
|
||||
@@ -118,7 +118,6 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
tail_attn_nomask_seqlens=common_long_seq_metadata.
|
||||
tail_attn_nomask_seqlens,
|
||||
q_full_idx=common_long_seq_metadata.q_full_idx,
|
||||
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask,
|
||||
pcp_allgather_restore_idx=common_long_seq_metadata.
|
||||
pcp_allgather_restore_idx)
|
||||
|
||||
@@ -195,7 +194,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
).item()
|
||||
if build_metadata_step == BUILD_METADATA_STEP_PREFILL:
|
||||
# For pcp + spec decode, we flatten seq_lens and block_table
|
||||
# to avoid irregular spec_attn_mask shape
|
||||
# to avoid irregular attn_mask shape
|
||||
return self.num_decodes_flatten + self.num_prefills
|
||||
else:
|
||||
return self.num_decodes_flatten
|
||||
@@ -420,7 +419,6 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
|
||||
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
|
||||
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
|
||||
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
|
||||
output_head, lse_head = self._attention_with_mask_and_nomask(
|
||||
q_nope=torch.index_select(q_nope, 0, q_head_idx),
|
||||
q_pe=torch.index_select(q_pe, 0, q_head_idx),
|
||||
@@ -431,7 +429,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
kv_nomask_idx=kv_with_q_head_nomask_idx,
|
||||
attn_mask_seqlens=attn_mask_seqlens,
|
||||
attn_nomask_seqlens=head_attn_nomask_seqlens,
|
||||
mask=mask)
|
||||
mask=attn_metadata.attn_mask)
|
||||
|
||||
output_tail, lse_tail = self._attention_with_mask_and_nomask(
|
||||
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
|
||||
@@ -443,7 +441,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
kv_nomask_idx=kv_with_q_tail_nomask_idx,
|
||||
attn_mask_seqlens=attn_mask_seqlens,
|
||||
attn_nomask_seqlens=tail_attn_nomask_seqlens,
|
||||
mask=mask)
|
||||
mask=attn_metadata.attn_mask)
|
||||
|
||||
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
|
||||
attn_output = torch.index_select(
|
||||
|
||||
@@ -18,6 +18,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
||||
|
||||
from vllm_ascend import envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||
AscendPCPMetadata, CPChunkedContextMetadata)
|
||||
@@ -263,6 +264,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
self.graph_pad_size = 0
|
||||
self.query_lens: torch.Tensor = None
|
||||
self.seq_lens: torch.Tensor = None
|
||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
@@ -448,7 +450,8 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
num_decodes=self.num_decodes,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
num_prefills=self.num_prefills,
|
||||
attn_mask=common_attn_metadata.attn_mask,
|
||||
attn_mask=self.attn_mask_builder.get_final_mla_mask(
|
||||
self.model_config),
|
||||
attn_state=common_attn_metadata.attn_state,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
@@ -542,7 +545,8 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
prefill_input_positions = input_positions[tokens_start:]
|
||||
cos, sin = get_cos_and_sin_mla(prefill_input_positions)
|
||||
return AscendMLAPrefillMetadata(
|
||||
attn_mask=common_attn_metadata.attn_mask,
|
||||
attn_mask=self.attn_mask_builder.get_final_mla_mask(
|
||||
self.model_config),
|
||||
query_lens=self.query_lens[reqs_start:].to(torch.int32),
|
||||
seq_lens=self.seq_lens,
|
||||
context_lens=self.seq_lens[reqs_start:],
|
||||
@@ -643,7 +647,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]):
|
||||
seq_lens=self.seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(),
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin[:self.num_decode_tokens, ...],
|
||||
cos=cos[:self.num_decode_tokens, ...],
|
||||
@@ -1197,7 +1201,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# Output shape: [num_heads, num_tokens, dim]
|
||||
attn_output_shape = (self.num_heads, num_tokens, self.kv_lora_rank)
|
||||
sparse_mode = 3
|
||||
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
||||
attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
||||
actual_seq_lengths = decode_meta.actual_seq_lengths_q
|
||||
else:
|
||||
# The output layout is set to NBSD to eliminate the need for a
|
||||
@@ -1218,7 +1222,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
attn_output_shape = (self.num_heads, num_tokens, 1,
|
||||
self.kv_lora_rank)
|
||||
sparse_mode = 0
|
||||
spec_attn_mask = None
|
||||
attn_mask = None
|
||||
|
||||
common_kwargs = {
|
||||
'query_rope': q_pe,
|
||||
@@ -1226,7 +1230,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
'num_heads': self.num_heads,
|
||||
'num_key_value_heads': self.num_kv_heads,
|
||||
'input_layout': input_layout,
|
||||
'atten_mask': spec_attn_mask,
|
||||
'atten_mask': attn_mask,
|
||||
'sparse_mode': sparse_mode,
|
||||
'scale': self.scale,
|
||||
'antiquant_mode': 0,
|
||||
@@ -1269,8 +1273,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
|
||||
weak_ref_tensors(q_pe), weak_ref_tensors(k_pe),
|
||||
self.num_heads, self.num_kv_heads, input_layout,
|
||||
weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None
|
||||
else None, sparse_mode, self.scale, decode_meta.block_table,
|
||||
weak_ref_tensors(attn_mask) if attn_mask is not None else
|
||||
None, sparse_mode, self.scale, decode_meta.block_table,
|
||||
block_size, decode_meta.seq_lens_list, actual_seq_lengths,
|
||||
weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse)))
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_ascend import envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
@@ -156,6 +157,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
and self.vllm_config.compilation_config.cudagraph_mode
|
||||
== CUDAGraphMode.FULL_DECODE_ONLY
|
||||
), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1."
|
||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
@@ -280,7 +282,8 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
||||
seq_lens=seq_lens,
|
||||
slot_mapping=slot_mapping,
|
||||
head_dim=self.model_config.get_head_size(),
|
||||
attn_mask=common_attn_metadata.attn_mask,
|
||||
attn_mask=self.attn_mask_builder.get_attention_mask(
|
||||
self.model_config),
|
||||
attn_state=common_attn_metadata.attn_state,
|
||||
block_tables=block_table,
|
||||
sin=sin[:num_input_tokens],
|
||||
|
||||
@@ -66,8 +66,6 @@ class AscendPrefillContextParallelMetadata:
|
||||
|
||||
q_full_idx: torch.Tensor = None
|
||||
|
||||
pcp_prefill_mask: torch.Tensor = None
|
||||
|
||||
# original query_lens before pcp split
|
||||
query_lens_pcp_full_cpu: torch.Tensor = None
|
||||
|
||||
@@ -93,12 +91,6 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
|
||||
|
||||
positions: torch.Tensor = None
|
||||
|
||||
attn_mask: torch.Tensor = None
|
||||
|
||||
spec_attn_mask: torch.Tensor = None
|
||||
|
||||
swa_mask: torch.Tensor = None
|
||||
|
||||
attn_state: Any = None
|
||||
|
||||
graph_pad_size: int = -1
|
||||
@@ -130,9 +122,6 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
|
||||
causal=self.causal,
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens],
|
||||
positions=self.positions[:num_actual_tokens],
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
swa_mask=self.swa_mask,
|
||||
attn_state=self.attn_state,
|
||||
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
|
||||
num_input_tokens=num_actual_tokens,
|
||||
|
||||
Reference in New Issue
Block a user