Fuse MLA set kv cache kernel (#5748)

This commit is contained in:
Ke Bao
2025-04-27 09:42:22 +08:00
committed by GitHub
parent 02723e1b0d
commit 799c4bb502
4 changed files with 100 additions and 9 deletions

View File

@@ -92,8 +92,11 @@ class RadixAttention(nn.Module):
if k is not None:
# For cross-layer sharing, kv can be None
assert v is not None
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
if "k_rope" not in kwargs:
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
else:
k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
return forward_batch.attn_backend.forward(
q,