[1/2] Speed up prefill mla attention (#10156)
This commit is contained in:
@@ -371,3 +371,11 @@ def downcast_fp8(
|
||||
|
||||
def copy_to_gpu_no_ce(input: List[int], output: torch.Tensor):
|
||||
torch.ops.sgl_kernel.copy_to_gpu_no_ce(input, output)
|
||||
|
||||
|
||||
def concat_mla_k(
|
||||
k: torch.Tensor,
|
||||
k_nope: torch.Tensor,
|
||||
k_rope: torch.Tensor,
|
||||
):
|
||||
torch.ops.sgl_kernel.concat_mla_k(k, k_nope, k_rope)
|
||||
|
||||
Reference in New Issue
Block a user