[Refactor] Fix AttentionMaskBuilder singleton and remove redundant pcp_prefill_mask (#4870)

## What this PR does / why we need it?

This PR fixes the `AttentionMaskBuilder` singleton initialization issue
introduced in PR #4779 and removes the unused `pcp_prefill_mask` field.

### Background

After PR #4779 made `AttentionMaskBuilder` a singleton with `@singleton`
decorator, the class constructor now requires a `device` parameter.
However, two initialization sites were still using the old parameterless
constructor, causing failures.

### Changes

1. **Fix singleton initialization**
- Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)`
in `AscendMLAMetadataBuilder.__init__()`
- Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)`
in `AscendAttentionMetadataBuilder.__init__()`

2. **Remove unused field**
- Removed `pcp_prefill_mask` field from
`AscendPrefillContextParallelMetadata` (never used in codebase)
   - Updated related test assertions

### Related

- Issue #5463
- PR #4779 (Unify all mask generation methods)
- PR #5389 (Make AttentionMaskBuilder singleton)

## Does this PR introduce _any_ user-facing change?

No. This is an internal refactoring.

## How was this patch tested?

-  Local testing: No linter errors
-  Unit tests for attention modules verified
-  CI pipeline

Signed-off-by: lico67373 <918688502@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
LICO67373
2026-01-07 17:09:52 +08:00
committed by GitHub
parent 91790fd85a
commit 380f089fbf
21 changed files with 118 additions and 148 deletions

View File

@@ -77,7 +77,6 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
from vllm.v1.worker.utils import AttentionGroup
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
# yapf conflicts with isort for this block
@@ -230,7 +229,6 @@ class NPUModelRunner(GPUModelRunner):
self.positions = self._make_buffer(max_buffer_num_tokens,
dtype=torch.int64)
self.sampler = AscendSampler()
self.attn_mask = None
self.attn_state = None
# Ascend-specific configurations
@@ -264,19 +262,9 @@ class NPUModelRunner(GPUModelRunner):
use_sparse=self.use_sparse,
use_mm_prefix=self.model_config is not None
and self.model_config.is_mm_prefix_lm)
self.attn_mask_builder = AttentionMaskBuilder(self.device)
self._set_up_drafter()
# sliding window attn mask
self.swa_mask = None
is_swa = hasattr(self.vllm_config.model_config.hf_text_config,
"sliding_window")
if self.model_config is not None and is_swa:
self.swa_mask = self.attn_mask_builder.get_swa_mask(
self.dtype,
self.vllm_config.model_config.hf_text_config.sliding_window)
# kv role
self.is_kv_producer = False
self.is_kv_consumer = False
@@ -370,7 +358,6 @@ class NPUModelRunner(GPUModelRunner):
def _set_up_drafter(self):
# Set up speculative decoding.
self.spec_attn_mask = None
self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer,
SuffixDecodingProposer]] = None
self.actual_seq_lengths_q: list[int] = []
@@ -379,8 +366,6 @@ class NPUModelRunner(GPUModelRunner):
spec_token_num = self.speculative_config.num_speculative_tokens
assert spec_token_num > 0
self.decode_token_per_req = 1 + spec_token_num
self.spec_attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask(
)
if get_pp_group().is_last_rank:
self.drafter = self._get_drafter()
if self.speculative_config.method == "eagle3":
@@ -494,22 +479,6 @@ class NPUModelRunner(GPUModelRunner):
return self.model.unwrap()
return self.model
def _make_attention_mask(self, attn_state) -> torch.Tensor:
# pcp situation.
if self.attn_mask_builder is None:
raise ValueError("Attn mask builder is None")
# Pooling situation.
if self.model_config.runner_type == "pooling":
return self.attn_mask_builder.get_attn_mask(2048, torch.bool)
if self.vllm_config.model_config.use_mla:
if self.pcp_size > 1:
return self.attn_mask_builder.get_pcp_mla_mask(self.dtype)
# mla prefill
if attn_state != AscendAttentionState.DecodeOnly:
return self.attn_mask_builder.get_mla_mask(self.dtype)
return self.attn_mask_builder.get_splitfuse_attn_mask()
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
@@ -551,7 +520,6 @@ class NPUModelRunner(GPUModelRunner):
with_prefill = attn_state not in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
self.attn_mask = self._make_attention_mask(attn_state)
# Get positions.
positions_np = self.positions.np[:total_num_scheduled_tokens]
@@ -941,7 +909,7 @@ class NPUModelRunner(GPUModelRunner):
if self.pcp_size * self.dcp_size > 1:
self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata(
total_num_scheduled_tokens, self.query_lens,
self.attn_mask, self.input_batch)
self.input_batch)
blk_table.slot_mapping.gpu[maybe_pcp_full_tokens:].fill_(-1)
if self.pcp_size > 1:
slot_mapping = self.pcp_manager.get_padded_slot_mapping(
@@ -997,9 +965,6 @@ class NPUModelRunner(GPUModelRunner):
num_computed_tokens_cpu=self.input_batch.
num_computed_tokens_cpu_tensor[:num_reqs],
positions=self.positions.gpu,
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
swa_mask=self.swa_mask,
attn_state=self.attn_state,
max_query_len=max_num_scheduled_tokens,
decode_token_per_req=self.decode_token_per_req,
@@ -1009,7 +974,7 @@ class NPUModelRunner(GPUModelRunner):
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
# For pcp + spec decode, we flatten block_table
# to avoid irregular spec_attn_mask shape, e.g.,
# to avoid irregular attn_mask shape, e.g.,
# num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1,
# ori block_table: # [d0, d1, p0, p1, p2]
# (num_reqs_d + num_reqs_p, max_num_blocks),
@@ -1918,7 +1883,6 @@ class NPUModelRunner(GPUModelRunner):
self.query_start_loc.cpu[1:num_reqs +
1] = torch.Tensor(cu_num_tokens)
self.query_lens = torch.from_numpy(num_scheduled_tokens)
self.attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask()
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
@@ -1930,8 +1894,7 @@ class NPUModelRunner(GPUModelRunner):
slot_mapping = self.input_batch.block_table[
kv_cache_group_id].slot_mapping
long_seq_metadata = None if self.pcp_size * self.dcp_size == 1 else self.pcp_manager.generate_pcp_metadata(
num_tokens, self.query_lens, self.attn_mask,
self.input_batch)
num_tokens, self.query_lens, self.input_batch)
if long_seq_metadata is not None:
pcp_world_size = get_pcp_group().world_size
dcp_world_size = get_dcp_group().world_size
@@ -1954,9 +1917,6 @@ class NPUModelRunner(GPUModelRunner):
slot_mapping=slot_mapping.gpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
positions=self.positions.gpu,
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
swa_mask=self.swa_mask,
attn_state=self.attn_state,
max_query_len=max_query_len,
decode_token_per_req=self.decode_token_per_req,

View File

@@ -498,7 +498,7 @@ class PCPManager:
torch.float32).argsort().to(torch.int32)
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens,
attn_mask, input_batch):
input_batch):
from vllm_ascend.attention.utils import \
AscendPrefillContextParallelMetadata
num_reqs = input_batch.num_reqs or query_lens.size(0)
@@ -523,7 +523,7 @@ class PCPManager:
dtype=torch.int32,
)
# For pcp + spec decode, we flatten seq_lens
# to avoid irregular spec_attn_mask shape.
# to avoid irregular attn_mask shape.
# Same as block_table, we flatten decode seq_lens to query_lens,
# and keep prefill seq_lens unchanged.
for decode_idx in range(self.decode_threshold):
@@ -657,13 +657,11 @@ class PCPManager:
split_with_q_head_nomask_idx_reqs,
split_kv_with_q_tail_nomask_idx_reqs,
head_attn_nomask_seqlens, chunk_seqlens)
pcp_prefill_mask = attn_mask
self.extra_long_seq_kwargs = {
'attn_mask_seqlens': attn_mask_seqlens,
'head_attn_nomask_seqlens': head_attn_nomask_seqlens,
'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens,
'pcp_prefill_mask': pcp_prefill_mask
'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens
}
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[:
num_actual_tokens_pcp_padded]
@@ -685,8 +683,6 @@ class PCPManager:
'head_attn_nomask_seqlens']
long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[
'tail_attn_nomask_seqlens']
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
'pcp_prefill_mask']
if self.vllm_config.model_config.use_mla:
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = split_q_head_nomask_idx_tensor_list
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list

View File

@@ -58,9 +58,6 @@ def build_attn_metadata(
decode_token_per_req: int,
actual_seq_lengths_q: list[int],
positions: torch.Tensor | None = None,
attn_mask: torch.Tensor
| None = None,
spec_attn_mask: torch.Tensor | None = None,
attn_state: Any | None = None,
graph_pad_size: int = -1,
num_input_tokens: int = 0,
@@ -92,8 +89,6 @@ def build_attn_metadata(
slot_mapping=slot_mapping,
actual_seq_lengths_q=actual_seq_lengths_q,
positions=positions,
attn_mask=attn_mask,
spec_attn_mask=spec_attn_mask,
attn_state=attn_state,
graph_pad_size=graph_pad_size,
num_input_tokens=num_input_tokens,

View File

@@ -32,8 +32,7 @@ from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager
from vllm_ascend.worker.v2.attn_utils import (build_attn_metadata,
build_attn_state,
make_attention_mask)
build_attn_state)
from vllm_ascend.worker.v2.input_batch import AscendInputBuffers
from vllm_ascend.worker.v2.sample.sampler import AscendSampler
from vllm_ascend.worker.v2.states import AscendRequestState, uva_wrapper
@@ -155,12 +154,6 @@ class NPUModelRunner(GPUModelRunner):
num_scheduled_tokens,
num_valid_tokens,
)
attn_mask = make_attention_mask(
self.vllm_config,
attn_state,
self.dtype,
self.device,
)
idx_mapping_list = [
self.req_states.req_id_to_index[req_id] for req_id in req_ids
@@ -284,7 +277,6 @@ class NPUModelRunner(GPUModelRunner):
slot_mappings=slot_mappings.to(torch.int32),
kv_cache_config=self.kv_cache_config,
decode_token_per_req=self.decode_token_per_req,
attn_mask=attn_mask,
attn_state=attn_state,
)