diff --git a/vllm_ascend/torchair/torchair_attention.py b/vllm_ascend/torchair/torchair_attention.py index 81f2968..d2443ad 100644 --- a/vllm_ascend/torchair/torchair_attention.py +++ b/vllm_ascend/torchair/torchair_attention.py @@ -374,6 +374,9 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl): indices = torch.cat((block_indices, slots_indices), dim=1) torch_npu.npu_scatter_nd_update_(key_cache, indices, key) torch_npu.npu_scatter_nd_update_(value_cache, indices, value) + if attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: + self.key_cache = key_cache + self.value_cache = value_cache if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: assert attn_metadata is not None @@ -411,11 +414,13 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl): assert attn_metadata is not None assert attn_metadata.attn_mask is not None compress_mask = attn_metadata.attn_mask + batch_size = attn_metadata.query_lens.shape[0] + block_table = attn_metadata.block_tables[:batch_size, :] torch_npu._npu_flash_attention_qlens( query=query, key_cache=self.key_cache, value_cache=self.value_cache, - block_table=attn_metadata.block_tables, + block_table=block_table, mask=compress_mask, seq_len=attn_metadata.query_lens, context_lens=attn_metadata.seq_lens,