From b2a189dd11ae30d3b6ad637b227146f7a9560142 Mon Sep 17 00:00:00 2001 From: strgrb Date: Fri, 18 Apr 2025 15:05:24 +0800 Subject: [PATCH] use sglang_per_token_group_quant_fp8 from sgl-kernel instead of trion kernel (#5473) Co-authored-by: Zhang Kaihong --- .../srt/layers/quantization/fp8_kernel.py | 29 +++++++++++++++---- .../srt/layers/quantization/fp8_utils.py | 2 +- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 71de14b46..c3613fe83 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -275,6 +275,8 @@ def sglang_per_token_group_quant_fp8( x: torch.Tensor, group_size: int, eps: float = 1e-10, + column_major_scales: bool = False, + scale_tma_aligned: bool = False, ): assert ( x.shape[-1] % group_size == 0 @@ -282,11 +284,28 @@ def sglang_per_token_group_quant_fp8( assert x.is_contiguous(), "`x` is not contiguous" x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type) - x_s = torch.empty( - x.shape[:-1] + (x.shape[-1] // group_size,), - device=x.device, - dtype=torch.float32, - ) + if column_major_scales: + if scale_tma_aligned: + # aligned to 4 * sizeof(float) + aligned_size = (x.shape[-2] + 3) // 4 * 4 + x_s = torch.empty( + x.shape[:-2] + (x.shape[-1] // group_size, aligned_size), + device=x.device, + dtype=torch.float32, + ).permute(-1, -2)[: x.shape[-2], :] + else: + x_s = torch.empty( + (x.shape[-1] // group_size,) + x.shape[:-1], + device=x.device, + dtype=torch.float32, + ).permute(-1, -2) + else: + x_s = torch.empty( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) + sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) return x_q, x_s diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 7acf95678..1599cf26b 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -141,7 +141,7 @@ def apply_w8a8_block_fp8_linear( gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output) else: if _enable_jit_deepgemm: - q_input, x_scale = per_token_group_quant_fp8( + q_input, x_scale = sglang_per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=True,