Add dsv3 fused a gemm to sgl-kernel (#7630)
This commit is contained in:
@@ -33,6 +33,7 @@ from sgl_kernel.gemm import (
|
||||
awq_dequantize,
|
||||
bmm_fp8,
|
||||
cutlass_scaled_fp4_mm,
|
||||
dsv3_fused_a_gemm,
|
||||
fp8_blockwise_scaled_mm,
|
||||
fp8_scaled_mm,
|
||||
int8_scaled_mm,
|
||||
|
||||
@@ -82,6 +82,21 @@ def bmm_fp8(
|
||||
return out
|
||||
|
||||
|
||||
def dsv3_fused_a_gemm(
|
||||
mat_a: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if output is None:
|
||||
output = torch.empty(
|
||||
(mat_a.shape[0], mat_b.shape[1]),
|
||||
device=mat_a.device,
|
||||
dtype=mat_a.dtype,
|
||||
)
|
||||
torch.ops.sgl_kernel.dsv3_fused_a_gemm.default(output, mat_a, mat_b)
|
||||
return output
|
||||
|
||||
|
||||
def sgl_per_token_group_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
output_q: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user