From 9af34755ff87c4884e1b949311fabe7797703a74 Mon Sep 17 00:00:00 2001 From: Ting FU Date: Sat, 29 Nov 2025 09:20:22 +0800 Subject: [PATCH] [Bugfix] Fix model run _npu_flash_attention hang issue (#4410) Fix model run _npu_flash_attention in _forward_prefill_no_cache hang issue, it was caused by wrong attention mask dtype. ### How was this patch tested? Yes, tesed on Qwen2.5-VL and Qwen2.5-Omni - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 Signed-off-by: Ting FU --- tests/ut/attention/test_attention_mask.py | 7 ++++--- vllm_ascend/attention/attention_mask.py | 2 -- vllm_ascend/worker/model_runner_v1.py | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/ut/attention/test_attention_mask.py b/tests/ut/attention/test_attention_mask.py index c8139b71..9bd4cd0e 100644 --- a/tests/ut/attention/test_attention_mask.py +++ b/tests/ut/attention/test_attention_mask.py @@ -74,10 +74,11 @@ class TestAttentionMaskBuilder(TestBase): attn_mask = attention_mask_builder.get_attn_mask( max_seq_len=2048, dtype=torch.float16, device=torch.device("cpu")) self.assertEqual(attn_mask.shape, (2048, 2048)) - self.assertEqual(attn_mask[0][-1], torch.tensor(True)) - self.assertEqual(attention_mask_builder._seq_len_cached, 1024) + self.assertEqual(attn_mask[0][-1], + torch.tensor(float("-inf"), dtype=torch.float16)) + self.assertEqual(attention_mask_builder._seq_len_cached, 2048) self.assertEqual(attention_mask_builder.attn_mask_cache.shape, - (1024, 1024)) + (2048, 2048)) self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], torch.tensor(float("-inf"), dtype=torch.float16)) diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index 3514984d..2c963b5c 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -67,8 +67,6 @@ class AttentionMaskBuilder: def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype, device: torch.device): - if max_seq_len == 2048: - return self.chunked_prefill_attn_mask.to(torch.bool) self._update_attn_cache(max_seq_len, dtype) return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous( ).to(device, non_blocking=True) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ff55d1d1..2e7c4ea2 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -991,8 +991,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): max_seq_len, self.dtype, self.device) # Prefill with cache hit. elif attn_state == AscendAttentionState.PrefillCacheHit: - return self.attn_mask_builder.get_attn_mask( - 2048, self.dtype, self.device) + return self.attn_mask_builder.get_splitfuse_attn_mask().to( + torch.bool) # Decode-only situation. else: return None