diff --git a/python/pyproject.toml b/python/pyproject.toml index b45d75e53..7471f84bf 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -49,7 +49,7 @@ runtime_common = [ srt = [ "sglang[runtime_common]", - "sgl-kernel==0.1.7", + "sgl-kernel==0.1.8.post1", "flashinfer_python==0.2.6.post1", "torch==2.7.1", "torchaudio==2.7.1", diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 75bccc9dc..357146469 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -605,7 +605,7 @@ def _set_envs_and_config(server_args: ServerArgs): if _is_cuda: assert_pkg_version( "sgl-kernel", - "0.1.7", + "0.1.8.post1", "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", ) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index f92ff801f..60ba7f5c5 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -280,6 +280,7 @@ def sglang_per_token_group_quant_fp8( eps: float = 1e-10, column_major_scales: bool = False, scale_tma_aligned: bool = False, + scale_ue8m0: bool = False, ): assert ( x.shape[-1] % group_size == 0 @@ -287,8 +288,20 @@ def sglang_per_token_group_quant_fp8( assert x.is_contiguous(), "`x` is not contiguous" x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) - if column_major_scales: + if scale_ue8m0: + assert column_major_scales and scale_tma_aligned + x_q_mn, x_q_k = x.shape + x_s_mn, x_s_k = x_q_mn, x_q_k // 128 + aligned_mn = align(x_s_mn, 4) + aligned_k = align(x_s_k, 4) + x_s = torch.empty( + (aligned_k // 4, aligned_mn), + device=x.device, + dtype=torch.int, + ).permute(-1, -2)[:x_s_mn, :] + elif column_major_scales: if scale_tma_aligned: + # TODO extract "align" function # aligned to 4 * sizeof(float) aligned_size = (x.shape[-2] + 3) // 4 * 4 x_s = torch.empty( @@ -309,7 +322,9 @@ def sglang_per_token_group_quant_fp8( dtype=torch.float32, ) if x.shape[0] > 0: - sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max) + sgl_per_token_group_quant_fp8( + x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 + ) return x_q, x_s