diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index b488a65c0..acde08f82 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -173,6 +173,7 @@ def _per_token_group_quant_fp8_colmajor( fp8_max, # Meta-parameters BLOCK: tl.constexpr, + SCALE_UE8M0: tl.constexpr, ): """A Triton-accelerated function to perform per-token-group quantization on a tensor. @@ -197,6 +198,8 @@ def _per_token_group_quant_fp8_colmajor( # Quant _absmax = tl.maximum(tl.max(tl.abs(y)), eps) y_s = _absmax / fp8_max + if SCALE_UE8M0: + y_s = tl.exp2(tl.ceil(tl.log2(tl.abs(y_s)))) y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) tl.store(y_q_ptr + cols, y_q, mask=mask) @@ -209,6 +212,7 @@ def per_token_group_quant_fp8( eps: float = 1e-10, column_major_scales: bool = False, scale_tma_aligned: bool = False, + scale_ue8m0: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Function to perform per-token-group quantization on an input tensor `x`. @@ -229,29 +233,17 @@ def per_token_group_quant_fp8( assert x.is_contiguous(), "`x` is not contiguous" x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) + x_s = create_per_token_group_quant_fp8_output_scale( + x_shape=x.shape, + device=x.device, + group_size=group_size, + column_major_scales=column_major_scales, + scale_tma_aligned=scale_tma_aligned, + scale_ue8m0=False, + ) + M = x.numel() // group_size N = group_size - if column_major_scales: - if scale_tma_aligned: - # aligned to 4 * sizeof(float) - aligned_size = (x.shape[-2] + 3) // 4 * 4 - x_s = torch.empty( - x.shape[:-2] + (x.shape[-1] // group_size, aligned_size), - device=x.device, - dtype=torch.float32, - ).permute(-1, -2)[: x.shape[-2], :] - else: - x_s = torch.empty( - (x.shape[-1] // group_size,) + x.shape[:-1], - device=x.device, - dtype=torch.float32, - ).permute(-1, -2) - else: - x_s = torch.empty( - x.shape[:-1] + (x.shape[-1] // group_size,), - device=x.device, - dtype=torch.float32, - ) BLOCK = triton.next_power_of_2(N) # heuristics for number of warps @@ -271,8 +263,10 @@ def per_token_group_quant_fp8( BLOCK=BLOCK, num_warps=num_warps, num_stages=num_stages, + SCALE_UE8M0=scale_ue8m0, ) else: + assert not scale_ue8m0 _per_token_group_quant_fp8[(M,)]( x, x_q, @@ -287,9 +281,66 @@ def per_token_group_quant_fp8( num_stages=num_stages, ) + if scale_ue8m0: + from deep_gemm.utils.layout import transform_sf_into_required_layout + + assert group_size == 128 + x_s = transform_sf_into_required_layout( + x_s, + num_groups=None, + mn=x_q.shape[0], + k=x_q.shape[1], + recipe=(1, group_size, group_size), + is_sfa=True, + ) + return x_q, x_s +def create_per_token_group_quant_fp8_output_scale( + x_shape, + device, + group_size, + column_major_scales: bool, + scale_tma_aligned: bool, + scale_ue8m0: bool, +): + if scale_ue8m0: + assert column_major_scales and scale_tma_aligned + x_q_mn, x_q_k = x_shape + x_s_mn, x_s_k = x_q_mn, x_q_k // 128 + aligned_mn = align(x_s_mn, 4) + aligned_k = align(x_s_k, 4) + # TODO(FIXME): Fix cuda kernel and recover here to empty. + return torch.zeros( + (aligned_k // 4, aligned_mn), + device=device, + dtype=torch.int, + ).transpose(0, 1)[:x_s_mn, :] + elif column_major_scales: + if scale_tma_aligned: + # TODO extract "align" function + # aligned to 4 * sizeof(float) + aligned_size = (x_shape[-2] + 3) // 4 * 4 + return torch.empty( + x_shape[:-2] + (x_shape[-1] // group_size, aligned_size), + device=device, + dtype=torch.float32, + ).permute(-1, -2)[: x_shape[-2], :] + else: + return torch.empty( + (x_shape[-1] // group_size,) + x_shape[:-1], + device=device, + dtype=torch.float32, + ).permute(-1, -2) + else: + return torch.empty( + x_shape[:-1] + (x_shape[-1] // group_size,), + device=device, + dtype=torch.float32, + ) + + def sglang_per_token_group_quant_fp8( x: torch.Tensor, group_size: int, @@ -303,41 +354,20 @@ def sglang_per_token_group_quant_fp8( ), "the last dimension of `x` cannot be divisible by `group_size`" assert x.is_contiguous(), "`x` is not contiguous" - x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) if scale_ue8m0: - assert column_major_scales and scale_tma_aligned - x_q_mn, x_q_k = x.shape - x_s_mn, x_s_k = x_q_mn, x_q_k // 128 - aligned_mn = align(x_s_mn, 4) - aligned_k = align(x_s_k, 4) - # TODO(FIXME): Fix cuda kernel and recover here to empty. - x_s = torch.zeros( - (aligned_k // 4, aligned_mn), - device=x.device, - dtype=torch.int, - ).transpose(0, 1)[:x_s_mn, :] - elif column_major_scales: - if scale_tma_aligned: - # TODO extract "align" function - # aligned to 4 * sizeof(float) - aligned_size = (x.shape[-2] + 3) // 4 * 4 - x_s = torch.empty( - x.shape[:-2] + (x.shape[-1] // group_size, aligned_size), - device=x.device, - dtype=torch.float32, - ).permute(-1, -2)[: x.shape[-2], :] - else: - x_s = torch.empty( - (x.shape[-1] // group_size,) + x.shape[:-1], - device=x.device, - dtype=torch.float32, - ).permute(-1, -2) - else: - x_s = torch.empty( - x.shape[:-1] + (x.shape[-1] // group_size,), - device=x.device, - dtype=torch.float32, - ) + # TODO: handle this case by fixing the (token=4, dim=256, group_size=128) UT case + assert x.shape[-1] % (group_size * 4) == 0 + + x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) + x_s = create_per_token_group_quant_fp8_output_scale( + x_shape=x.shape, + device=x.device, + group_size=group_size, + column_major_scales=column_major_scales, + scale_tma_aligned=scale_tma_aligned, + scale_ue8m0=scale_ue8m0, + ) + if x.shape[0] > 0: sgl_per_token_group_quant_fp8( x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0