Add awq dequantize kernel to sgl with 1x to 3x speedup (#4104)

This commit is contained in:
Rex
2025-03-12 00:10:02 -07:00
committed by GitHub
parent e0917e6bd0
commit 07f944631e
8 changed files with 324 additions and 0 deletions

View File

@@ -23,6 +23,7 @@ from sgl_kernel.elementwise import (
silu_and_mul,
)
from sgl_kernel.gemm import (
awq_dequantize,
bmm_fp8,
cublas_grouped_gemm,
fp8_blockwise_scaled_mm,

View File

@@ -4,6 +4,12 @@ import torch
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
def awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> torch.ByteTensor:
return torch.ops.sgl_kernels.awq_dequantize(qweight, scales, qzeros)
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernel.int8_scaled_mm(
mat_a,