Move deep gemm related arguments to sglang.srt.environ (#11547)
This commit is contained in:
@@ -19,6 +19,7 @@ import requests
|
||||
|
||||
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -28,9 +29,9 @@ from sglang.srt.warmup import warmup
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
# Reduce warning
|
||||
os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1"
|
||||
envs.SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE.set(True)
|
||||
# Force enable deep gemm
|
||||
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
|
||||
envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(True)
|
||||
# Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
|
||||
os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0"
|
||||
|
||||
|
||||
@@ -180,6 +180,7 @@ class Envs:
|
||||
SGLANG_EXPERT_LOCATION_UPDATER_CANARY = EnvBool(False)
|
||||
SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS = EnvBool(False)
|
||||
SGLANG_LOG_EXPERT_LOCATION_METADATA = EnvBool(False)
|
||||
SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR = EnvStr("/tmp")
|
||||
|
||||
# TBO
|
||||
SGLANG_TBO_DEBUG = EnvBool(False)
|
||||
|
||||
@@ -16,21 +16,20 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from abc import ABC
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import Withable, get_bool_env_var, is_npu
|
||||
from sglang.srt.utils import Withable, is_npu
|
||||
|
||||
_is_npu = is_npu()
|
||||
|
||||
@@ -839,7 +838,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
|
||||
|
||||
|
||||
def _dump_to_file(name, data):
|
||||
save_dir = Path(os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR", "/tmp"))
|
||||
save_dir = envs.SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR.get()
|
||||
path_output = save_dir / name
|
||||
logger.info(f"Write expert distribution to {path_output}")
|
||||
if not save_dir.exists():
|
||||
|
||||
@@ -7,11 +7,12 @@ from typing import Dict, List, Tuple
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
||||
ENABLE_JIT_DEEPGEMM,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import ceil_div, get_bool_env_var, get_int_env_var
|
||||
from sglang.srt.utils import ceil_div, get_bool_env_var
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -20,12 +21,9 @@ if ENABLE_JIT_DEEPGEMM:
|
||||
|
||||
|
||||
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
||||
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
|
||||
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
|
||||
)
|
||||
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.get()
|
||||
_DO_COMPILE_ALL = True
|
||||
_IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
|
||||
_COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
|
||||
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
|
||||
|
||||
# Force redirect deep_gemm cache_dir
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.utils import get_device_sm, is_blackwell
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -15,7 +16,7 @@ def _compute_enable_deep_gemm():
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
|
||||
return envs.SGLANG_ENABLE_JIT_DEEPGEMM.get()
|
||||
|
||||
|
||||
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
|
||||
|
||||
Reference in New Issue
Block a user