From 5e5c30d9ab8f1f07155eee61f2cab95a8e7cc350 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 5 Sep 2025 19:52:32 +0800 Subject: [PATCH] Tiny let DeepGEMM scale checks cover more cases (#7182) Co-authored-by: Yineng Zhang --- .../deep_gemm_wrapper/entrypoint.py | 27 +++++++++++++++++++ .../srt/layers/quantization/fp8_utils.py | 10 ------- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py index eedaa3c9b..02945f449 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py @@ -11,6 +11,7 @@ from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( ENABLE_JIT_DEEPGEMM, ) from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import get_bool_env_var logger = logging.getLogger(__name__) @@ -18,6 +19,8 @@ if ENABLE_JIT_DEEPGEMM: import deep_gemm from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor +_SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK") + # TODO maybe rename these functions def grouped_gemm_nt_f8f8bf16_masked( @@ -31,6 +34,9 @@ def grouped_gemm_nt_f8f8bf16_masked( _, n, _ = rhs[0].shape kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED + _sanity_check_input(lhs) + _sanity_check_input(rhs) + with compile_utils.deep_gemm_execution_hook( expected_m, n, k, num_groups, kernel_type ): @@ -53,6 +59,9 @@ def grouped_gemm_nt_f8f8bf16_contig( num_groups, n, _ = rhs[0].shape kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG + _sanity_check_input(lhs) + _sanity_check_input(rhs) + with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type): deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices) @@ -67,6 +76,9 @@ def gemm_nt_f8f8bf16( num_groups = 1 kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16 + _sanity_check_input(lhs) + _sanity_check_input(rhs) + with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type): deep_gemm.fp8_gemm_nt( lhs, @@ -90,3 +102,18 @@ def configure_deep_gemm_num_sms(num_sms): yield finally: deep_gemm.set_num_sms(original_num_sms) + + +def _sanity_check_input(x_fp8: Tuple[torch.Tensor, torch.Tensor]): + if not _SANITY_CHECK: + return + + x, x_scale = x_fp8 + + if x_scale.dtype == torch.int: + return + + from sglang.srt.layers.quantization.fp8_utils import ceil_to_ue8m0 + + x_scale_ceil = ceil_to_ue8m0(x_scale) + assert torch.all(x_scale == x_scale_ceil), f"{x_scale=} {x_scale_ceil=}" diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 42c894590..e4bcbe23c 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -248,11 +248,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback( scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, ) - # NOTE(alcanderian): Useless when scale is packed to int32 - # if get_bool_env_var("SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"): - # _check_ue8m0("x_scale", x_scale) - # _check_ue8m0("weight_scale", ws) - output = w8a8_block_fp8_matmul_deepgemm( q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype ) @@ -261,11 +256,6 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback( return output.to(dtype=output_dtype).view(*output_shape) -def _check_ue8m0(name, x): - x_ceil = ceil_to_ue8m0(x) - assert torch.all(x == x_ceil), f"{name=} {x=} {x_ceil=}" - - def aiter_w8a8_block_fp8_linear( input: torch.Tensor, weight: torch.Tensor,