Add awq dequantize kernel to sgl with 1x to 3x speedup (#4104)

This commit is contained in:
Rex
2025-03-12 00:10:02 -07:00
committed by GitHub
parent e0917e6bd0
commit 07f944631e
8 changed files with 324 additions and 0 deletions

View File

@@ -75,6 +75,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From csrc/gemm
*/
m.def("awq_dequantize(Tensor qweight, Tensor scales, Tensor qzeros) -> Tensor");
m.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
m.def(
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
"bias) -> Tensor");