[Refactor] Fix AttentionMaskBuilder singleton and remove redundant pcp_prefill_mask (#4870)
## What this PR does / why we need it? This PR fixes the `AttentionMaskBuilder` singleton initialization issue introduced in PR #4779 and removes the unused `pcp_prefill_mask` field. ### Background After PR #4779 made `AttentionMaskBuilder` a singleton with `@singleton` decorator, the class constructor now requires a `device` parameter. However, two initialization sites were still using the old parameterless constructor, causing failures. ### Changes 1. **Fix singleton initialization** - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendMLAMetadataBuilder.__init__()` - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendAttentionMetadataBuilder.__init__()` 2. **Remove unused field** - Removed `pcp_prefill_mask` field from `AscendPrefillContextParallelMetadata` (never used in codebase) - Updated related test assertions ### Related - Issue #5463 - PR #4779 (Unify all mask generation methods) - PR #5389 (Make AttentionMaskBuilder singleton) ## Does this PR introduce _any_ user-facing change? No. This is an internal refactoring. ## How was this patch tested? - ✅ Local testing: No linter errors - ✅ Unit tests for attention modules verified - ⏳ CI pipeline Signed-off-by: lico67373 <918688502@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -34,6 +34,7 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.context_parallel.common_cp import (
|
||||
AscendMetadataForDecode, AscendMetadataForPrefill)
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
@@ -219,6 +220,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill
|
||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
@@ -253,10 +255,19 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]):
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
swa_mask = common_attn_metadata.swa_mask
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
|
||||
# Get attn_mask and swa_mask from singleton AttentionMaskBuilder
|
||||
attn_mask = self.attn_mask_builder.get_attention_mask(
|
||||
self.model_config)
|
||||
|
||||
swa_mask = None
|
||||
is_swa = hasattr(self.model_config.hf_text_config, 'sliding_window')
|
||||
if self.model_config is not None and is_swa:
|
||||
swa_mask = self.attn_mask_builder.get_swa_mask(
|
||||
self.model_config.dtype,
|
||||
self.model_config.hf_text_config.sliding_window)
|
||||
|
||||
# TODO: Yet another unnecessary H2D while we already have a query_start_loc on device
|
||||
query_start_loc = query_start_loc_cpu.pin_memory().to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
Reference in New Issue
Block a user