Add dsv3 fused a gemm to sgl-kernel (#7630)

This commit is contained in:
Ke Bao
2025-06-29 17:52:24 +08:00
committed by GitHub
parent 071a1f51ae
commit 04b35190e2
9 changed files with 800 additions and 0 deletions

View File

@@ -33,6 +33,7 @@ from sgl_kernel.gemm import (
awq_dequantize,
bmm_fp8,
cutlass_scaled_fp4_mm,
dsv3_fused_a_gemm,
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
int8_scaled_mm,

View File

@@ -82,6 +82,21 @@ def bmm_fp8(
return out
def dsv3_fused_a_gemm(
mat_a: torch.Tensor,
mat_b: torch.Tensor,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is None:
output = torch.empty(
(mat_a.shape[0], mat_b.shape[1]),
device=mat_a.device,
dtype=mat_a.dtype,
)
torch.ops.sgl_kernel.dsv3_fused_a_gemm.default(output, mat_a, mat_b)
return output
def sgl_per_token_group_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,