Fuse MLA set kv cache kernel (#5748)
This commit is contained in:
@@ -92,8 +92,11 @@ class RadixAttention(nn.Module):
|
||||
if k is not None:
|
||||
# For cross-layer sharing, kv can be None
|
||||
assert v is not None
|
||||
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||
if "k_rope" not in kwargs:
|
||||
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
||||
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
||||
else:
|
||||
k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
|
||||
|
||||
return forward_batch.attn_backend.forward(
|
||||
q,
|
||||
|
||||
Reference in New Issue
Block a user