Add awq dequantize kernel to sgl with 1x to 3x speedup (#4104)
This commit is contained in:
@@ -23,6 +23,7 @@ from sgl_kernel.elementwise import (
|
||||
silu_and_mul,
|
||||
)
|
||||
from sgl_kernel.gemm import (
|
||||
awq_dequantize,
|
||||
bmm_fp8,
|
||||
cublas_grouped_gemm,
|
||||
fp8_blockwise_scaled_mm,
|
||||
|
||||
@@ -4,6 +4,12 @@ import torch
|
||||
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
|
||||
|
||||
|
||||
def awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> torch.ByteTensor:
|
||||
return torch.ops.sgl_kernels.awq_dequantize(qweight, scales, qzeros)
|
||||
|
||||
|
||||
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return torch.ops.sgl_kernel.int8_scaled_mm(
|
||||
mat_a,
|
||||
|
||||
Reference in New Issue
Block a user