[1/2] Speed up prefill mla attention (#10156)

This commit is contained in:
fzyzcjy
2025-09-09 00:00:33 +08:00
committed by GitHub
parent 2c2b19b18b
commit 0096798ed6
6 changed files with 130 additions and 0 deletions

View File

@@ -371,3 +371,11 @@ def downcast_fp8(
def copy_to_gpu_no_ce(input: List[int], output: torch.Tensor):
torch.ops.sgl_kernel.copy_to_gpu_no_ce(input, output)
def concat_mla_k(
k: torch.Tensor,
k_nope: torch.Tensor,
k_rope: torch.Tensor,
):
torch.ops.sgl_kernel.concat_mla_k(k, k_nope, k_rope)