Fuse MLA set kv cache kernel (#5748)
This commit is contained in:
@@ -625,6 +625,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
save_kv_cache=True,
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
@@ -639,11 +640,11 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
else:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
||||
layer,
|
||||
cache_loc,
|
||||
k,
|
||||
v,
|
||||
k_rope,
|
||||
)
|
||||
|
||||
# Use precomputed metadata across all layers
|
||||
@@ -887,6 +888,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
save_kv_cache=True,
|
||||
# For multi-head latent attention
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if k is not None:
|
||||
assert v is not None
|
||||
@@ -901,11 +903,11 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
else:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
||||
layer,
|
||||
cache_loc,
|
||||
k,
|
||||
v,
|
||||
k_rope,
|
||||
)
|
||||
|
||||
# Use precomputed metadata across all layers
|
||||
|
||||
Reference in New Issue
Block a user