From 45e3a7bc41d750659bbd57b3981a91fa0dd0f1c0 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Wed, 12 Feb 2025 18:40:42 +0800 Subject: [PATCH] use sgl_per_token_group_quant_fp8 kernel (#3493) --- python/pyproject.toml | 2 +- .../layers/moe/fused_moe_triton/fused_moe.py | 9 ++++- .../srt/layers/quantization/fp8_kernel.py | 34 +++++++++++++++++++ 3 files changed, 43 insertions(+), 2 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 03a71af1e..296440016 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -25,7 +25,7 @@ runtime_common = [ ] srt = [ "sglang[runtime_common]", "cuda-python", - "sgl-kernel>=0.0.3.post3", "torch", "vllm>=0.6.4.post1,<=0.7.2", + "sgl-kernel>=0.0.3.post4", "torch", "vllm>=0.6.4.post1,<=0.7.2", "flashinfer_python>=0.2.0.post2", "outlines>=0.0.44,<=0.1.11" ] diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index b9390fd42..04292764c 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -33,6 +33,10 @@ _is_rocm = torch.cuda.is_available() and torch.version.hip if _is_cuda: from sgl_kernel import gelu_and_mul, silu_and_mul + from sglang.srt.layers.quantization.fp8_kernel import ( + sglang_per_token_group_quant_fp8, + ) + if _is_cuda or _is_rocm: from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size @@ -488,7 +492,10 @@ def invoke_fused_moe_kernel( else: assert len(block_shape) == 2 block_n, block_k = block_shape[0], block_shape[1] - A, A_scale = per_token_group_quant_fp8(A, block_k) + if _is_cuda: + A, A_scale = sglang_per_token_group_quant_fp8(A, block_k) + else: + A, A_scale = per_token_group_quant_fp8(A, block_k) assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 28c371cfe..8ff18715f 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -27,6 +27,10 @@ from sglang.srt.utils import get_device_core_count, get_device_name, is_hip is_hip_ = is_hip() fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn +_is_cuda = torch.cuda.is_available() and torch.version.cuda +if _is_cuda: + from sgl_kernel import sgl_per_token_group_quant_fp8 + logger = logging.getLogger(__name__) @@ -135,6 +139,36 @@ def per_token_group_quant_fp8( return x_q, x_s +def sglang_per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: torch.dtype = fp8_type_, +): + assert ( + x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_max = finfo.max + + fp8_min = -fp8_max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + x_s = torch.empty( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) + + sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) + + return x_q, x_s + + @triton.jit def _w8a8_block_fp8_matmul( # Pointers to inputs and output