[1/2] Speed up trtllm_mla attention backend (>10% e2e) (#10473)
This commit is contained in:
@@ -23,6 +23,7 @@ from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_
|
||||
from sgl_kernel.elementwise import (
|
||||
FusedSetKVBufferArg,
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
concat_mla_absorb_q,
|
||||
concat_mla_k,
|
||||
copy_to_gpu_no_ce,
|
||||
downcast_fp8,
|
||||
|
||||
@@ -379,3 +379,15 @@ def concat_mla_k(
|
||||
k_rope: torch.Tensor,
|
||||
):
|
||||
torch.ops.sgl_kernel.concat_mla_k(k, k_nope, k_rope)
|
||||
|
||||
|
||||
def concat_mla_absorb_q(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
):
|
||||
*batch_dims, _ = a.shape
|
||||
out = torch.empty(
|
||||
(*batch_dims, a.shape[-1] + b.shape[-1]), device=a.device, dtype=a.dtype
|
||||
)
|
||||
torch.ops.sgl_kernel.concat_mla_absorb_q(a, b, out)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user