[1/2] Speed up trtllm_mla attention backend (>10% e2e) (#10473)

This commit is contained in:
fzyzcjy
2025-09-16 02:53:21 +08:00
committed by GitHub
parent 5c08d7d21d
commit 3b25dc127a
6 changed files with 119 additions and 3 deletions

View File

@@ -379,3 +379,15 @@ def concat_mla_k(
k_rope: torch.Tensor,
):
torch.ops.sgl_kernel.concat_mla_k(k, k_nope, k_rope)
def concat_mla_absorb_q(
a: torch.Tensor,
b: torch.Tensor,
):
*batch_dims, _ = a.shape
out = torch.empty(
(*batch_dims, a.shape[-1] + b.shape[-1]), device=a.device, dtype=a.dtype
)
torch.ops.sgl_kernel.concat_mla_absorb_q(a, b, out)
return out