From 7750b91ca81d15b85290703f24f8cd2716fe149a Mon Sep 17 00:00:00 2001 From: Hubert Lu <55214931+hubertlu-tw@users.noreply.github.com> Date: Fri, 18 Jul 2025 14:27:25 -0700 Subject: [PATCH] [AMD] Add triton awq_dequantize kernel to support AWQ on ROCm (#7661) --- python/sglang/srt/layers/quantization/awq.py | 12 +- .../srt/layers/quantization/awq_triton.py | 339 ++++++++++++++++++ python/sglang/srt/models/deepseek_v2.py | 6 +- test/srt/run_suite.py | 1 + test/srt/test_awq_dequant.py | 175 +++++++++ 5 files changed, 530 insertions(+), 3 deletions(-) create mode 100644 python/sglang/srt/layers/quantization/awq_triton.py create mode 100644 test/srt/test_awq_dequant.py diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index 453267383..c20beb2ff 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -43,11 +43,20 @@ try: except ImportError: ops = None -from sglang.srt.utils import is_cuda +from sglang.srt.utils import is_cuda, is_hip _is_cuda = is_cuda() +_is_hip = is_hip() if _is_cuda: from sgl_kernel import awq_dequantize, fused_marlin_moe +elif _is_hip: + from sglang.srt.layers.quantization.awq_triton import ( + awq_dequantize_triton as awq_dequantize, + ) + + warnings.warn(f"HIP does not support fused_marlin_moe currently.") +else: + warnings.warn(f"Only CUDA and HIP support AWQ currently.") logger = logging.getLogger(__name__) @@ -398,7 +407,6 @@ class AWQLinearMethod(LinearMethodBase): pack_factor = self.quant_config.pack_factor out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,) reshaped_x = x.reshape(-1, x.shape[-1]) - out = awq_dequantize(qweight, scales, qzeros) out = torch.matmul(reshaped_x, out) diff --git a/python/sglang/srt/layers/quantization/awq_triton.py b/python/sglang/srt/layers/quantization/awq_triton.py new file mode 100644 index 000000000..13352efdb --- /dev/null +++ b/python/sglang/srt/layers/quantization/awq_triton.py @@ -0,0 +1,339 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/awq_triton.py + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import triton +import triton.language as tl + +AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +@triton.jit +def awq_dequantize_kernel( + qweight_ptr, # quantized matrix + scales_ptr, # scales, per group + zeros_ptr, # zeros, per group + group_size, # Should always be one of the supported group sizes + result_ptr, # Output matrix + num_cols, # input num cols in qweight + num_rows, # input num rows in qweight + BLOCK_SIZE_X: tl.constexpr, + BLOCK_SIZE_Y: tl.constexpr, +): + # Setup the pids. + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + + # Compute offsets and masks for qweight_ptr. + offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) + offsets = num_cols * offsets_y[:, None] + offsets_x[None, :] + + masks_y = offsets_y < num_rows + masks_x = offsets_x < num_cols + + masks = masks_y[:, None] & masks_x[None, :] + + # Compute offsets and masks for result output ptr. + result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8) + result_offsets = ( + 8 * num_cols * result_offsets_y[:, None] + result_offsets_x[None, :] + ) + + result_masks_y = result_offsets_y < num_rows + result_masks_x = result_offsets_x < num_cols * 8 + result_masks = result_masks_y[:, None] & result_masks_x[None, :] + + # Load the weights. + iweights = tl.load(qweight_ptr + offsets, masks, 0.0) + iweights = tl.interleave(iweights, iweights) + iweights = tl.interleave(iweights, iweights) + iweights = tl.interleave(iweights, iweights) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ( + (tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None] + ).reshape(8) + + # Use this to compute a set of shifts that can be used to unpack and + # reorder the values in iweights and zeros. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + iweights = (iweights >> shifts) & 0xF + + # Compute zero offsets and masks. + zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) + zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) + zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :] + + zero_masks_y = zero_offsets_y < num_rows // group_size + zero_masks_x = zero_offsets_x < num_cols + zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :] + + # Load the zeros. + zeros = tl.load(zeros_ptr + zero_offsets, zero_masks, 0.0) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + zeros = (zeros >> shifts) & 0xF + + # Compute scale offsets and masks. + scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) + scale_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8) + scale_offsets = num_cols * 8 * scale_offsets_y[:, None] + scale_offsets_x[None, :] + scale_masks_y = scale_offsets_y < num_rows // group_size + scale_masks_x = scale_offsets_x < num_cols * 8 + scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :] + + # Load the scales. + scales = tl.load(scales_ptr + scale_offsets, scale_masks, 0.0) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Dequantize. + iweights = (iweights - zeros) * scales + iweights = iweights.to(result_ptr.type.element_ty) + + # Finally, store. + tl.store(result_ptr + result_offsets, iweights, result_masks) + + +@triton.jit +def awq_gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + zeros_ptr, + scales_ptr, + M, + N, + K, + group_size, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(1) + + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + accumulator_dtype = c_ptr.type.element_ty + + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # accumulator = tl.arange(0, BLOCK_SIZE_N) + # accumulator = tl.broadcast_to(accumulator[None, :], + # (BLOCK_SIZE_M, BLOCK_SIZE_N)) + # accumulator = accumulator & 0x0 + # accumulator = accumulator.to(accumulator_dtype) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ( + (tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None] + ).reshape(8) + + # Create the necessary shifts to use to unpack. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + # Offsets and masks. + offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + masks_am = offsets_am < M + + offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) + masks_bn = offsets_bn < N // 8 + + offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) + masks_zn = offsets_zn < N // 8 + + offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + masks_sn = offsets_sn < N + + offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offsets_a = K * offsets_am[:, None] + offsets_k[None, :] + offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :] + + a_ptrs = a_ptr + offsets_a + b_ptrs = b_ptr + offsets_b + + # NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv + # block_offset = BLOCK_SIZE_K * SPLIT_K + # for k in range(0, (K + block_offset - 1) // (block_offset)): + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + masks_k = offsets_k < K + masks_a = masks_am[:, None] & masks_k[None, :] + a = tl.load(a_ptrs, mask=masks_a, other=0.0) + + masks_b = masks_k[:, None] & masks_bn[None, :] + b = tl.load(b_ptrs, mask=masks_b, other=0.0) + b = tl.interleave(b, b) + b = tl.interleave(b, b) + b = tl.interleave(b, b) + + # Dequantize b. + offsets_szk = ( + BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K + ) // group_size + tl.arange(0, 1) + offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :] + masks_zk = offsets_szk < K // group_size + masks_z = masks_zk[:, None] & masks_zn[None, :] + zeros_ptrs = zeros_ptr + offsets_z + zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :] + masks_sk = offsets_szk < K // group_size + masks_s = masks_sk[:, None] & masks_sn[None, :] + scales_ptrs = scales_ptr + offsets_s + scales = tl.load(scales_ptrs, mask=masks_s, other=0.0) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + b = (b >> shifts) & 0xF + zeros = (zeros >> shifts) & 0xF + b = (b - zeros) * scales + b = b.to(c_ptr.type.element_ty) + + # Accumulate results. + accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) + + offsets_k += BLOCK_SIZE_K * SPLIT_K + a_ptrs += BLOCK_SIZE_K * SPLIT_K + b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8) + + c = accumulator.to(c_ptr.type.element_ty) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# qweights - [K , M // 8], int32 +# scales - [K // G, M ], float16 +# zeros - [K // G, M // 8], int32 +def awq_dequantize_triton( + qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + block_size_x: int = 32, + block_size_y: int = 32, +) -> torch.Tensor: + K = qweight.shape[0] + M = scales.shape[1] + group_size = qweight.shape[0] // scales.shape[0] + + assert K > 0 and M > 0 + assert scales.shape[0] == K // group_size and scales.shape[1] == M + assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8 + assert group_size <= K + assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + # Result tensor: + # number of rows = same as input tensor + # number of cols = 8 x input tensor num cols + result = torch.empty( + qweight.shape[0], + qweight.shape[1] * 8, + device=qweight.device, + dtype=scales.dtype, + ) + + Y = qweight.shape[0] # num rows + X = qweight.shape[1] # num cols + + grid = lambda META: ( + triton.cdiv(X, META["BLOCK_SIZE_X"]), + triton.cdiv(Y, META["BLOCK_SIZE_Y"]), + ) + awq_dequantize_kernel[grid]( + qweight, + scales, + zeros, + group_size, + result, + X, + Y, + BLOCK_SIZE_X=block_size_x, + BLOCK_SIZE_Y=block_size_y, + ) + + return result + + +# input - [M, K] +# qweight - [K, N // 8] +# qzeros - [K // G, N // 8] +# scales - [K // G, N] +# split_k_iters - parallelism along K-dimension, int, power of 2. +def awq_gemm_triton( + input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + split_k_iters: int, + block_size_m: int = 32, + block_size_n: int = 32, + block_size_k: int = 32, +) -> torch.Tensor: + M, K = input.shape + N = qweight.shape[1] * 8 + group_size = qweight.shape[0] // qzeros.shape[0] + + assert N > 0 and K > 0 and M > 0 + assert qweight.shape[0] == K and qweight.shape[1] == N // 8 + assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8 + assert scales.shape[0] == K // group_size and scales.shape[1] == N + assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0 + assert split_k_iters <= 32 + assert group_size <= K + assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + split_k_iters, + ) + + result = torch.zeros((split_k_iters, M, N), dtype=scales.dtype, device=input.device) + + # A = input, B = qweight, C = result + # A = M x K, B = K x N, C = M x N + awq_gemm_kernel[grid]( + input, + qweight, + result, + qzeros, + scales, + M, + N, + K, + group_size, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + SPLIT_K=split_k_iters, + ) + + result = result.sum(0) + + return result diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 12aa9cb39..0da956b01 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -127,6 +127,10 @@ if _is_cuda: ) elif _is_cpu and _is_cpu_amx_available: pass +elif _is_hip: + from sglang.srt.layers.quantization.awq_triton import ( + awq_dequantize_triton as awq_dequantize, + ) else: from vllm._custom_ops import awq_dequantize @@ -2176,7 +2180,7 @@ class DeepseekV2ForCausalLM(nn.Module): ) if hasattr(self_attn.kv_b_proj, "qweight"): # AWQ compatible - if _is_cuda: + if _is_cuda or _is_hip: w = awq_dequantize( self_attn.kv_b_proj.qweight, self_attn.kv_b_proj.scales, diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 41564869e..1a89971e1 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -147,6 +147,7 @@ suites = { # TestFile("test_vision_chunked_prefill.py", 175), # Disabled temporarily and track in #7701 TestFile("test_reasoning_parser.py", 5), TestFile("test_rope_rocm.py", 3), + TestFile("test_awq_dequant.py", 2), ], "per-commit-npu": [ TestFile("test_ascend_attention_backend.py", 400), diff --git a/test/srt/test_awq_dequant.py b/test/srt/test_awq_dequant.py new file mode 100644 index 000000000..ec1f2b16a --- /dev/null +++ b/test/srt/test_awq_dequant.py @@ -0,0 +1,175 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/kernels/quantization/test_awq_triton.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +unittest version of the AWQ Triton kernel tests. + +Run with: + python -m unittest test_awq_dequant.py +""" +import unittest + +import torch + +from sglang.srt.layers.quantization.awq_triton import ( + AWQ_TRITON_SUPPORTED_GROUP_SIZES, + awq_dequantize_triton, + awq_gemm_triton, +) +from sglang.test.test_utils import CustomTestCase + +device = "cuda" + + +def reverse_awq_order(t: torch.Tensor) -> torch.Tensor: + bits = 4 + AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + idx = torch.arange(t.shape[-1], dtype=torch.int32, device=t.device) + idx = idx.view(-1, 32 // bits)[:, AWQ_REVERSE_ORDER].view(-1) + return (t[:, idx] & 0xF).contiguous() + + +def awq_dequantize_torch( + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + group_size: int, +) -> torch.Tensor: + if group_size == -1: + group_size = qweight.shape[0] + + bits = 4 + shifts = torch.arange(0, 32, bits, device=qzeros.device) + + iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to( + torch.int8 + ) + iweights = reverse_awq_order(iweights.view(iweights.shape[0], -1)) + + zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to( + torch.int8 + ) + zeros = reverse_awq_order(zeros.view(qzeros.shape[0], -1)) + + iweights = torch.bitwise_and(iweights, (2**bits) - 1) + zeros = torch.bitwise_and(zeros, (2**bits) - 1) + + scales = scales.repeat_interleave(group_size, dim=0) + zeros = zeros.repeat_interleave(group_size, dim=0) + return (iweights - zeros) * scales + + +class TestAWQTriton(CustomTestCase): + def test_dequantize(self): + rows_list = [3584, 18944, 128, 256, 512, 1024] + cols_list = [448, 576, 4736, 16, 32, 64, 128] + + for qweight_rows in rows_list: + for qweight_cols in cols_list: + for group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES: + with self.subTest( + rows=qweight_rows, cols=qweight_cols, g=group_size + ): + self._run_dequant_case( + qweight_rows=qweight_rows, + qweight_cols=qweight_cols, + group_size=group_size, + ) + + def _run_dequant_case(self, qweight_rows, qweight_cols, group_size): + if group_size == -1: + group_size = qweight_rows + + torch.manual_seed(0) + + qweight = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (qweight_rows, qweight_cols), + dtype=torch.int32, + device=device, + ) + scales = torch.rand( + qweight_rows // group_size, + qweight_cols * 8, + dtype=torch.float16, + device=device, + ) + zeros = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (qweight_rows // group_size, qweight_cols), + dtype=torch.int32, + device=device, + ) + + ref = awq_dequantize_torch(qweight, scales, zeros, group_size) + tri = awq_dequantize_triton(qweight, scales, zeros) + + # sanity + self.assertFalse(torch.any(torch.isinf(tri)) or torch.any(torch.isnan(tri))) + torch.testing.assert_close(ref, tri) + + # GEMM + def test_gemm(self): + N_list = [1, 2, 4, 8, 14, 17, 23, 32] + K_list = [128] + M_list = [16, 24, 32] + splitK_list = [1, 8] + + for N in N_list: + for K in K_list: + for M in M_list: + for group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES: + for splitK in splitK_list: + with self.subTest(N=N, K=K, M=M, g=group_size, sk=splitK): + self._run_gemm_case( + N=N, + K=K, + M=M, + group_size=group_size, + splitK=splitK, + ) + + def _run_gemm_case(self, N, K, M, group_size, splitK): + if group_size == -1: + group_size = K + + torch.manual_seed(0) + + x = torch.rand((N, K), dtype=torch.float32, device=device) + qweight = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (K, M // 8), + dtype=torch.int32, + device=device, + ) + qzeros = torch.randint( + 0, + torch.iinfo(torch.int32).max, + (K // group_size, M // 8), + dtype=torch.int32, + device=device, + ) + scales = torch.rand((K // group_size, M), dtype=torch.float32, device=device) + + tri_out = awq_gemm_triton(x, qweight, scales, qzeros, splitK) + + self.assertFalse( + torch.any(torch.isinf(tri_out)) or torch.any(torch.isnan(tri_out)) + ) + + # dequantize & compare + w_deq = awq_dequantize_triton(qweight, scales, qzeros) + ref_out = torch.matmul(x, w_deq) + + self.assertFalse( + torch.any(torch.isinf(ref_out)) or torch.any(torch.isnan(ref_out)) + ) + + torch.testing.assert_close(tri_out.cpu(), ref_out.cpu(), atol=1e-1, rtol=1e-1) + + +if __name__ == "__main__": + unittest.main(verbosity=2)