[bugfix]: use correct cache location for cross attention in torch native backend (#8622)
This commit is contained in:
@@ -193,10 +193,13 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
o = torch.empty_like(q)
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
|
if layer.is_cross_attention:
|
||||||
|
cache_loc = forward_batch.encoder_out_cache_loc
|
||||||
|
else:
|
||||||
|
cache_loc = forward_batch.out_cache_loc
|
||||||
|
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||||
layer, forward_batch.out_cache_loc, k, v
|
|
||||||
)
|
|
||||||
|
|
||||||
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
||||||
|
|
||||||
@@ -241,10 +244,13 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
o = torch.empty_like(q)
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
|
if layer.is_cross_attention:
|
||||||
|
cache_loc = forward_batch.encoder_out_cache_loc
|
||||||
|
else:
|
||||||
|
cache_loc = forward_batch.out_cache_loc
|
||||||
|
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||||
layer, forward_batch.out_cache_loc, k, v
|
|
||||||
)
|
|
||||||
|
|
||||||
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user