[BugFix][MLA] Fix attn_mask bug for ring mla (#2704)

This PR fix a bug related to attention mask used in ring mla. Current
ring mla has supported compressed mask, so we can directly use a 512 *
512 attention mask.

- vLLM version: v0.10.1.1
- vLLM main:
b5ee1e3261

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-09-04 10:22:46 +08:00
committed by GitHub
parent e11a1bbfc1
commit a58013440a
2 changed files with 15 additions and 11 deletions

View File

@@ -483,10 +483,12 @@ class AscendMLAImpl(MLAAttentionImpl):
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
self.chunked_prefill_for_mla = ascend_config.chunked_prefill_for_mla
vllm_config = get_current_vllm_config()
self.ring_mla_mask_size = 512
self.prefill_mask = None
# Adapt torch air graph mode with spec decoding.
speculative_config = get_current_vllm_config().speculative_config
speculative_config = vllm_config.speculative_config
if speculative_config is not None:
self.spec_token_num = speculative_config.num_speculative_tokens
assert self.spec_token_num > 0
@@ -681,14 +683,10 @@ class AscendMLAImpl(MLAAttentionImpl):
device=q_nope.device)
if self.prefill_mask is None:
self.prefill_mask = torch.triu(
torch.ones(512,
512,
torch.ones(self.ring_mla_mask_size,
self.ring_mla_mask_size,
device=q_nope.device,
dtype=q_nope.dtype),
1) # 512: mask only support 512
if attn_metadata.num_prefills > 1:
self.prefill_mask = self.prefill_mask.unsqueeze(0).repeat(
attn_metadata.num_prefills, 1, 1)
dtype=q_nope.dtype), 1)
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,