diff --git a/sgl-kernel/benchmark/bench_awq_dequant.py b/sgl-kernel/benchmark/bench_awq_dequant.py new file mode 100644 index 000000000..22280c250 --- /dev/null +++ b/sgl-kernel/benchmark/bench_awq_dequant.py @@ -0,0 +1,118 @@ +import itertools +from typing import List, Tuple + +import torch +import triton +import triton.testing +from sgl_kernel import awq_dequantize +from vllm import _custom_ops as ops + + +def vllm_awq_dequantize( + qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) + + +def sglang_awq_dequantize( + qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + + return awq_dequantize(qweight, scales, qzeros) + + +def calculate_diff(qweight_row: int, qweight_col: int): + """Calculate difference between VLLM and SGLang implementations.""" + device = torch.device("cuda") + qweight = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (qweight_row, qweight_col), + dtype=torch.int32, + device=device, + ) + group_size = qweight_row + scales_row = qweight_row // group_size + scales_col = qweight_col * 8 + scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device) + qzeros = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (scales_row, qweight_col), + dtype=torch.int32, + device=device, + ) + + vllm_out = vllm_awq_dequantize(qweight, scales, qzeros) + sglang_out = sglang_awq_dequantize(qweight, scales, qzeros) + + output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item() + + if torch.allclose( + vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5 + ): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +qweight_row_range = [3584, 18944, 128, 256, 512, 1024] +qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128] + +configs = list(itertools.product(qweight_row_range, qweight_cols_range)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["qweight_row", "qweight_col"], + x_vals=configs, + line_arg="provider", + line_vals=["vllm", "sglang"], + line_names=["VLLM", "SGL Kernel"], + styles=[("blue", "-"), ("green", "-")], + ylabel="us", + plot_name="awq-dequantize-performance", + args={}, + ) +) +def benchmark(qweight_row, qweight_col, provider): + dtype = torch.float16 + device = torch.device("cuda") + qweight = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (qweight_row, qweight_col), + dtype=torch.int32, + device=device, + ) + group_size = qweight_row + scales_row = qweight_row // group_size + scales_col = qweight_col * 8 + scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device) + qzeros = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (scales_row, qweight_col), + dtype=torch.int32, + device=device, + ) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "vllm": + fn = lambda: vllm_awq_dequantize( + qweight.clone(), scales.clone(), qzeros.clone() + ) + elif provider == "sglang": + fn = lambda: sglang_awq_dequantize( + qweight.clone(), scales.clone(), qzeros.clone() + ) + + ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + calculate_diff(qweight_row=3584, qweight_col=448) + benchmark.run(print_data=True) diff --git a/sgl-kernel/csrc/gemm/awq_kernel.cu b/sgl-kernel/csrc/gemm/awq_kernel.cu new file mode 100644 index 000000000..2b697cae4 --- /dev/null +++ b/sgl-kernel/csrc/gemm/awq_kernel.cu @@ -0,0 +1,127 @@ +// Adapted from +// https://github.com/vllm-project/vllm/blob/eb59b5a6cba6727d3727c0372258db9002f687c1/csrc/quantization/awq/gemm_kernels.cu#L350 +#include +#include +#include + +__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + uint4 result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is + // thanks to the register packing format and the fact that we force our + // integers to be unsigned, and account for this in the fp16 subtractions. In + // addition, I exploit the fact that sub and fma have the same throughput in + // order to convert elt_23 and elt_67 to fp16 without having to shift them to + // the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW + // dependency if we issue immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // This is the half2 {1024, 1024} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-64, -64} represented as an integer. + static constexpr uint32_t NEG_64 = 0xd400d400; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + + return result; +#else + assert(false); + return {}; +#endif +} + +__global__ void __launch_bounds__(256) dequantize_weights( + int* __restrict__ qweight, + half* __restrict__ scales, + int* __restrict__ qzeros, + half* __restrict__ output, + int group_size, + int qweight_cols) { + int col = blockIdx.x * blockDim.x + threadIdx.x; + int row = blockIdx.y * blockDim.y + threadIdx.y; + + uint4 zeros = dequantize_s4_to_fp16x2(qzeros[col + (row / group_size) * qweight_cols]); + uint4 loaded_scale = *(uint4*)(scales + 8 * col + (row / group_size) * qweight_cols * 8); + + uint4 weight_fp16 = dequantize_s4_to_fp16x2(qweight[col + row * qweight_cols]); + + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x)); + asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(loaded_scale.x)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(zeros.y)); + asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(loaded_scale.y)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(zeros.z)); + asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(loaded_scale.z)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(zeros.w)); + asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(loaded_scale.w)); + + half* output_ptr = output + 8 * col + 8 * row * qweight_cols; + *(uint4*)output_ptr = weight_fp16; +} + +torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros) { + int qweight_rows = qweight.size(0); + int qweight_cols = qweight.size(1); + int group_size = qweight_rows / scales.size(0); + + int x_num_threads = 16; + int y_num_threads = 16; + int x_blocks = qweight_cols / x_num_threads; + int y_blocks = qweight_rows / y_num_threads; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight)); + + auto output_tensor_options = torch::TensorOptions().dtype(scales.dtype()).device(scales.device()); + at::Tensor output = torch::empty({qweight_rows, qweight_cols * 8}, output_tensor_options); + + auto _qweight = reinterpret_cast(qweight.data_ptr()); + auto _scales = reinterpret_cast(scales.data_ptr()); + auto _zeros = reinterpret_cast(qzeros.data_ptr()); + auto _output = reinterpret_cast(output.data_ptr()); + + dim3 num_blocks(x_blocks, y_blocks); + dim3 threads_per_block(x_num_threads, y_num_threads); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + dequantize_weights<<>>( + _qweight, _scales, _zeros, _output, group_size, qweight_cols); + + return output; +} diff --git a/sgl-kernel/csrc/torch_extension.cc b/sgl-kernel/csrc/torch_extension.cc index d8bd89917..a5a115e3a 100644 --- a/sgl-kernel/csrc/torch_extension.cc +++ b/sgl-kernel/csrc/torch_extension.cc @@ -75,6 +75,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { /* * From csrc/gemm */ + m.def("awq_dequantize(Tensor qweight, Tensor scales, Tensor qzeros) -> Tensor"); + m.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); + m.def( "int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " "bias) -> Tensor"); diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 34ce443a2..934478eec 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -112,6 +112,7 @@ void apply_rope_pos_ids_cos_sin_cache( /* * From csrc/gemm */ +torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros); torch::Tensor int8_scaled_mm( const torch::Tensor& mat_a, const torch::Tensor& mat_b, diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index c8cb0443d..da2959269 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -23,6 +23,7 @@ from sgl_kernel.elementwise import ( silu_and_mul, ) from sgl_kernel.gemm import ( + awq_dequantize, bmm_fp8, cublas_grouped_gemm, fp8_blockwise_scaled_mm, diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index e5936da56..d68d0a437 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -4,6 +4,12 @@ import torch from sgl_kernel.utils import _get_cache_buf, get_cuda_stream +def awq_dequantize( + qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor +) -> torch.ByteTensor: + return torch.ops.sgl_kernels.awq_dequantize(qweight, scales, qzeros) + + def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): return torch.ops.sgl_kernel.int8_scaled_mm( mat_a, diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 0cf88ff06..885e65489 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -150,6 +150,7 @@ sources = [ "csrc/elementwise/rope.cu", "csrc/gemm/bmm_fp8.cu", "csrc/gemm/cublas_grouped_gemm.cu", + "csrc/gemm/awq_kernel.cu", "csrc/gemm/fp8_gemm_kernel.cu", "csrc/gemm/fp8_blockwise_gemm_kernel.cu", "csrc/gemm/int8_gemm_kernel.cu", diff --git a/sgl-kernel/tests/test_awq_dequant.py b/sgl-kernel/tests/test_awq_dequant.py new file mode 100644 index 000000000..c2a2ee84d --- /dev/null +++ b/sgl-kernel/tests/test_awq_dequant.py @@ -0,0 +1,67 @@ +import itertools +from typing import Optional, Tuple + +import pytest +import torch +from sgl_kernel import awq_dequantize +from vllm import _custom_ops as ops + + +def vllm_awq_dequantize( + qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor +) -> torch.Tensor: + return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) + + +def sglang_awq_dequantize( + qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor +) -> torch.Tensor: + return awq_dequantize(qweight, scales, qzeros) + + +@pytest.mark.parametrize( + "qweight_row,qweight_col", + list( + itertools.product( + [3584, 18944, 128, 256, 512, 1024], [448, 576, 4736, 16, 32, 64, 128] + ) + ), +) +def test_awq_dequant_compare_implementations( + qweight_row: int, + qweight_col: int, +): + device = torch.device("cuda") + + qweight = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (qweight_row, qweight_col), + dtype=torch.int32, + device=device, + ) + group_size = qweight_row + scales_row = qweight_row // group_size + scales_col = qweight_col * 8 + scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device) + qzeros = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (scales_row, qweight_col), + dtype=torch.int32, + device=device, + ) + + # Run both implementations + vllm_out = vllm_awq_dequantize(qweight, scales, qzeros) + sglang_out = sglang_awq_dequantize(qweight, scales, qzeros) + + # Compare results + torch.testing.assert_close( + vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5 + ) + + +if __name__ == "__main__": + # Run the specific test function directly + pytest.main([__file__])