fix qwen torchair attention PrefillCacheHit (#2787)

### What this PR does / why we need it?
Fix qwen torchair attention PrefillCacheHit 
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
vLLM version: v0.10.1.1
vLLM main:
e599e2c65e

- vLLM version: main
- vLLM main:
0b9a612fa3

Signed-off-by: zhaozixin <zhaozixin1@huawei.com>
Co-authored-by: zhaozixin <zhaozixin1@huawei.com>
This commit is contained in:
zhaozx-cn
2025-09-11 14:26:59 +08:00
committed by GitHub
parent 7b2ecc1e9a
commit b9a0a75c78

View File

@@ -374,6 +374,9 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl):
indices = torch.cat((block_indices, slots_indices), dim=1)
torch_npu.npu_scatter_nd_update_(key_cache, indices, key)
torch_npu.npu_scatter_nd_update_(value_cache, indices, value)
if attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
self.key_cache = key_cache
self.value_cache = value_cache
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
assert attn_metadata is not None
@@ -411,11 +414,13 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl):
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
compress_mask = attn_metadata.attn_mask
batch_size = attn_metadata.query_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :]
torch_npu._npu_flash_attention_qlens(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
block_table=attn_metadata.block_tables,
block_table=block_table,
mask=compress_mask,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,