From b9a0a75c783571caf22129612fb3338272d1782c Mon Sep 17 00:00:00 2001 From: zhaozx-cn <59479021+zhaozx-cn@users.noreply.github.com> Date: Thu, 11 Sep 2025 14:26:59 +0800 Subject: [PATCH] fix qwen torchair attention PrefillCacheHit (#2787) ### What this PR does / why we need it? Fix qwen torchair attention PrefillCacheHit ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? vLLM version: v0.10.1.1 vLLM main: https://github.com/vllm-project/vllm/commit/e599e2c65ee32abcc986733ab0a55becea158bb4 - vLLM version: main - vLLM main: https://github.com/vllm-project/vllm/commit/0b9a612fa327c4b0f4dedf6dc9b2f8c03eb23e36 Signed-off-by: zhaozixin Co-authored-by: zhaozixin --- vllm_ascend/torchair/torchair_attention.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/torchair/torchair_attention.py b/vllm_ascend/torchair/torchair_attention.py index 81f2968..d2443ad 100644 --- a/vllm_ascend/torchair/torchair_attention.py +++ b/vllm_ascend/torchair/torchair_attention.py @@ -374,6 +374,9 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl): indices = torch.cat((block_indices, slots_indices), dim=1) torch_npu.npu_scatter_nd_update_(key_cache, indices, key) torch_npu.npu_scatter_nd_update_(value_cache, indices, value) + if attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: + self.key_cache = key_cache + self.value_cache = value_cache if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: assert attn_metadata is not None @@ -411,11 +414,13 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl): assert attn_metadata is not None assert attn_metadata.attn_mask is not None compress_mask = attn_metadata.attn_mask + batch_size = attn_metadata.query_lens.shape[0] + block_table = attn_metadata.block_tables[:batch_size, :] torch_npu._npu_flash_attention_qlens( query=query, key_cache=self.key_cache, value_cache=self.value_cache, - block_table=attn_metadata.block_tables, + block_table=block_table, mask=compress_mask, seq_len=attn_metadata.query_lens, context_lens=attn_metadata.seq_lens,