diff --git a/docs/basic_usage/deepseek.md b/docs/basic_usage/deepseek.md index 0f397b3cd..7e5daa898 100644 --- a/docs/basic_usage/deepseek.md +++ b/docs/basic_usage/deepseek.md @@ -144,7 +144,7 @@ With data parallelism attention enabled, we have achieved up to **1.9x** decodin - **DeepGEMM**: The [DeepGEMM](https://github.com/deepseek-ai/DeepGEMM) kernel library optimized for FP8 matrix multiplications. -**Usage**: The activation and weight optimization above are turned on by default for DeepSeek V3 models. DeepGEMM is enabled by default on NVIDIA Hopper GPUs and disabled by default on other devices. DeepGEMM can also be manually turned off by setting the environment variable `SGL_ENABLE_JIT_DEEPGEMM=0`. +**Usage**: The activation and weight optimization above are turned on by default for DeepSeek V3 models. DeepGEMM is enabled by default on NVIDIA Hopper GPUs and disabled by default on other devices. DeepGEMM can also be manually turned off by setting the environment variable `SGLANG_ENABLE_JIT_DEEPGEMM=0`. Before serving the DeepSeek model, precompile the DeepGEMM kernels using: ```bash diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index e332aac1f..2e5277a3f 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -32,9 +32,9 @@ SGLang supports various environment variables that can be used to configure its | Environment Variable | Description | Default Value | | --- | --- | --- | -| `SGL_ENABLE_JIT_DEEPGEMM` | Enable Just-In-Time compilation of DeepGEMM kernels | `"true"` | -| `SGL_JIT_DEEPGEMM_PRECOMPILE` | Enable precompilation of DeepGEMM kernels | `"true"` | -| `SGL_JIT_DEEPGEMM_COMPILE_WORKERS` | Number of workers for parallel DeepGEMM kernel compilation | `4` | +| `SGLANG_ENABLE_JIT_DEEPGEMM` | Enable Just-In-Time compilation of DeepGEMM kernels | `"true"` | +| `SGLANG_JIT_DEEPGEMM_PRECOMPILE` | Enable precompilation of DeepGEMM kernels | `"true"` | +| `SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS` | Number of workers for parallel DeepGEMM kernel compilation | `4` | | `SGL_IN_DEEPGEMM_PRECOMPILE_STAGE` | Indicator flag used during the DeepGEMM precompile script | `"false"` | | `SGL_DG_CACHE_DIR` | Directory for caching compiled DeepGEMM kernels | `~/.cache/deep_gemm` | | `SGL_DG_USE_NVRTC` | Use NVRTC (instead of Triton) for JIT compilation (Experimental) | `"0"` | diff --git a/docs/references/multi_node_deployment/lws_pd/lws-examples/d.yaml b/docs/references/multi_node_deployment/lws_pd/lws-examples/d.yaml index ac1d295eb..dbb51b519 100644 --- a/docs/references/multi_node_deployment/lws_pd/lws-examples/d.yaml +++ b/docs/references/multi_node_deployment/lws_pd/lws-examples/d.yaml @@ -80,7 +80,7 @@ spec: value: "true" - name: SGLANG_MOONCAKE_TRANS_THREAD value: "16" - - name: SGL_ENABLE_JIT_DEEPGEMM + - name: SGLANG_ENABLE_JIT_DEEPGEMM value: "1" - name: NCCL_IB_HCA value: ^=mlx5_0,mlx5_5,mlx5_6 @@ -217,7 +217,7 @@ spec: value: "5" - name: SGLANG_MOONCAKE_TRANS_THREAD value: "16" - - name: SGL_ENABLE_JIT_DEEPGEMM + - name: SGLANG_ENABLE_JIT_DEEPGEMM value: "1" - name: NCCL_IB_HCA value: ^=mlx5_0,mlx5_5,mlx5_6 diff --git a/docs/references/multi_node_deployment/lws_pd/lws-examples/p.yaml b/docs/references/multi_node_deployment/lws_pd/lws-examples/p.yaml index 62df262bb..1c5b58704 100644 --- a/docs/references/multi_node_deployment/lws_pd/lws-examples/p.yaml +++ b/docs/references/multi_node_deployment/lws_pd/lws-examples/p.yaml @@ -71,7 +71,7 @@ spec: value: "1" - name: SGLANG_SET_CPU_AFFINITY value: "true" - - name: SGL_ENABLE_JIT_DEEPGEMM + - name: SGLANG_ENABLE_JIT_DEEPGEMM value: "1" - name: NCCL_IB_QPS_PER_CONNECTION value: "8" @@ -224,7 +224,7 @@ spec: value: "0" - name: SGLANG_MOONCAKE_TRANS_THREAD value: "8" - - name: SGL_ENABLE_JIT_DEEPGEMM + - name: SGLANG_ENABLE_JIT_DEEPGEMM value: "1" - name: SGL_CHUNKED_PREFIX_CACHE_THRESHOLD value: "0" diff --git a/docs/references/multi_node_deployment/lws_pd/lws_pd_deploy.md b/docs/references/multi_node_deployment/lws_pd/lws_pd_deploy.md index eb8454997..b35089683 100644 --- a/docs/references/multi_node_deployment/lws_pd/lws_pd_deploy.md +++ b/docs/references/multi_node_deployment/lws_pd/lws_pd_deploy.md @@ -98,7 +98,7 @@ spec: value: "1" - name: SGLANG_SET_CPU_AFFINITY value: "true" - - name: SGL_ENABLE_JIT_DEEPGEMM + - name: SGLANG_ENABLE_JIT_DEEPGEMM value: "1" - name: NCCL_IB_QPS_PER_CONNECTION value: "8" @@ -257,7 +257,7 @@ spec: value: "0" - name: SGLANG_MOONCAKE_TRANS_THREAD value: "8" - - name: SGL_ENABLE_JIT_DEEPGEMM + - name: SGLANG_ENABLE_JIT_DEEPGEMM value: "1" - name: SGL_CHUNKED_PREFIX_CACHE_THRESHOLD value: "0" @@ -421,7 +421,7 @@ spec: value: "true" - name: SGLANG_MOONCAKE_TRANS_THREAD value: "16" - - name: SGL_ENABLE_JIT_DEEPGEMM + - name: SGLANG_ENABLE_JIT_DEEPGEMM value: "1" - name: NCCL_IB_HCA value: ^=mlx5_0,mlx5_5,mlx5_6 @@ -560,7 +560,7 @@ spec: value: "5" - name: SGLANG_MOONCAKE_TRANS_THREAD value: "16" - - name: SGL_ENABLE_JIT_DEEPGEMM + - name: SGLANG_ENABLE_JIT_DEEPGEMM value: "1" - name: NCCL_IB_HCA value: ^=mlx5_0,mlx5_5,mlx5_6 diff --git a/python/sglang/compile_deep_gemm.py b/python/sglang/compile_deep_gemm.py index 5504bc448..739b907bf 100644 --- a/python/sglang/compile_deep_gemm.py +++ b/python/sglang/compile_deep_gemm.py @@ -19,6 +19,7 @@ import requests from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST from sglang.srt.entrypoints.http_server import launch_server +from sglang.srt.environ import envs from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.server_args import ServerArgs @@ -28,9 +29,9 @@ from sglang.srt.warmup import warmup multiprocessing.set_start_method("spawn", force=True) # Reduce warning -os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1" +envs.SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE.set(True) # Force enable deep gemm -os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1" +envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(True) # Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0" diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index f3bb8c005..06b68c523 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -180,6 +180,7 @@ class Envs: SGLANG_EXPERT_LOCATION_UPDATER_CANARY = EnvBool(False) SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS = EnvBool(False) SGLANG_LOG_EXPERT_LOCATION_METADATA = EnvBool(False) + SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR = EnvStr("/tmp") # TBO SGLANG_TBO_DEBUG = EnvBool(False) diff --git a/python/sglang/srt/eplb/expert_distribution.py b/python/sglang/srt/eplb/expert_distribution.py index 3faf981ef..2f51d4d67 100644 --- a/python/sglang/srt/eplb/expert_distribution.py +++ b/python/sglang/srt/eplb/expert_distribution.py @@ -16,21 +16,20 @@ from __future__ import annotations import logging import math -import os import time from abc import ABC from collections import deque from contextlib import contextmanager -from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type import einops import torch import torch.distributed +from sglang.srt.environ import envs from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import Withable, get_bool_env_var, is_npu +from sglang.srt.utils import Withable, is_npu _is_npu = is_npu() @@ -839,7 +838,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin): def _dump_to_file(name, data): - save_dir = Path(os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR", "/tmp")) + save_dir = envs.SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR.get() path_output = save_dir / name logger.info(f"Write expert distribution to {path_output}") if not save_dir.exists(): diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py index e374759c4..0f4aa9449 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py @@ -7,11 +7,12 @@ from typing import Dict, List, Tuple import torch from tqdm import tqdm +from sglang.srt.environ import envs from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import ( ENABLE_JIT_DEEPGEMM, ) from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import ceil_div, get_bool_env_var, get_int_env_var +from sglang.srt.utils import ceil_div, get_bool_env_var logger = logging.getLogger(__name__) @@ -20,12 +21,9 @@ if ENABLE_JIT_DEEPGEMM: _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1)) -_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var( - "SGL_JIT_DEEPGEMM_PRECOMPILE", "true" -) +_ENABLE_JIT_DEEPGEMM_PRECOMPILE = envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.get() _DO_COMPILE_ALL = True _IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true") -_COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4) _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false") # Force redirect deep_gemm cache_dir diff --git a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py index 62073e38c..6a7ae00d0 100644 --- a/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +++ b/python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py @@ -1,6 +1,7 @@ import logging -from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell +from sglang.srt.environ import envs +from sglang.srt.utils import get_device_sm, is_blackwell logger = logging.getLogger(__name__) @@ -15,7 +16,7 @@ def _compute_enable_deep_gemm(): except ImportError: return False - return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true") + return envs.SGLANG_ENABLE_JIT_DEEPGEMM.get() ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm() diff --git a/test/srt/ep/test_eplb.py b/test/srt/ep/test_eplb.py index c2acc07bb..748dd39c8 100755 --- a/test/srt/ep/test_eplb.py +++ b/test/srt/ep/test_eplb.py @@ -5,6 +5,7 @@ from pathlib import Path from types import SimpleNamespace import sglang as sgl +from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( @@ -23,44 +24,43 @@ class _BaseTestDynamicEPLB(CustomTestCase): def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "2", - "--dp", - "2", - "--enable-dp-attention", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "normal", - "--disable-cuda-graph", - "--enable-eplb", - "--ep-num-redundant-experts", - "4", - "--eplb-rebalance-num-iterations", - "50", - "--expert-distribution-recorder-buffer-size", - "50", - # TODO pr-chain: enable later - # "--enable-expert-distribution-metrics", - # TODO auto determine these flags - "--expert-distribution-recorder-mode", - "stat", - "--ep-dispatch-algorithm", - "static", - *cls.extra_args, - ], - env={ - "SGL_ENABLE_JIT_DEEPGEMM": "0", - "SGLANG_EXPERT_LOCATION_UPDATER_CANARY": "1", - **os.environ, - }, - ) + with ( + envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False), + envs.SGLANG_EXPERT_LOCATION_UPDATER_CANARY.override(True), + ): + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--dp", + "2", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--deepep-mode", + "normal", + "--disable-cuda-graph", + "--enable-eplb", + "--ep-num-redundant-experts", + "4", + "--eplb-rebalance-num-iterations", + "50", + "--expert-distribution-recorder-buffer-size", + "50", + # TODO pr-chain: enable later + # "--enable-expert-distribution-metrics", + # TODO auto determine these flags + "--expert-distribution-recorder-mode", + "stat", + "--ep-dispatch-algorithm", + "static", + *cls.extra_args, + ], + ) @classmethod def tearDownClass(cls): @@ -89,7 +89,7 @@ class TestDynamicEPLBMultiChunk(_BaseTestDynamicEPLB): class TestStaticEPLB(CustomTestCase): def test_save_expert_distribution_and_init_expert_location(self): - os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "0" + envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False) with tempfile.TemporaryDirectory() as tmp_dir: engine_kwargs = dict( @@ -108,7 +108,7 @@ class TestStaticEPLB(CustomTestCase): ) print(f"Action: start engine") - os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir + envs.SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR.set(tmp_dir) engine = sgl.Engine( **engine_kwargs, disable_overlap_schedule=True, diff --git a/test/srt/ep/test_moe_deepep.py b/test/srt/ep/test_moe_deepep.py index aa9d7a1f8..00f7cd59b 100644 --- a/test/srt/ep/test_moe_deepep.py +++ b/test/srt/ep/test_moe_deepep.py @@ -3,6 +3,7 @@ import os import unittest from types import SimpleNamespace +from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( @@ -55,48 +56,45 @@ class TestDPAttn(unittest.TestCase): def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "2", - "--dp", - "2", - "--enable-dp-attention", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "normal", - "--disable-cuda-graph", - # Test custom config - "--deepep-config", - json.dumps( - { - "normal_dispatch": { - "num_sms": 20, - "num_max_nvl_chunked_send_tokens": 16, - "num_max_nvl_chunked_recv_tokens": 256, - "num_max_rdma_chunked_send_tokens": 6, - "num_max_rdma_chunked_recv_tokens": 128, - }, - "normal_combine": { - "num_sms": 20, - "num_max_nvl_chunked_send_tokens": 6, - "num_max_nvl_chunked_recv_tokens": 256, - "num_max_rdma_chunked_send_tokens": 6, - "num_max_rdma_chunked_recv_tokens": 128, - }, - } - ), - ], - env={ - "SGL_ENABLE_JIT_DEEPGEMM": "0", - **os.environ, - }, - ) + with envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False): + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--dp", + "2", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--deepep-mode", + "normal", + "--disable-cuda-graph", + # Test custom config + "--deepep-config", + json.dumps( + { + "normal_dispatch": { + "num_sms": 20, + "num_max_nvl_chunked_send_tokens": 16, + "num_max_nvl_chunked_recv_tokens": 256, + "num_max_rdma_chunked_send_tokens": 6, + "num_max_rdma_chunked_recv_tokens": 128, + }, + "normal_combine": { + "num_sms": 20, + "num_max_nvl_chunked_send_tokens": 6, + "num_max_nvl_chunked_recv_tokens": 256, + "num_max_rdma_chunked_send_tokens": 6, + "num_max_rdma_chunked_recv_tokens": 128, + }, + } + ), + ], + ) @classmethod def tearDownClass(cls): diff --git a/test/srt/test_disaggregation_different_tp.py b/test/srt/test_disaggregation_different_tp.py index 9664d7cec..a146cbfe2 100644 --- a/test/srt/test_disaggregation_different_tp.py +++ b/test/srt/test_disaggregation_different_tp.py @@ -1,8 +1,7 @@ -import os -import time import unittest from types import SimpleNamespace +from sglang.srt.environ import envs from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_disaggregation_utils import TestDisaggregationBase from sglang.test.test_utils import ( @@ -18,8 +17,7 @@ class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase): def setUpClass(cls): super().setUpClass() # Temporarily disable JIT DeepGEMM - cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") - os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" + envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False) cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA @@ -90,8 +88,7 @@ class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase): def setUpClass(cls): super().setUpClass() # Temporarily disable JIT DeepGEMM - cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") - os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" + envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False) cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA @@ -162,8 +159,7 @@ class TestDisaggregationMooncakeMHAPrefillLargerTP(TestDisaggregationBase): def setUpClass(cls): super().setUpClass() # Temporarily disable JIT DeepGEMM - cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") - os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" + envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False) cls.model = DEFAULT_MODEL_NAME_FOR_TEST @@ -234,8 +230,7 @@ class TestDisaggregationMooncakeMHADecodeLargerTP(TestDisaggregationBase): def setUpClass(cls): super().setUpClass() # Temporarily disable JIT DeepGEMM - cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") - os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" + envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False) cls.model = DEFAULT_MODEL_NAME_FOR_TEST diff --git a/test/srt/test_disaggregation_dp_attention.py b/test/srt/test_disaggregation_dp_attention.py index dd82fe887..c4a90c1c2 100644 --- a/test/srt/test_disaggregation_dp_attention.py +++ b/test/srt/test_disaggregation_dp_attention.py @@ -2,6 +2,7 @@ import os import unittest from types import SimpleNamespace +from sglang.srt.environ import envs from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_disaggregation_utils import TestDisaggregationBase from sglang.test.test_utils import ( @@ -16,8 +17,7 @@ class TestDisaggregationDPAttention(TestDisaggregationBase): def setUpClass(cls): super().setUpClass() # Temporarily disable JIT DeepGEMM - cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM") - os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" + envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False) cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index 5d4add72f..aeb989941 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -6,9 +6,9 @@ from pathlib import Path import requests import torch +from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( - DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_BASE, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, @@ -32,7 +32,7 @@ class TestExpertDistribution(CustomTestCase): def _execute_core(self, model_path: str, mode: str = "stat", tp_size: int = 1): """Test expert distribution record endpoints""" with tempfile.TemporaryDirectory() as tmp_dir: - os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir + envs.SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR.set(tmp_dir) process = popen_launch_server( model_path, diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index c9f286fca..739b143fa 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -1,9 +1,9 @@ -import os import unittest from types import SimpleNamespace import requests +from sglang.srt.environ import envs from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( @@ -77,14 +77,16 @@ class BaseFlashAttentionTest(CustomTestCase): def setUpClass(cls): # disable deep gemm precompile to make launch server faster # please don't do this if you want to make your inference workload faster - os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false" - os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=cls.get_server_args(), - ) + with ( + envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.override(False), + envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False), + ): + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=cls.get_server_args(), + ) @classmethod def tearDownClass(cls): diff --git a/test/srt/test_hybrid_attn_backend.py b/test/srt/test_hybrid_attn_backend.py index 1574ff873..e43158a2e 100644 --- a/test/srt/test_hybrid_attn_backend.py +++ b/test/srt/test_hybrid_attn_backend.py @@ -4,6 +4,7 @@ from types import SimpleNamespace import requests +from sglang.srt.environ import envs from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( @@ -49,18 +50,20 @@ class TestHybridAttnBackendBase(CustomTestCase): def setUpClass(cls): # disable deep gemm precompile to make launch server faster # please don't do this if you want to make your inference workload faster - os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false" - os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" - if cls.speculative_decode: - model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST - else: - model = cls.model - cls.process = popen_launch_server( - model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=cls.get_server_args(), - ) + with ( + envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.override(False), + envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False), + ): + if cls.speculative_decode: + model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST + else: + model = cls.model + cls.process = popen_launch_server( + model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=cls.get_server_args(), + ) @classmethod def tearDownClass(cls): diff --git a/test/srt/test_ngram_speculative_decoding.py b/test/srt/test_ngram_speculative_decoding.py index 4495f9121..3106fa970 100644 --- a/test/srt/test_ngram_speculative_decoding.py +++ b/test/srt/test_ngram_speculative_decoding.py @@ -1,9 +1,9 @@ -import os import unittest from types import SimpleNamespace import requests +from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( @@ -47,8 +47,8 @@ class TestNgramSpeculativeDecodingBase(CustomTestCase): def setUpClass(cls): # disable deep gemm precompile to make launch server faster # please don't do this if you want to make your inference workload faster - os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false" - os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" + envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.set(False) + envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False) model = cls.model cls.process = popen_launch_server( model, diff --git a/test/srt/test_standalone_speculative_decoding.py b/test/srt/test_standalone_speculative_decoding.py index e2962b716..70d18db4a 100644 --- a/test/srt/test_standalone_speculative_decoding.py +++ b/test/srt/test_standalone_speculative_decoding.py @@ -1,9 +1,9 @@ -import os import unittest from types import SimpleNamespace import requests +from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( @@ -55,8 +55,8 @@ class TestStandaloneSpeculativeDecodingBase(CustomTestCase): def setUpClass(cls): # disable deep gemm precompile to make launch server faster # please don't do this if you want to make your inference workload faster - os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false" - os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" + envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.set(False) + envs.SGLANG_ENABLE_JIT_DEEPGEMM.set(False) model = cls.model cls.process = popen_launch_server( model, diff --git a/test/srt/test_two_batch_overlap.py b/test/srt/test_two_batch_overlap.py index 6aa550c46..7c8bc7eb1 100644 --- a/test/srt/test_two_batch_overlap.py +++ b/test/srt/test_two_batch_overlap.py @@ -1,9 +1,9 @@ -import os import unittest from types import SimpleNamespace import requests +from sglang.srt.environ import envs from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.two_batch_overlap import ( compute_split_seq_index, @@ -25,26 +25,26 @@ class TestTwoBatchOverlap(unittest.TestCase): def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "2", - "--dp", - "2", - "--enable-dp-attention", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "normal", - "--disable-cuda-graph", # DeepEP normal does not support CUDA Graph - "--enable-two-batch-overlap", - ], - env={"SGL_ENABLE_JIT_DEEPGEMM": "0", **os.environ}, - ) + with envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False): + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--dp", + "2", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--deepep-mode", + "normal", + "--disable-cuda-graph", # DeepEP normal does not support CUDA Graph + "--enable-two-batch-overlap", + ], + ) @classmethod def tearDownClass(cls): @@ -126,26 +126,26 @@ class TestQwen3TwoBatchOverlap(TestTwoBatchOverlap): cls.model = DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.api_key = "sk-1234" - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--tp", - "2", - "--dp", - "2", - "--enable-dp-attention", - "--moe-a2a-backend", - "deepep", - "--deepep-mode", - "normal", - "--disable-cuda-graph", # DeepEP normal does not support CUDA Graph - "--enable-two-batch-overlap", - ], - env={"SGL_ENABLE_JIT_DEEPGEMM": "0", **os.environ}, - ) + with envs.SGLANG_ENABLE_JIT_DEEPGEMM.override(False): + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--dp", + "2", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--deepep-mode", + "normal", + "--disable-cuda-graph", # DeepEP normal does not support CUDA Graph + "--enable-two-batch-overlap", + ], + ) if __name__ == "__main__":