[2/2] Speed up prefill mla attention concat (#10157)
This commit is contained in:
@@ -154,6 +154,7 @@ if _is_cuda:
|
|||||||
from sgl_kernel import (
|
from sgl_kernel import (
|
||||||
awq_dequantize,
|
awq_dequantize,
|
||||||
bmm_fp8,
|
bmm_fp8,
|
||||||
|
concat_mla_k,
|
||||||
dsv3_fused_a_gemm,
|
dsv3_fused_a_gemm,
|
||||||
dsv3_router_gemm,
|
dsv3_router_gemm,
|
||||||
merge_state_v2,
|
merge_state_v2,
|
||||||
@@ -1295,8 +1296,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||||
q[..., self.qk_nope_head_dim :] = q_pe
|
q[..., self.qk_nope_head_dim :] = q_pe
|
||||||
k = torch.empty_like(q)
|
k = torch.empty_like(q)
|
||||||
k[..., : self.qk_nope_head_dim] = k_nope
|
|
||||||
k[..., self.qk_nope_head_dim :] = k_pe
|
# Temporary for DeepSeek V3/R1 only, but can generalize if needed
|
||||||
|
if (
|
||||||
|
_is_cuda
|
||||||
|
and (self.num_local_heads == 128)
|
||||||
|
and (self.qk_nope_head_dim == 128)
|
||||||
|
and (self.qk_rope_head_dim == 64)
|
||||||
|
):
|
||||||
|
concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
|
||||||
|
else:
|
||||||
|
k[..., : self.qk_nope_head_dim] = k_nope
|
||||||
|
k[..., self.qk_nope_head_dim :] = k_pe
|
||||||
|
|
||||||
if not _is_npu:
|
if not _is_npu:
|
||||||
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
||||||
|
|||||||
Reference in New Issue
Block a user