[bugfix]: use correct cache location for cross attention in torch native backend (#8622)

This commit is contained in:
Mahmoud Ashraf
2025-09-05 23:39:46 +03:00
committed by GitHub
parent 4efe844a25
commit e678cc717d

View File

@@ -193,10 +193,13 @@ class TorchNativeAttnBackend(AttentionBackend):
else:
o = torch.empty_like(q)
if layer.is_cross_attention:
cache_loc = forward_batch.encoder_out_cache_loc
else:
cache_loc = forward_batch.out_cache_loc
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
@@ -241,10 +244,13 @@ class TorchNativeAttnBackend(AttentionBackend):
else:
o = torch.empty_like(q)
if layer.is_cross_attention:
cache_loc = forward_batch.encoder_out_cache_loc
else:
cache_loc = forward_batch.out_cache_loc
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num