fix awq_dequantize (#4333)
This commit is contained in:
@@ -7,7 +7,7 @@ from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
|
|||||||
def awq_dequantize(
|
def awq_dequantize(
|
||||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||||
) -> torch.ByteTensor:
|
) -> torch.ByteTensor:
|
||||||
return torch.ops.sgl_kernels.awq_dequantize(qweight, scales, qzeros)
|
return torch.ops.sgl_kernel.awq_dequantize(qweight, scales, qzeros)
|
||||||
|
|
||||||
|
|
||||||
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||||
|
|||||||
Reference in New Issue
Block a user