[fix] fix potential bumpy throughtput with deepgemm (#5722)
This commit is contained in:
@@ -27,7 +27,7 @@ from sglang.srt.warmup import warmup
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
# Reduce warning
|
||||
os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1"
|
||||
os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1"
|
||||
# Force enable deep gemm
|
||||
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
|
||||
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
|
||||
|
||||
@@ -34,9 +34,10 @@ _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
||||
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
|
||||
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
|
||||
)
|
||||
_DO_COMPILE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
|
||||
_DO_COMPILE_ALL = True
|
||||
_IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
|
||||
_COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
|
||||
_IN_PRE_COMPILE_STAGE = get_bool_env_var("SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE", "false")
|
||||
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
|
||||
|
||||
# Force redirect deep_gemm cache_dir
|
||||
os.environ["DG_CACHE_DIR"] = os.getenv(
|
||||
@@ -46,7 +47,8 @@ os.environ["DG_CACHE_DIR"] = os.getenv(
|
||||
|
||||
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
||||
global _BUILTIN_M_LIST
|
||||
global _DO_COMPILE
|
||||
global _DO_COMPILE_ALL
|
||||
global _IS_FIRST_RANK_ON_NODE
|
||||
|
||||
# Generate m_max
|
||||
m_max = 1024 * 16
|
||||
@@ -57,8 +59,13 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
||||
m_max = min(1024 * 128, m_max)
|
||||
_BUILTIN_M_LIST = list(range(1, m_max + 1))
|
||||
|
||||
# Check if is the first rank on node
|
||||
_DO_COMPILE = ServerArgs.base_gpu_id == gpu_id
|
||||
_IS_FIRST_RANK_ON_NODE = ServerArgs.base_gpu_id == gpu_id
|
||||
|
||||
# Check if is the first rank on node.
|
||||
# Default each rank will try compile all Ms to
|
||||
# load all symbols at the launch stages.
|
||||
# Avoid loading symbols at the serving stages.
|
||||
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE or not _IN_PRECOMPILE_STAGE
|
||||
|
||||
|
||||
class DeepGemmKernelType(IntEnum):
|
||||
@@ -89,7 +96,7 @@ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dic
|
||||
|
||||
|
||||
def _compile_warning_1():
|
||||
if not _IN_PRE_COMPILE_STAGE:
|
||||
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
||||
logger.warning(
|
||||
"Entering DeepGEMM JIT Pre-Complie session. "
|
||||
"And it may takes a long time(Typically 10-20 mins) "
|
||||
@@ -276,7 +283,7 @@ def _maybe_compile_deep_gemm_one_type_all(
|
||||
query_key = (kernel_type, n, k, num_groups)
|
||||
if (
|
||||
_ENABLE_JIT_DEEPGEMM_PRECOMPILE
|
||||
and _DO_COMPILE
|
||||
and _DO_COMPILE_ALL
|
||||
and _INITIALIZATION_DICT.get(query_key) is None
|
||||
):
|
||||
_INITIALIZATION_DICT[query_key] = True
|
||||
@@ -286,7 +293,7 @@ def _maybe_compile_deep_gemm_one_type_all(
|
||||
logger.info(
|
||||
f"Try DeepGEMM JIT Compiling for "
|
||||
f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
|
||||
f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRE_COMPILE_STAGE else ''}"
|
||||
f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
|
||||
)
|
||||
|
||||
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
|
||||
@@ -355,7 +362,7 @@ def gemm_nt_f8f8bf16(
|
||||
|
||||
@contextmanager
|
||||
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
||||
if _IN_PRE_COMPILE_STAGE:
|
||||
if _IN_PRECOMPILE_STAGE:
|
||||
yield
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user