[Refactor] Fix AttentionMaskBuilder singleton and remove redundant pcp_prefill_mask (#4870)
## What this PR does / why we need it? This PR fixes the `AttentionMaskBuilder` singleton initialization issue introduced in PR #4779 and removes the unused `pcp_prefill_mask` field. ### Background After PR #4779 made `AttentionMaskBuilder` a singleton with `@singleton` decorator, the class constructor now requires a `device` parameter. However, two initialization sites were still using the old parameterless constructor, causing failures. ### Changes 1. **Fix singleton initialization** - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendMLAMetadataBuilder.__init__()` - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendAttentionMetadataBuilder.__init__()` 2. **Remove unused field** - Removed `pcp_prefill_mask` field from `AscendPrefillContextParallelMetadata` (never used in codebase) - Updated related test assertions ### Related - Issue #5463 - PR #4779 (Unify all mask generation methods) - PR #5389 (Make AttentionMaskBuilder singleton) ## Does this PR introduce _any_ user-facing change? No. This is an internal refactoring. ## How was this patch tested? - ✅ Local testing: No linter errors - ✅ Unit tests for attention modules verified - ⏳ CI pipeline Signed-off-by: lico67373 <918688502@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -498,7 +498,7 @@ class PCPManager:
|
||||
torch.float32).argsort().to(torch.int32)
|
||||
|
||||
def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens,
|
||||
attn_mask, input_batch):
|
||||
input_batch):
|
||||
from vllm_ascend.attention.utils import \
|
||||
AscendPrefillContextParallelMetadata
|
||||
num_reqs = input_batch.num_reqs or query_lens.size(0)
|
||||
@@ -523,7 +523,7 @@ class PCPManager:
|
||||
dtype=torch.int32,
|
||||
)
|
||||
# For pcp + spec decode, we flatten seq_lens
|
||||
# to avoid irregular spec_attn_mask shape.
|
||||
# to avoid irregular attn_mask shape.
|
||||
# Same as block_table, we flatten decode seq_lens to query_lens,
|
||||
# and keep prefill seq_lens unchanged.
|
||||
for decode_idx in range(self.decode_threshold):
|
||||
@@ -657,13 +657,11 @@ class PCPManager:
|
||||
split_with_q_head_nomask_idx_reqs,
|
||||
split_kv_with_q_tail_nomask_idx_reqs,
|
||||
head_attn_nomask_seqlens, chunk_seqlens)
|
||||
pcp_prefill_mask = attn_mask
|
||||
|
||||
self.extra_long_seq_kwargs = {
|
||||
'attn_mask_seqlens': attn_mask_seqlens,
|
||||
'head_attn_nomask_seqlens': head_attn_nomask_seqlens,
|
||||
'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens,
|
||||
'pcp_prefill_mask': pcp_prefill_mask
|
||||
'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens
|
||||
}
|
||||
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[:
|
||||
num_actual_tokens_pcp_padded]
|
||||
@@ -685,8 +683,6 @@ class PCPManager:
|
||||
'head_attn_nomask_seqlens']
|
||||
long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[
|
||||
'tail_attn_nomask_seqlens']
|
||||
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
|
||||
'pcp_prefill_mask']
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = split_q_head_nomask_idx_tensor_list
|
||||
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list
|
||||
|
||||
Reference in New Issue
Block a user