fix awq_dequantize (#4333)
This commit is contained in:
@@ -7,7 +7,7 @@ from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
|
||||
def awq_dequantize(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||
) -> torch.ByteTensor:
|
||||
return torch.ops.sgl_kernels.awq_dequantize(qweight, scales, qzeros)
|
||||
return torch.ops.sgl_kernel.awq_dequantize(qweight, scales, qzeros)
|
||||
|
||||
|
||||
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
|
||||
Reference in New Issue
Block a user