diff --git a/sgl-kernel/csrc/gemm/awq_kernel.cu b/sgl-kernel/csrc/gemm/awq_kernel.cu index 2b697cae4..0c144d40f 100644 --- a/sgl-kernel/csrc/gemm/awq_kernel.cu +++ b/sgl-kernel/csrc/gemm/awq_kernel.cu @@ -3,6 +3,16 @@ #include #include #include +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 @@ -68,32 +78,102 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { #endif } +__device__ uint4 dequantize_s4_to_bf16x2(uint32_t const& source) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + uint4 result; + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = source; + + // Define masks and constants + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC300C300; + + int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s, MASK, EX); + int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 4, MASK, EX); + int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 8, MASK, EX); + int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 12, MASK, EX); + + nv_bfloat162* res = reinterpret_cast(h); + res[0] = __hfma2( + *reinterpret_cast(&lo0), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[1] = __hfma2( + *reinterpret_cast(&hi0), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[2] = __hfma2( + *reinterpret_cast(&lo1), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + res[3] = __hfma2( + *reinterpret_cast(&hi1), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + + return result; +#else + assert(false); + return {}; +#endif +} + +template __global__ void __launch_bounds__(256) dequantize_weights( int* __restrict__ qweight, - half* __restrict__ scales, + OutputT* __restrict__ scales, int* __restrict__ qzeros, - half* __restrict__ output, + OutputT* __restrict__ output, int group_size, int qweight_cols) { int col = blockIdx.x * blockDim.x + threadIdx.x; int row = blockIdx.y * blockDim.y + threadIdx.y; - uint4 zeros = dequantize_s4_to_fp16x2(qzeros[col + (row / group_size) * qweight_cols]); - uint4 loaded_scale = *(uint4*)(scales + 8 * col + (row / group_size) * qweight_cols * 8); + int group_idx = row / group_size; + int scale_offset = 8 * col + group_idx * qweight_cols * 8; + uint4 loaded_scale = *(uint4*)(scales + scale_offset); - uint4 weight_fp16 = dequantize_s4_to_fp16x2(qweight[col + row * qweight_cols]); + // Handle different data types + if constexpr (std::is_same::value) { + // FP16 path + uint4 zeros = dequantize_s4_to_fp16x2(qzeros[col + group_idx * qweight_cols]); + uint4 weight_fp16 = dequantize_s4_to_fp16x2(qweight[col + row * qweight_cols]); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x)); - asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(loaded_scale.x)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(zeros.y)); - asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(loaded_scale.y)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(zeros.z)); - asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(loaded_scale.z)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(zeros.w)); - asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(loaded_scale.w)); + // Use PTX assembly for FP16 operations + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x)); + asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(loaded_scale.x)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(zeros.y)); + asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(loaded_scale.y)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(zeros.z)); + asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(loaded_scale.z)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(zeros.w)); + asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(loaded_scale.w)); - half* output_ptr = output + 8 * col + 8 * row * qweight_cols; - *(uint4*)output_ptr = weight_fp16; + OutputT* output_ptr = output + 8 * col + 8 * row * qweight_cols; + *(uint4*)output_ptr = weight_fp16; + } else if constexpr (std::is_same::value) { + uint4 weight_raw = dequantize_s4_to_bf16x2(qweight[col + row * qweight_cols]); + uint4 zero_raw = dequantize_s4_to_bf16x2(qzeros[col + group_idx * qweight_cols]); + uint4 scale_raw = *reinterpret_cast(scales + scale_offset); + + // Vectorized processing (each uint4 contains 4 nv_bfloat162) + nv_bfloat162* weight_vec = reinterpret_cast(&weight_raw); + nv_bfloat162* zero_vec = reinterpret_cast(&zero_raw); + nv_bfloat162* scale_vec = reinterpret_cast(&scale_raw); + +// Single instruction dual-channel operation +#pragma unroll + for (int i = 0; i < 4; ++i) { // uint4 = 4 * nv_bfloat162 + weight_vec[i] = __hmul2(__hsub2(weight_vec[i], zero_vec[i]), scale_vec[i]); + } + + // Directly store to OutputT array (guaranteed contiguous memory) + OutputT* output_ptr = output + 8 * col + row * qweight_cols * 8; + static_assert(sizeof(uint4) == 8 * sizeof(OutputT), "Memory layout mismatch"); + *reinterpret_cast(output_ptr) = weight_raw; + } } torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros) { @@ -112,16 +192,23 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch: at::Tensor output = torch::empty({qweight_rows, qweight_cols * 8}, output_tensor_options); auto _qweight = reinterpret_cast(qweight.data_ptr()); - auto _scales = reinterpret_cast(scales.data_ptr()); auto _zeros = reinterpret_cast(qzeros.data_ptr()); - auto _output = reinterpret_cast(output.data_ptr()); dim3 num_blocks(x_blocks, y_blocks); dim3 threads_per_block(x_num_threads, y_num_threads); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - dequantize_weights<<>>( - _qweight, _scales, _zeros, _output, group_size, qweight_cols); + + if (scales.scalar_type() == at::ScalarType::Half) { + auto _scales = reinterpret_cast(scales.data_ptr()); + auto _output = reinterpret_cast(output.data_ptr()); + dequantize_weights + <<>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols); + } else { + auto _scales = reinterpret_cast<__nv_bfloat16*>(scales.data_ptr()); + auto _output = reinterpret_cast<__nv_bfloat16*>(output.data_ptr()); + dequantize_weights<__nv_bfloat16> + <<>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols); + } return output; } diff --git a/sgl-kernel/tests/test_awq_dequant.py b/sgl-kernel/tests/test_awq_dequant.py index c2a2ee84d..bad3e2c10 100644 --- a/sgl-kernel/tests/test_awq_dequant.py +++ b/sgl-kernel/tests/test_awq_dequant.py @@ -7,6 +7,57 @@ from sgl_kernel import awq_dequantize from vllm import _custom_ops as ops +def reverse_awq_order(t: torch.Tensor): + bits = 4 + AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + reverse_order_tensor = torch.arange( + t.shape[-1], + dtype=torch.int32, + device=t.device, + ) + reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) + reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] + reverse_order_tensor = reverse_order_tensor.view(-1) + + t = t[:, reverse_order_tensor] & 0xF + return t + + +# qweights - [R , C // 8], int32 +# scales - [R // G, C ], float16 +# zeros - [R // G, C // 8], int32 +def awq_dequantize_torch( + qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int +) -> torch.Tensor: + + if group_size == -1: + group_size = qweight.shape[0] + + bits = 4 + shifts = torch.arange(0, 32, bits, device=qzeros.device) + + iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to( + torch.int8 + ) + + iweights = iweights.view(iweights.shape[0], -1) + + zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to( + torch.int8 + ) + zeros = zeros.view(qzeros.shape[0], -1) + zeros = reverse_awq_order(zeros) + + iweights = reverse_awq_order(iweights) + + iweights = torch.bitwise_and(iweights, (2**bits) - 1) + zeros = torch.bitwise_and(zeros, (2**bits) - 1) + + scales = scales.repeat_interleave(group_size, dim=0) + zeros = zeros.repeat_interleave(group_size, dim=0) + return (iweights - zeros) * scales + + def vllm_awq_dequantize( qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor ) -> torch.Tensor: @@ -20,16 +71,17 @@ def sglang_awq_dequantize( @pytest.mark.parametrize( - "qweight_row,qweight_col", + "qweight_row,qweight_col,is_bf16_act", list( itertools.product( - [3584, 18944, 128, 256, 512, 1024], [448, 576, 4736, 16, 32, 64, 128] + [3584, 18944, 128, 256, 512, 1024], + [448, 576, 4736, 16, 32, 64, 128], + [True, False], ) ), ) def test_awq_dequant_compare_implementations( - qweight_row: int, - qweight_col: int, + qweight_row: int, qweight_col: int, is_bf16_act: bool ): device = torch.device("cuda") @@ -43,7 +95,12 @@ def test_awq_dequant_compare_implementations( group_size = qweight_row scales_row = qweight_row // group_size scales_col = qweight_col * 8 - scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device) + + if is_bf16_act: + scales = torch.rand(scales_row, scales_col, dtype=torch.bfloat16, device=device) + else: + scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device) + qzeros = torch.randint( 0, torch.iinfo(torch.int32).max, @@ -53,13 +110,21 @@ def test_awq_dequant_compare_implementations( ) # Run both implementations - vllm_out = vllm_awq_dequantize(qweight, scales, qzeros) + vllm_out = vllm_awq_dequantize(qweight, scales.to(torch.float16), qzeros) + torch_out = awq_dequantize_torch(qweight, scales, qzeros, group_size) sglang_out = sglang_awq_dequantize(qweight, scales, qzeros) # Compare results torch.testing.assert_close( - vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5 + torch_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5 ) + if not is_bf16_act: + torch.testing.assert_close( + vllm_out.to(torch.float32), + sglang_out.to(torch.float32), + rtol=1e-3, + atol=1e-5, + ) if __name__ == "__main__":