[2/2] Speed up prefill mla attention concat (#10157)
This commit is contained in:
@@ -154,6 +154,7 @@ if _is_cuda:
|
||||
from sgl_kernel import (
|
||||
awq_dequantize,
|
||||
bmm_fp8,
|
||||
concat_mla_k,
|
||||
dsv3_fused_a_gemm,
|
||||
dsv3_router_gemm,
|
||||
merge_state_v2,
|
||||
@@ -1295,8 +1296,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
q[..., self.qk_nope_head_dim :] = q_pe
|
||||
k = torch.empty_like(q)
|
||||
k[..., : self.qk_nope_head_dim] = k_nope
|
||||
k[..., self.qk_nope_head_dim :] = k_pe
|
||||
|
||||
# Temporary for DeepSeek V3/R1 only, but can generalize if needed
|
||||
if (
|
||||
_is_cuda
|
||||
and (self.num_local_heads == 128)
|
||||
and (self.qk_nope_head_dim == 128)
|
||||
and (self.qk_rope_head_dim == 64)
|
||||
):
|
||||
concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
|
||||
else:
|
||||
k[..., : self.qk_nope_head_dim] = k_nope
|
||||
k[..., self.qk_nope_head_dim :] = k_pe
|
||||
|
||||
if not _is_npu:
|
||||
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
||||
|
||||
Reference in New Issue
Block a user