fix qwen torchair attention PrefillCacheHit (#2787)
### What this PR does / why we need it? Fix qwen torchair attention PrefillCacheHit ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? vLLM version: v0.10.1.1 vLLM main:e599e2c65e- vLLM version: main - vLLM main:0b9a612fa3Signed-off-by: zhaozixin <zhaozixin1@huawei.com> Co-authored-by: zhaozixin <zhaozixin1@huawei.com>
This commit is contained in:
@@ -374,6 +374,9 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl):
|
|||||||
indices = torch.cat((block_indices, slots_indices), dim=1)
|
indices = torch.cat((block_indices, slots_indices), dim=1)
|
||||||
torch_npu.npu_scatter_nd_update_(key_cache, indices, key)
|
torch_npu.npu_scatter_nd_update_(key_cache, indices, key)
|
||||||
torch_npu.npu_scatter_nd_update_(value_cache, indices, value)
|
torch_npu.npu_scatter_nd_update_(value_cache, indices, value)
|
||||||
|
if attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||||
|
self.key_cache = key_cache
|
||||||
|
self.value_cache = value_cache
|
||||||
|
|
||||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||||
assert attn_metadata is not None
|
assert attn_metadata is not None
|
||||||
@@ -411,11 +414,13 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl):
|
|||||||
assert attn_metadata is not None
|
assert attn_metadata is not None
|
||||||
assert attn_metadata.attn_mask is not None
|
assert attn_metadata.attn_mask is not None
|
||||||
compress_mask = attn_metadata.attn_mask
|
compress_mask = attn_metadata.attn_mask
|
||||||
|
batch_size = attn_metadata.query_lens.shape[0]
|
||||||
|
block_table = attn_metadata.block_tables[:batch_size, :]
|
||||||
torch_npu._npu_flash_attention_qlens(
|
torch_npu._npu_flash_attention_qlens(
|
||||||
query=query,
|
query=query,
|
||||||
key_cache=self.key_cache,
|
key_cache=self.key_cache,
|
||||||
value_cache=self.value_cache,
|
value_cache=self.value_cache,
|
||||||
block_table=attn_metadata.block_tables,
|
block_table=block_table,
|
||||||
mask=compress_mask,
|
mask=compress_mask,
|
||||||
seq_len=attn_metadata.query_lens,
|
seq_len=attn_metadata.query_lens,
|
||||||
context_lens=attn_metadata.seq_lens,
|
context_lens=attn_metadata.seq_lens,
|
||||||
|
|||||||
Reference in New Issue
Block a user