Move deep gemm related arguments to sglang.srt.environ (#11547)

This commit is contained in:
Liangsheng Yin
2025-10-14 00:34:35 +08:00
committed by GitHub
parent bfadb5ea5f
commit acc2327bbd
20 changed files with 187 additions and 189 deletions

View File

@@ -19,6 +19,7 @@ import requests
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.environ import envs
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import ServerArgs
@@ -28,9 +29,9 @@ from sglang.srt.warmup import warmup
multiprocessing.set_start_method("spawn", force=True)
# Reduce warning
os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1"
envs.SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE.set(True)
# Force enable deep gemm
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(True)
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0"

View File

@@ -180,6 +180,7 @@ class Envs:
SGLANG_EXPERT_LOCATION_UPDATER_CANARY = EnvBool(False)
SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS = EnvBool(False)
SGLANG_LOG_EXPERT_LOCATION_METADATA = EnvBool(False)
SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR = EnvStr("/tmp")
# TBO
SGLANG_TBO_DEBUG = EnvBool(False)

View File

@@ -16,21 +16,20 @@ from __future__ import annotations
import logging
import math
import os
import time
from abc import ABC
from collections import deque
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
import einops
import torch
import torch.distributed
from sglang.srt.environ import envs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable, get_bool_env_var, is_npu
from sglang.srt.utils import Withable, is_npu
_is_npu = is_npu()
@@ -839,7 +838,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
def _dump_to_file(name, data):
save_dir = Path(os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR", "/tmp"))
save_dir = envs.SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR.get()
path_output = save_dir / name
logger.info(f"Write expert distribution to {path_output}")
if not save_dir.exists():

View File

@@ -7,11 +7,12 @@ from typing import Dict, List, Tuple
import torch
from tqdm import tqdm
from sglang.srt.environ import envs
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
ENABLE_JIT_DEEPGEMM,
)
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ceil_div, get_bool_env_var, get_int_env_var
from sglang.srt.utils import ceil_div, get_bool_env_var
logger = logging.getLogger(__name__)
@@ -20,12 +21,9 @@ if ENABLE_JIT_DEEPGEMM:
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
)
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.get()
_DO_COMPILE_ALL = True
_IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
_COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
# Force redirect deep_gemm cache_dir

View File

@@ -1,6 +1,7 @@
import logging
from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
from sglang.srt.environ import envs
from sglang.srt.utils import get_device_sm, is_blackwell
logger = logging.getLogger(__name__)
@@ -15,7 +16,7 @@ def _compute_enable_deep_gemm():
except ImportError:
return False
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
return envs.SGLANG_ENABLE_JIT_DEEPGEMM.get()
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()