[Refactor] Fix AttentionMaskBuilder singleton and remove redundant pcp_prefill_mask (#4870)
## What this PR does / why we need it? This PR fixes the `AttentionMaskBuilder` singleton initialization issue introduced in PR #4779 and removes the unused `pcp_prefill_mask` field. ### Background After PR #4779 made `AttentionMaskBuilder` a singleton with `@singleton` decorator, the class constructor now requires a `device` parameter. However, two initialization sites were still using the old parameterless constructor, causing failures. ### Changes 1. **Fix singleton initialization** - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendMLAMetadataBuilder.__init__()` - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendAttentionMetadataBuilder.__init__()` 2. **Remove unused field** - Removed `pcp_prefill_mask` field from `AscendPrefillContextParallelMetadata` (never used in codebase) - Updated related test assertions ### Related - Issue #5463 - PR #4779 (Unify all mask generation methods) - PR #5389 (Make AttentionMaskBuilder singleton) ## Does this PR introduce _any_ user-facing change? No. This is an internal refactoring. ## How was this patch tested? - ✅ Local testing: No linter errors - ✅ Unit tests for attention modules verified - ⏳ CI pipeline Signed-off-by: lico67373 <918688502@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -340,7 +340,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
|
||||
spec_attn_mask, sparse_mode, scale, block_table, block_size,
|
||||
attn_mask, sparse_mode, scale, block_table, block_size,
|
||||
seq_lens_list, actual_seq_lengths, attn_output,
|
||||
softmax_lse) = param
|
||||
seq_lens_list = forward_context.attn_metadata[
|
||||
@@ -380,7 +380,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=num_kv_heads,
|
||||
input_layout=input_layout,
|
||||
atten_mask=spec_attn_mask,
|
||||
atten_mask=attn_mask,
|
||||
sparse_mode=sparse_mode,
|
||||
scale=scale,
|
||||
antiquant_mode=0,
|
||||
@@ -480,7 +480,7 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
|
||||
seq_len = decode_meta.cp_seq_len
|
||||
|
||||
# For pcp + spec decode, we flatten seq_lens
|
||||
# to avoid irregular spec_attn_mask shape,
|
||||
# to avoid irregular attn_mask shape,
|
||||
# so there's no need to divide runtime_shape by spec_multiple
|
||||
pad_length = runtime_shape - len(seq_len)
|
||||
pad_tensor = torch.zeros(pad_length,
|
||||
|
||||
Reference in New Issue
Block a user