From eebfdb94593ffdf92ac7bac3dbc82c80a3baf3f2 Mon Sep 17 00:00:00 2001 From: JieXin Liang Date: Sun, 27 Apr 2025 09:12:48 +0800 Subject: [PATCH] [fix] fix potential bumpy throughtput with deepgemm (#5722) --- python/sglang/compile_deep_gemm.py | 2 +- .../srt/layers/quantization/deep_gemm.py | 25 ++++++++++++------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/python/sglang/compile_deep_gemm.py b/python/sglang/compile_deep_gemm.py index b86086e20..a3e171464 100644 --- a/python/sglang/compile_deep_gemm.py +++ b/python/sglang/compile_deep_gemm.py @@ -27,7 +27,7 @@ from sglang.srt.warmup import warmup multiprocessing.set_start_method("spawn", force=True) # Reduce warning -os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1" +os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1" # Force enable deep gemm os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1" # Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case diff --git a/python/sglang/srt/layers/quantization/deep_gemm.py b/python/sglang/srt/layers/quantization/deep_gemm.py index 261542fb2..08ba0b9f9 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm.py +++ b/python/sglang/srt/layers/quantization/deep_gemm.py @@ -34,9 +34,10 @@ _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1)) _ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var( "SGL_JIT_DEEPGEMM_PRECOMPILE", "true" ) -_DO_COMPILE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true") +_DO_COMPILE_ALL = True +_IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true") _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4) -_IN_PRE_COMPILE_STAGE = get_bool_env_var("SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE", "false") +_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false") # Force redirect deep_gemm cache_dir os.environ["DG_CACHE_DIR"] = os.getenv( @@ -46,7 +47,8 @@ os.environ["DG_CACHE_DIR"] = os.getenv( def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): global _BUILTIN_M_LIST - global _DO_COMPILE + global _DO_COMPILE_ALL + global _IS_FIRST_RANK_ON_NODE # Generate m_max m_max = 1024 * 16 @@ -57,8 +59,13 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): m_max = min(1024 * 128, m_max) _BUILTIN_M_LIST = list(range(1, m_max + 1)) - # Check if is the first rank on node - _DO_COMPILE = ServerArgs.base_gpu_id == gpu_id + _IS_FIRST_RANK_ON_NODE = ServerArgs.base_gpu_id == gpu_id + + # Check if is the first rank on node. + # Default each rank will try compile all Ms to + # load all symbols at the launch stages. + # Avoid loading symbols at the serving stages. + _DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE or not _IN_PRECOMPILE_STAGE class DeepGemmKernelType(IntEnum): @@ -89,7 +96,7 @@ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dic def _compile_warning_1(): - if not _IN_PRE_COMPILE_STAGE: + if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE: logger.warning( "Entering DeepGEMM JIT Pre-Complie session. " "And it may takes a long time(Typically 10-20 mins) " @@ -276,7 +283,7 @@ def _maybe_compile_deep_gemm_one_type_all( query_key = (kernel_type, n, k, num_groups) if ( _ENABLE_JIT_DEEPGEMM_PRECOMPILE - and _DO_COMPILE + and _DO_COMPILE_ALL and _INITIALIZATION_DICT.get(query_key) is None ): _INITIALIZATION_DICT[query_key] = True @@ -286,7 +293,7 @@ def _maybe_compile_deep_gemm_one_type_all( logger.info( f"Try DeepGEMM JIT Compiling for " f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms." - f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRE_COMPILE_STAGE else ''}" + f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}" ) # NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced @@ -355,7 +362,7 @@ def gemm_nt_f8f8bf16( @contextmanager def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType): - if _IN_PRE_COMPILE_STAGE: + if _IN_PRECOMPILE_STAGE: yield return