diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py index e374759c4..3c36fcda4 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py @@ -44,6 +44,13 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): global _DO_COMPILE_ALL global _IS_FIRST_RANK_ON_NODE + # Update UE8M0 scaling configuration based on server args + from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( + update_deepgemm_scale_ue8m0, + ) + + update_deepgemm_scale_ue8m0(server_args.disable_deepgemm_ue8m0) + # Generate m_max m_max = 1024 * 16 if server_args.chunked_prefill_size < 1: diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py index ecf7d1647..d3397534f 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py @@ -29,4 +29,21 @@ def _is_blackwell_arch() -> bool: ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm() DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and _is_blackwell_arch() -DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL +# Allow disabling UE8M0 scaling for accuracy-critical workloads +# This can help with DeepSeek EP accuracy issues on B200 GPUs +# Will be updated by server args in update_deepgemm_scale_ue8m0() +DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL and get_bool_env_var( + "SGL_ENABLE_DEEPGEMM_UE8M0", default="true" +) + + +def update_deepgemm_scale_ue8m0(disable_ue8m0: bool): + """Update DEEPGEMM_SCALE_UE8M0 based on server arguments.""" + global DEEPGEMM_SCALE_UE8M0 + if disable_ue8m0: + DEEPGEMM_SCALE_UE8M0 = False + logger.info("DeepGEMM UE8M0 scaling disabled via server argument") + else: + DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL and get_bool_env_var( + "SGL_ENABLE_DEEPGEMM_UE8M0", default="true" + ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c6255223d..8730c4c49 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -268,6 +268,7 @@ class ServerArgs: flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default" enable_flashinfer_allreduce_fusion: bool = False deepep_mode: Literal["auto", "normal", "low_latency"] = "auto" + disable_deepgemm_ue8m0: bool = False ep_num_redundant_experts: int = 0 ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None init_expert_location: str = "trivial" @@ -1562,6 +1563,11 @@ class ServerArgs: default="auto", help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.", ) + parser.add_argument( + "--disable-deepgemm-ue8m0", + action="store_true", + help="Disable DeepGEMM UE8M0 scaling optimizations. This can help with accuracy issues on Blackwell GPUs (B200) for certain models like DeepSeek.", + ) parser.add_argument( "--ep-num-redundant-experts", type=int,