[Refactor] Fix AttentionMaskBuilder singleton and remove redundant pcp_prefill_mask (#4870)
## What this PR does / why we need it? This PR fixes the `AttentionMaskBuilder` singleton initialization issue introduced in PR #4779 and removes the unused `pcp_prefill_mask` field. ### Background After PR #4779 made `AttentionMaskBuilder` a singleton with `@singleton` decorator, the class constructor now requires a `device` parameter. However, two initialization sites were still using the old parameterless constructor, causing failures. ### Changes 1. **Fix singleton initialization** - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendMLAMetadataBuilder.__init__()` - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendAttentionMetadataBuilder.__init__()` 2. **Remove unused field** - Removed `pcp_prefill_mask` field from `AscendPrefillContextParallelMetadata` (never used in codebase) - Updated related test assertions ### Related - Issue #5463 - PR #4779 (Unify all mask generation methods) - PR #5389 (Make AttentionMaskBuilder singleton) ## Does this PR introduce _any_ user-facing change? No. This is an internal refactoring. ## How was this patch tested? - ✅ Local testing: No linter errors - ✅ Unit tests for attention modules verified - ⏳ CI pipeline Signed-off-by: lico67373 <918688502@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -66,8 +66,6 @@ class AscendPrefillContextParallelMetadata:
|
||||
|
||||
q_full_idx: torch.Tensor = None
|
||||
|
||||
pcp_prefill_mask: torch.Tensor = None
|
||||
|
||||
# original query_lens before pcp split
|
||||
query_lens_pcp_full_cpu: torch.Tensor = None
|
||||
|
||||
@@ -93,12 +91,6 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
|
||||
|
||||
positions: torch.Tensor = None
|
||||
|
||||
attn_mask: torch.Tensor = None
|
||||
|
||||
spec_attn_mask: torch.Tensor = None
|
||||
|
||||
swa_mask: torch.Tensor = None
|
||||
|
||||
attn_state: Any = None
|
||||
|
||||
graph_pad_size: int = -1
|
||||
@@ -130,9 +122,6 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata):
|
||||
causal=self.causal,
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens],
|
||||
positions=self.positions[:num_actual_tokens],
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
swa_mask=self.swa_mask,
|
||||
attn_state=self.attn_state,
|
||||
graph_pad_size=-1, # It should be -1 when not run in fullgraph mode.
|
||||
num_input_tokens=num_actual_tokens,
|
||||
|
||||
Reference in New Issue
Block a user