[bugfix]: use correct cache location for cross attention in torch native backend (#8622)
This commit is contained in:
@@ -193,10 +193,13 @@ class TorchNativeAttnBackend(AttentionBackend):
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
if layer.is_cross_attention:
|
||||
cache_loc = forward_batch.encoder_out_cache_loc
|
||||
else:
|
||||
cache_loc = forward_batch.out_cache_loc
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||
|
||||
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
||||
|
||||
@@ -241,10 +244,13 @@ class TorchNativeAttnBackend(AttentionBackend):
|
||||
else:
|
||||
o = torch.empty_like(q)
|
||||
|
||||
if layer.is_cross_attention:
|
||||
cache_loc = forward_batch.encoder_out_cache_loc
|
||||
else:
|
||||
cache_loc = forward_batch.out_cache_loc
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, v
|
||||
)
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
||||
|
||||
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
||||
|
||||
|
||||
Reference in New Issue
Block a user