[BugFix] [310p] Fix attention accuracy issue (#6803)

### What this PR does / why we need it?
This pull request resolves an attention accuracy issue by enhancing the
AttentionMaskBuilder310 to correctly handle the maximum model length.
The change ensures that the attention mask generation process is
properly parameterized by the model's configuration, rather than relying
on a fixed internal value. This leads to more accurate attention mask
creation, which is crucial for the correct functioning of the attention
mechanism.
Update fused_moe to main branch.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Qwen3 dense mode & moe model e2e test
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1

---------

Signed-off-by: pu-zhe <zpuaa@outlook.com>
This commit is contained in:
pu-zhe
2026-02-26 14:30:39 +08:00
committed by GitHub
parent 9f8b84e5fc
commit e76b69b9ef
8 changed files with 76 additions and 43 deletions

View File

@@ -24,19 +24,20 @@ from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, nd_to_nz_2d, nd_to_nz_spec
class AttentionMaskBuilder310:
chunked_prefill_attn_mask = None
max_seqlen = 2048
max_seqlen = 16384
def __init__(self, device: torch.device):
def __init__(self, device: torch.device, max_seqlen: int):
"""
Initializes the AttentionMaskBuilder for the 310P device.
Args:
device (torch.device): The device on which tensors will be allocated.
max_seqlen (int): Maximum length of a sequence (including prompt and generated text).
"""
AttentionMaskBuilder310.max_seqlen = max_seqlen
self.attn_mask_cache = None
self.device = device
self.swa_mask = None
self._seq_len_cached = 0
@staticmethod
def gen_causal_additive_mask(max_seq_len: int, device: torch.device):
@@ -147,8 +148,7 @@ class AttentionMaskBuilder310:
Returns:
torch.Tensor: The cached causal mask in ACL_FORMAT_FRACTAL_NZ.
"""
if self.attn_mask_cache is None or max_seq_len > self._seq_len_cached:
if self.attn_mask_cache is None:
attn_mask = self.gen_causal_additive_mask(max_seq_len, self.device)
self.attn_mask_cache = torch_npu.npu_format_cast(nd_to_nz_2d(attn_mask), ACL_FORMAT_FRACTAL_NZ)
self._seq_len_cached = max_seq_len
return self.attn_mask_cache