diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 77ab92aff..2176ad228 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -113,7 +113,7 @@ if supports_custom_op(): @triton.jit -def _per_token_group_quant_fp8( +def _per_token_group_quant_8bit( # Pointers to inputs and output y_ptr, y_q_ptr, @@ -125,8 +125,8 @@ def _per_token_group_quant_fp8( # Avoid to divide zero eps, # Information for float8 - fp8_min, - fp8_max, + bit8_min, + bit8_max, # Meta-parameters BLOCK: tl.constexpr, ): @@ -147,16 +147,16 @@ def _per_token_group_quant_fp8( y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) # Quant _absmax = tl.maximum(tl.max(tl.abs(y)), eps) - y_s = _absmax / fp8_max + y_s = _absmax / bit8_max y_s_inv = 1.0 / y_s - y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + y_q = tl.clamp(y * y_s_inv, bit8_min, bit8_max).to(y_q_ptr.dtype.element_ty) tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_s_ptr, y_s) @triton.jit -def _per_token_group_quant_fp8_colmajor( +def _per_token_group_quant_8bit_colmajor( # Pointers to inputs and output y_ptr, y_q_ptr, @@ -169,8 +169,8 @@ def _per_token_group_quant_fp8_colmajor( # Avoid to divide zero eps, # Information for float8 - fp8_min, - fp8_max, + bit8_min, + bit8_max, # Meta-parameters BLOCK: tl.constexpr, SCALE_UE8M0: tl.constexpr, @@ -197,19 +197,20 @@ def _per_token_group_quant_fp8_colmajor( y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) # Quant _absmax = tl.maximum(tl.max(tl.abs(y)), eps) - y_s = _absmax / fp8_max + y_s = _absmax / bit8_max if SCALE_UE8M0: y_s = tl.exp2(tl.ceil(tl.log2(tl.abs(y_s)))) - y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + y_q = tl.clamp(y / y_s, bit8_min, bit8_max).to(y_q_ptr.dtype.element_ty) tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_s_ptr, y_s) -def per_token_group_quant_fp8( +def _per_token_group_quant_8bit_raw( x: torch.Tensor, group_size: int, eps: float = 1e-10, + dtype: torch.dtype = fp8_dtype, column_major_scales: bool = False, scale_tma_aligned: bool = False, scale_ue8m0: bool = False, @@ -223,6 +224,7 @@ def per_token_group_quant_fp8( x: The input tenosr with ndim >= 2. group_size: The group size used for quantization. eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Returns: Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. @@ -232,7 +234,21 @@ def per_token_group_quant_fp8( ), "the last dimension of `x` cannot be divisible by `group_size`" assert x.is_contiguous(), "`x` is not contiguous" - x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) + if _is_hip: + if dtype == torch.int8: + bit8_max = 127.0 + else: + bit8_max = 224.0 + bit8_min = -bit8_max # TODO incorrect for int8 + else: + if dtype == torch.int8: + info = torch.iinfo(dtype) + else: + info = torch.finfo(dtype) + bit8_max = info.max + bit8_min = info.min + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) x_s = create_per_token_group_quant_fp8_output_scale( x_shape=x.shape, device=x.device, @@ -250,7 +266,7 @@ def per_token_group_quant_fp8( num_warps = min(max(BLOCK // 256, 1), 8) num_stages = 1 if column_major_scales: - _per_token_group_quant_fp8_colmajor[(M,)]( + _per_token_group_quant_8bit_colmajor[(M,)]( x, x_q, x_s, @@ -258,8 +274,8 @@ def per_token_group_quant_fp8( x.shape[1], x_s.stride(1), eps, - fp8_min=fp8_min, - fp8_max=fp8_max, + bit8_min=bit8_min, + bit8_max=bit8_max, BLOCK=BLOCK, num_warps=num_warps, num_stages=num_stages, @@ -267,15 +283,15 @@ def per_token_group_quant_fp8( ) else: assert not scale_ue8m0 - _per_token_group_quant_fp8[(M,)]( + _per_token_group_quant_8bit[(M,)]( x, x_q, x_s, group_size, N, eps, - fp8_min=fp8_min, - fp8_max=fp8_max, + bit8_min=bit8_min, + bit8_max=bit8_max, BLOCK=BLOCK, num_warps=num_warps, num_stages=num_stages, @@ -297,6 +313,117 @@ def per_token_group_quant_fp8( return x_q, x_s +# backward compatibility +per_token_group_quant_fp8 = _per_token_group_quant_8bit_raw + + +def _per_token_group_quant_8bit_fuse_silu_and_mul( + x: torch.Tensor, + group_size: int, + dst_dtype: torch.dtype, + column_major_scales: bool, + scale_tma_aligned: bool, + scale_ue8m0: bool, + masked_m: Optional[torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + # Another way to implement (can be used in e.g. comparison tests) + # from sgl_kernel import silu_and_mul + # x_after_silu_and_mul = silu_and_mul(x) + # return per_token_group_quant_fp8( + # x_after_silu_and_mul, + # group_size=group_size, + # eps=eps, + # column_major_scales=column_major_scales, + # scale_tma_aligned=scale_tma_aligned, + # scale_ue8m0=scale_ue8m0, + # ) + + from deep_gemm.utils.layout import transform_sf_into_required_layout + + from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd + + assert column_major_scales + assert scale_tma_aligned + assert scale_ue8m0 + + needs_unsqueeze = x.dim() == 2 + if needs_unsqueeze: + num_tokens, _ = x.shape + x = x.unsqueeze(0) + assert masked_m is None + masked_m = torch.tensor([num_tokens], device=x.device, dtype=torch.int32) + + # Use `zeros` for easier testing + output = torch.zeros( + (*x.shape[:-1], x.shape[-1] // 2), + device=x.device, + dtype=dst_dtype, + ) + # Use `zeros` for easier testing + output_scale_for_kernel = torch.zeros( + (*x.shape[:-1], x.shape[-1] // 2 // group_size), + device=x.device, + dtype=torch.float32, + ) + silu_and_mul_masked_post_quant_fwd( + input=x, + output=output, + output_scale=output_scale_for_kernel, + quant_group_size=group_size, + masked_m=masked_m, + scale_ue8m0=scale_ue8m0, + ) + + assert group_size == 128 + output_scale = transform_sf_into_required_layout( + output_scale_for_kernel, + num_groups=output.shape[0], + mn=output.shape[-2], + k=output.shape[-1], + recipe=(1, group_size, group_size), + is_sfa=True, + ) + + if needs_unsqueeze: + output = output.squeeze(0) + output_scale = output_scale.squeeze(0) + + return output, output_scale + + +def per_token_group_quant_8bit( + x: torch.Tensor, + group_size: int, + dst_dtype: torch.dtype, + eps: float = 1e-10, + column_major_scales: bool = False, + scale_tma_aligned: bool = False, + scale_ue8m0: bool = False, + fuse_silu_and_mul: bool = False, + masked_m: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if fuse_silu_and_mul: + return _per_token_group_quant_8bit_fuse_silu_and_mul( + x=x, + group_size=group_size, + dst_dtype=dst_dtype, + column_major_scales=column_major_scales, + scale_tma_aligned=scale_tma_aligned, + scale_ue8m0=scale_ue8m0, + masked_m=masked_m, + ) + else: + return _per_token_group_quant_8bit_raw( + x=x, + group_size=group_size, + eps=eps, + column_major_scales=column_major_scales, + scale_tma_aligned=scale_tma_aligned, + scale_ue8m0=scale_ue8m0, + dtype=dst_dtype, + ) + + def create_per_token_group_quant_fp8_output_scale( x_shape, device, @@ -307,16 +434,16 @@ def create_per_token_group_quant_fp8_output_scale( ): if scale_ue8m0: assert column_major_scales and scale_tma_aligned - x_q_mn, x_q_k = x_shape + *x_batch, x_q_mn, x_q_k = x_shape x_s_mn, x_s_k = x_q_mn, x_q_k // 128 aligned_mn = align(x_s_mn, 4) aligned_k = align(x_s_k, 4) # TODO(FIXME): Fix cuda kernel and recover here to empty. - return torch.zeros( - (aligned_k // 4, aligned_mn), + return torch.empty( + (*x_batch, aligned_k // 4, aligned_mn), device=device, dtype=torch.int, - ).transpose(0, 1)[:x_s_mn, :] + ).transpose(-1, -2)[..., :x_s_mn, :] elif column_major_scales: if scale_tma_aligned: # TODO extract "align" function @@ -341,39 +468,6 @@ def create_per_token_group_quant_fp8_output_scale( ) -# TODO maybe unify int8 and fp8 code later -def per_token_group_quant_8bit( - x: torch.Tensor, - group_size: int, - dst_dtype: torch.dtype, - eps: float = 1e-10, - column_major_scales: bool = False, - scale_tma_aligned: bool = False, - scale_ue8m0: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - from sglang.srt.layers.quantization.int8_kernel import per_token_group_quant_int8 - - if dst_dtype == torch.int8: - assert not column_major_scales - assert not scale_tma_aligned - assert not scale_ue8m0 - return per_token_group_quant_int8( - x=x, - group_size=group_size, - eps=eps, - dtype=dst_dtype, - ) - - return per_token_group_quant_fp8( - x=x, - group_size=group_size, - eps=eps, - column_major_scales=column_major_scales, - scale_tma_aligned=scale_tma_aligned, - scale_ue8m0=scale_ue8m0, - ) - - def sglang_per_token_group_quant_fp8( x: torch.Tensor, group_size: int, @@ -381,15 +475,19 @@ def sglang_per_token_group_quant_fp8( column_major_scales: bool = False, scale_tma_aligned: bool = False, scale_ue8m0: bool = False, + fuse_silu_and_mul: bool = False, + masked_m: Optional[torch.Tensor] = None, ): assert ( x.shape[-1] % group_size == 0 ), "the last dimension of `x` cannot be divisible by `group_size`" assert x.is_contiguous(), "`x` is not contiguous" - x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) + out_shape = (*x.shape[:-1], x.shape[-1] // (2 if fuse_silu_and_mul else 1)) + + x_q = torch.empty(out_shape, device=x.device, dtype=fp8_dtype) x_s = create_per_token_group_quant_fp8_output_scale( - x_shape=x.shape, + x_shape=out_shape, device=x.device, group_size=group_size, column_major_scales=column_major_scales, @@ -414,6 +512,8 @@ def sglang_per_token_group_quant_8bit( column_major_scales: bool = False, scale_tma_aligned: bool = False, scale_ue8m0: bool = False, + fuse_silu_and_mul: bool = False, + masked_m: Optional[torch.Tensor] = None, ): from sglang.srt.layers.quantization.int8_kernel import ( sglang_per_token_group_quant_int8, @@ -422,6 +522,8 @@ def sglang_per_token_group_quant_8bit( if dst_dtype == torch.int8: assert not column_major_scales assert not scale_tma_aligned + assert not fuse_silu_and_mul + assert masked_m is None return sglang_per_token_group_quant_int8( x=x, group_size=group_size, @@ -436,6 +538,8 @@ def sglang_per_token_group_quant_8bit( column_major_scales=column_major_scales, scale_tma_aligned=scale_tma_aligned, scale_ue8m0=scale_ue8m0, + fuse_silu_and_mul=fuse_silu_and_mul, + masked_m=masked_m, )