[Refactor] Fix AttentionMaskBuilder singleton and remove redundant pcp_prefill_mask (#4870)
## What this PR does / why we need it? This PR fixes the `AttentionMaskBuilder` singleton initialization issue introduced in PR #4779 and removes the unused `pcp_prefill_mask` field. ### Background After PR #4779 made `AttentionMaskBuilder` a singleton with `@singleton` decorator, the class constructor now requires a `device` parameter. However, two initialization sites were still using the old parameterless constructor, causing failures. ### Changes 1. **Fix singleton initialization** - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendMLAMetadataBuilder.__init__()` - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendAttentionMetadataBuilder.__init__()` 2. **Remove unused field** - Removed `pcp_prefill_mask` field from `AscendPrefillContextParallelMetadata` (never used in codebase) - Updated related test assertions ### Related - Issue #5463 - PR #4779 (Unify all mask generation methods) - PR #5389 (Make AttentionMaskBuilder singleton) ## Does this PR introduce _any_ user-facing change? No. This is an internal refactoring. ## How was this patch tested? - ✅ Local testing: No linter errors - ✅ Unit tests for attention modules verified - ⏳ CI pipeline Signed-off-by: lico67373 <918688502@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -70,7 +70,7 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
|
||||
input_batch.num_tokens = torch.tensor(num_tokens)
|
||||
|
||||
query_lens = torch.tensor(query_lens)
|
||||
result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens, None,
|
||||
result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens,
|
||||
input_batch)
|
||||
|
||||
if not expect_not_none:
|
||||
@@ -97,21 +97,6 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
|
||||
assert hasattr(result, 'head_attn_nomask_seqlens')
|
||||
assert hasattr(result, 'tail_attn_nomask_seqlens')
|
||||
|
||||
if hasattr(result, 'pcp_prefill_mask'
|
||||
) and result.pcp_prefill_mask is not None:
|
||||
if use_mla:
|
||||
assert result.pcp_prefill_mask.shape == (512, 512)
|
||||
else:
|
||||
assert result.pcp_prefill_mask.shape == (2048, 2048)
|
||||
else:
|
||||
if hasattr(result, 'pcp_prefill_mask'):
|
||||
if result.pcp_prefill_mask is not None:
|
||||
if use_mla:
|
||||
assert result.pcp_prefill_mask.shape == (512, 512)
|
||||
else:
|
||||
assert result.pcp_prefill_mask.shape == (2048,
|
||||
2048)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens",
|
||||
|
||||
Reference in New Issue
Block a user