Fuse MLA set kv cache kernel (#5748)

This commit is contained in:
Ke Bao
2025-04-27 09:42:22 +08:00
committed by GitHub
parent 02723e1b0d
commit 799c4bb502
4 changed files with 100 additions and 9 deletions

View File

@@ -625,6 +625,7 @@ class FlashAttentionBackend(AttentionBackend):
save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
if k is not None:
assert v is not None
@@ -639,11 +640,11 @@ class FlashAttentionBackend(AttentionBackend):
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
v,
k_rope,
)
# Use precomputed metadata across all layers
@@ -887,6 +888,7 @@ class FlashAttentionBackend(AttentionBackend):
save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if k is not None:
assert v is not None
@@ -901,11 +903,11 @@ class FlashAttentionBackend(AttentionBackend):
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
v,
k_rope,
)
# Use precomputed metadata across all layers