diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 86d8155f8..e742f19c3 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -239,6 +239,11 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback( column_major_scales=True, scale_tma_aligned=True, ) + + if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"): + _check_ue8m0("x_scale", x_scale) + _check_ue8m0("weight_scale", weight_scale) + output = w8a8_block_fp8_matmul_deepgemm( q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype ) @@ -247,6 +252,11 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback( return output.to(dtype=output_dtype).view(*output_shape) +def _check_ue8m0(name, x): + x_ceil = ceil_to_ue8m0(x) + assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}" + + def aiter_w8a8_block_fp8_linear( input: torch.Tensor, weight: torch.Tensor, @@ -380,6 +390,11 @@ def block_quant_dequant( return (x_q_block.to(torch.float32) * x_scale_repeat).to(dtype) +# COPIED FROM DeepGEMM +def ceil_to_ue8m0(x: torch.Tensor): + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + def channel_quant_to_tensor_quant( x_q_channel: torch.Tensor, x_s: torch.Tensor,