diff --git a/python/sglang/srt/layers/attention/torch_native_backend.py b/python/sglang/srt/layers/attention/torch_native_backend.py index bb06076c1..6a67ea947 100644 --- a/python/sglang/srt/layers/attention/torch_native_backend.py +++ b/python/sglang/srt/layers/attention/torch_native_backend.py @@ -193,10 +193,13 @@ class TorchNativeAttnBackend(AttentionBackend): else: o = torch.empty_like(q) + if layer.is_cross_attention: + cache_loc = forward_batch.encoder_out_cache_loc + else: + cache_loc = forward_batch.out_cache_loc + if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) use_gqa = layer.tp_q_head_num != layer.tp_k_head_num @@ -241,10 +244,13 @@ class TorchNativeAttnBackend(AttentionBackend): else: o = torch.empty_like(q) + if layer.is_cross_attention: + cache_loc = forward_batch.encoder_out_cache_loc + else: + cache_loc = forward_batch.out_cache_loc + if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, forward_batch.out_cache_loc, k, v - ) + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) use_gqa = layer.tp_q_head_num != layer.tp_k_head_num