Fuse MLA set kv cache kernel (#5748)
This commit is contained in:
@@ -757,14 +757,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
|
||||
k = torch.cat([k_nope, k_pe], dim=-1)
|
||||
|
||||
if self.attention_backend == "fa3":
|
||||
attn_output = self.attn_mqa(
|
||||
q_nope_out, k, k_nope, forward_batch, q_rope=q_pe
|
||||
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
||||
)
|
||||
else:
|
||||
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
||||
k = torch.cat([k_nope, k_pe], dim=-1)
|
||||
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user