[Refactor] Fix AttentionMaskBuilder singleton and remove redundant pcp_prefill_mask (#4870)
## What this PR does / why we need it? This PR fixes the `AttentionMaskBuilder` singleton initialization issue introduced in PR #4779 and removes the unused `pcp_prefill_mask` field. ### Background After PR #4779 made `AttentionMaskBuilder` a singleton with `@singleton` decorator, the class constructor now requires a `device` parameter. However, two initialization sites were still using the old parameterless constructor, causing failures. ### Changes 1. **Fix singleton initialization** - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendMLAMetadataBuilder.__init__()` - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendAttentionMetadataBuilder.__init__()` 2. **Remove unused field** - Removed `pcp_prefill_mask` field from `AscendPrefillContextParallelMetadata` (never used in codebase) - Updated related test assertions ### Related - Issue #5463 - PR #4779 (Unify all mask generation methods) - PR #5389 (Make AttentionMaskBuilder singleton) ## Does this PR introduce _any_ user-facing change? No. This is an internal refactoring. ## How was this patch tested? - ✅ Local testing: No linter errors - ✅ Unit tests for attention modules verified - ⏳ CI pipeline Signed-off-by: lico67373 <918688502@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -118,7 +118,6 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
tail_attn_nomask_seqlens=common_long_seq_metadata.
|
||||
tail_attn_nomask_seqlens,
|
||||
q_full_idx=common_long_seq_metadata.q_full_idx,
|
||||
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask,
|
||||
pcp_allgather_restore_idx=common_long_seq_metadata.
|
||||
pcp_allgather_restore_idx)
|
||||
|
||||
@@ -195,7 +194,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder):
|
||||
).item()
|
||||
if build_metadata_step == BUILD_METADATA_STEP_PREFILL:
|
||||
# For pcp + spec decode, we flatten seq_lens and block_table
|
||||
# to avoid irregular spec_attn_mask shape
|
||||
# to avoid irregular attn_mask shape
|
||||
return self.num_decodes_flatten + self.num_prefills
|
||||
else:
|
||||
return self.num_decodes_flatten
|
||||
@@ -420,7 +419,6 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
|
||||
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
|
||||
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
|
||||
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
|
||||
output_head, lse_head = self._attention_with_mask_and_nomask(
|
||||
q_nope=torch.index_select(q_nope, 0, q_head_idx),
|
||||
q_pe=torch.index_select(q_pe, 0, q_head_idx),
|
||||
@@ -431,7 +429,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
kv_nomask_idx=kv_with_q_head_nomask_idx,
|
||||
attn_mask_seqlens=attn_mask_seqlens,
|
||||
attn_nomask_seqlens=head_attn_nomask_seqlens,
|
||||
mask=mask)
|
||||
mask=attn_metadata.attn_mask)
|
||||
|
||||
output_tail, lse_tail = self._attention_with_mask_and_nomask(
|
||||
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
|
||||
@@ -443,7 +441,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
kv_nomask_idx=kv_with_q_tail_nomask_idx,
|
||||
attn_mask_seqlens=attn_mask_seqlens,
|
||||
attn_nomask_seqlens=tail_attn_nomask_seqlens,
|
||||
mask=mask)
|
||||
mask=attn_metadata.attn_mask)
|
||||
|
||||
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
|
||||
attn_output = torch.index_select(
|
||||
|
||||
Reference in New Issue
Block a user