From 516738b0963de21e4f8ef04cd00d1f0da827816f Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Mon, 13 Oct 2025 19:34:43 +0800 Subject: [PATCH] Depreate `global_server_args_dict` (#11528) --- python/sglang/global_config.py | 3 -- .../device_communicators/pynccl_allocator.py | 4 +- .../srt/eplb/expert_location_dispatch.py | 4 +- .../srt/eplb/expert_location_updater.py | 4 +- .../attention/double_sparsity_backend.py | 4 +- .../attention/flashattention_backend.py | 4 +- .../attention/flashinfer_mla_backend.py | 10 ++-- .../srt/layers/attention/nsa/nsa_indexer.py | 4 +- .../triton_ops/double_sparsity_attention.py | 4 +- .../layers/attention/trtllm_mla_backend.py | 8 +-- python/sglang/srt/layers/attention/vision.py | 6 +-- python/sglang/srt/layers/communicator.py | 13 +++-- python/sglang/srt/layers/logits_processor.py | 16 +++--- .../srt/layers/moe/fused_moe_triton/layer.py | 2 - .../sglang/srt/layers/quantization/mxfp4.py | 8 +-- python/sglang/srt/layers/sampler.py | 10 ++-- python/sglang/srt/managers/mm_utils.py | 4 +- python/sglang/srt/managers/schedule_batch.py | 50 ++----------------- python/sglang/srt/managers/scheduler.py | 19 +++---- python/sglang/srt/managers/tp_worker.py | 3 +- python/sglang/srt/mem_cache/common.py | 12 ++--- .../sglang/srt/model_executor/model_runner.py | 37 ++++++-------- python/sglang/srt/model_loader/loader.py | 19 +++---- python/sglang/srt/models/apertus.py | 5 +- python/sglang/srt/models/arcee.py | 4 +- python/sglang/srt/models/bailing_moe.py | 12 ++--- python/sglang/srt/models/bailing_moe_nextn.py | 7 ++- python/sglang/srt/models/deepseek_nextn.py | 4 +- python/sglang/srt/models/deepseek_v2.py | 44 ++++++++-------- python/sglang/srt/models/falcon_h1.py | 4 +- python/sglang/srt/models/glm4_moe.py | 20 +++----- python/sglang/srt/models/glm4_moe_nextn.py | 4 +- python/sglang/srt/models/glm4v_moe.py | 10 ++-- python/sglang/srt/models/gpt_oss.py | 8 +-- python/sglang/srt/models/grok.py | 12 ++--- python/sglang/srt/models/llama.py | 4 +- python/sglang/srt/models/longcat_flash.py | 10 ++-- python/sglang/srt/models/mllama4.py | 4 +- python/sglang/srt/models/qwen2_moe.py | 8 +-- python/sglang/srt/models/qwen3_moe.py | 8 +-- python/sglang/srt/models/qwen3_next.py | 4 +- python/sglang/srt/models/qwen3_next_mtp.py | 7 ++- python/sglang/srt/models/qwen3_vl_moe.py | 14 ++---- python/sglang/srt/models/step3_vl.py | 5 +- .../srt/sampling/sampling_batch_info.py | 19 +++---- python/sglang/srt/server_args.py | 21 +++++++- python/sglang/srt/single_batch_overlap.py | 1 - python/sglang/srt/speculative/eagle_info.py | 11 ++-- .../sglang/srt/speculative/eagle_info_v2.py | 10 ++-- python/sglang/srt/speculative/eagle_worker.py | 12 ++--- python/sglang/srt/speculative/ngram_info.py | 10 ++-- python/sglang/srt/two_batch_overlap.py | 11 ++-- test/srt/rl/test_fp32_lm_head.py | 12 +++-- test/srt/test_gptqmodel_dynamic.py | 8 +-- 54 files changed, 240 insertions(+), 321 deletions(-) diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 552351466..fcd65b5ed 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -6,9 +6,6 @@ class GlobalConfig: """ Store some global constants. - - See also python/sglang/srt/managers/schedule_batch.py::global_server_args_dict, which stores - many global runtime arguments as well. """ def __init__(self): diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py b/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py index d7274cf2c..9ce1c1c20 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py @@ -5,7 +5,7 @@ from packaging import version from torch.cuda.memory import CUDAPluggableAllocator from sglang.srt.distributed.parallel_state import GroupCoordinator -from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.server_args import get_global_server_args nccl_allocator_source = """ #include @@ -32,7 +32,7 @@ _graph_pool_id = None def is_symmetric_memory_enabled(): - return global_server_args_dict["enable_symm_mem"] + return get_global_server_args().enable_symm_mem def set_graph_pool_id(graph_pool_id): diff --git a/python/sglang/srt/eplb/expert_location_dispatch.py b/python/sglang/srt/eplb/expert_location_dispatch.py index 624dc3fd1..7ac03390a 100644 --- a/python/sglang/srt/eplb/expert_location_dispatch.py +++ b/python/sglang/srt/eplb/expert_location_dispatch.py @@ -18,7 +18,7 @@ from typing import Literal, Optional import torch from sglang.srt.eplb.expert_location import get_global_expert_location_metadata -from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.server_args import get_global_server_args @dataclass @@ -34,7 +34,7 @@ class ExpertLocationDispatchInfo: @classmethod def init_new(cls, layer_id: int): - ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"] + ep_dispatch_algorithm = get_global_server_args().ep_dispatch_algorithm expert_location_metadata = get_global_expert_location_metadata() assert expert_location_metadata is not None diff --git a/python/sglang/srt/eplb/expert_location_updater.py b/python/sglang/srt/eplb/expert_location_updater.py index 772e65f18..286f1d0e3 100644 --- a/python/sglang/srt/eplb/expert_location_updater.py +++ b/python/sglang/srt/eplb/expert_location_updater.py @@ -24,7 +24,7 @@ from sglang.srt.eplb.expert_location import ( ExpertLocationMetadata, get_global_expert_location_metadata, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import get_bool_env_var logger = logging.getLogger(__name__) @@ -97,7 +97,7 @@ def _update_expert_weights_with_canary( canary_tensor = ( _get_canary_value(old_expert_location_metadata, layer_id) .clone() - .to(device=global_server_args_dict["device"], non_blocking=True) + .to(device=get_global_server_args().device, non_blocking=True) ) routed_experts_weights_of_layer[layer_id].append(canary_tensor) diff --git a/python/sglang/srt/layers/attention/double_sparsity_backend.py b/python/sglang/srt/layers/attention/double_sparsity_backend.py index 47b867f61..76a63a093 100644 --- a/python/sglang/srt/layers/attention/double_sparsity_backend.py +++ b/python/sglang/srt/layers/attention/double_sparsity_backend.py @@ -5,8 +5,8 @@ from typing import TYPE_CHECKING import torch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.server_args import get_global_server_args if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention @@ -42,7 +42,7 @@ class DoubleSparseAttnBackend(AttentionBackend): # TODO: Change the hard-coded block_seq_num self.BLOCK_SEQ = 128 - if global_server_args_dict.get("triton_attention_reduce_in_fp32", False): + if get_global_server_args().triton_attention_reduce_in_fp32: self.reduce_dtype = torch.float32 else: self.reduce_dtype = torch.float16 diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 927f1d93c..4fae2cb1d 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -11,8 +11,8 @@ import triton.language as tl from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.radix_attention import AttentionType -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.server_args import get_global_server_args from sglang.srt.speculative.spec_info import SpecInput if TYPE_CHECKING: @@ -830,7 +830,7 @@ class FlashAttentionBackend(AttentionBackend): ): # Do multi-head attention with chunked prefix cache if forward_batch.attn_attend_prefix_cache: - assert not global_server_args_dict["disable_chunked_prefix_cache"] + assert not get_global_server_args().disable_chunked_prefix_cache # MHA for chunked prefix kv cache when running model with MLA assert forward_batch.prefix_chunk_idx is not None assert forward_batch.prefix_chunk_cu_seq_lens is not None diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 6efda7775..82d1b05b4 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -28,8 +28,8 @@ from sglang.srt.layers.attention.flashinfer_backend import ( create_flashinfer_kv_indices_triton, ) from sglang.srt.layers.dp_attention import get_attention_tp_size -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.server_args import get_global_server_args from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.utils import ( is_flashinfer_available, @@ -193,9 +193,9 @@ class FlashInferMLAAttnBackend(AttentionBackend): self.skip_prefill = skip_prefill self.enable_chunk_kv = ( not skip_prefill - and global_server_args_dict["disaggregation_mode"] != "decode" - and not global_server_args_dict["disable_chunked_prefix_cache"] - and not global_server_args_dict["flashinfer_mla_disable_ragged"] + and get_global_server_args().disaggregation_mode != "decode" + and not get_global_server_args().disable_chunked_prefix_cache + and not get_global_server_args().flashinfer_mla_disable_ragged ) self.page_size = model_runner.page_size @@ -306,7 +306,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): prefix_lens = forward_batch.extend_prefix_lens extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) use_ragged = ( - not global_server_args_dict["flashinfer_mla_disable_ragged"] + not get_global_server_args().flashinfer_mla_disable_ragged and extend_no_prefix ) diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 798e1c0a8..13ab6b172 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -23,9 +23,9 @@ from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.rotary_embedding import get_rope_wrapper -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.server_args import get_global_server_args if TYPE_CHECKING: from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool @@ -162,7 +162,7 @@ class Indexer(CustomOp): base=rope_theta, # type: ignore rope_scaling=rope_scaling, is_neox_style=False, - device=global_server_args_dict["device"], + device=get_global_server_args().device, ) self.block_size = block_size self.scale_fmt = scale_fmt diff --git a/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py b/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py index 72e0bfe78..8e972f408 100644 --- a/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py @@ -2,7 +2,7 @@ import torch import triton import triton.language as tl -from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import is_cuda, is_hip _is_cuda = is_cuda() @@ -11,7 +11,7 @@ if _is_cuda: _is_hip = is_hip() -if global_server_args_dict.get("attention_reduce_in_fp32", False): +if get_global_server_args().triton_attention_reduce_in_fp32: REDUCE_TRITON_TYPE = tl.float32 REDUCE_TORCH_TYPE = torch.float32 else: diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 85e535b07..3727524ef 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -20,8 +20,8 @@ from sglang.srt.layers.attention.utils import ( create_flashmla_kv_indices_triton, ) from sglang.srt.layers.dp_attention import get_attention_tp_size -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import is_cuda, is_flashinfer_available if is_flashinfer_available(): @@ -123,9 +123,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None - self.disable_chunked_prefix_cache = global_server_args_dict[ - "disable_chunked_prefix_cache" - ] + self.disable_chunked_prefix_cache = ( + get_global_server_args().disable_chunked_prefix_cache + ) self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 489b8248b..bb9016e0d 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -45,7 +45,7 @@ from sglang.srt.layers.linear import ( ) from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb -from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix ROTARY_EMBED_CLASSES = { @@ -468,7 +468,7 @@ class VisionAttention(nn.Module): _passed_backend = qkv_backend qkv_backend = self._determine_attention_backend(_passed_backend) if ( - global_server_args_dict["mm_attention_backend"] is None + get_global_server_args().mm_attention_backend is None and _passed_backend is None ): print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.") @@ -528,7 +528,7 @@ class VisionAttention(nn.Module): - CUDA: "triton_attn" - Non-CUDA: "sdpa" """ - override_backend = global_server_args_dict["mm_attention_backend"] + override_backend = get_global_server_args().mm_attention_backend if override_backend is not None: backend = override_backend elif passed_backend is not None: diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index e050da91d..60b4e9e5f 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -40,8 +40,9 @@ from sglang.srt.layers.moe import ( get_moe_a2a_backend, should_use_flashinfer_cutlass_moe_fp4_allgather, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.server_args import get_global_server_args +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( get_bool_env_var, is_cuda, @@ -168,7 +169,7 @@ class LayerScatterModes: def enable_moe_dense_fully_dp(): - return global_server_args_dict["moe_dense_tp_size"] == 1 + return get_global_server_args().moe_dense_tp_size == 1 class LayerCommunicator: @@ -314,7 +315,9 @@ class LayerCommunicator: def should_fuse_mlp_allreduce_with_next_layer( self, forward_batch: ForwardBatch ) -> bool: - speculative_algo = global_server_args_dict.get("speculative_algorithm", None) + speculative_algo = SpeculativeAlgorithm.from_string( + get_global_server_args().speculative_algorithm + ) if ( is_dp_attention_enabled() and speculative_algo is not None @@ -333,7 +336,7 @@ class LayerCommunicator: static_conditions_met = ( (not self.is_last_layer) and (self._context.tp_size > 1) - and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False) + and get_global_server_args().enable_flashinfer_allreduce_fusion and _is_flashinfer_available ) @@ -531,7 +534,7 @@ class CommunicateWithAllReduceAndLayerNormFn: (_is_sm100_supported or _is_sm90_supported) and _is_flashinfer_available and hasattr(layernorm, "forward_with_allreduce_fusion") - and global_server_args_dict["enable_flashinfer_allreduce_fusion"] + and get_global_server_args().enable_flashinfer_allreduce_fusion and hidden_states.shape[0] <= 4096 ): hidden_states, residual = layernorm.forward_with_allreduce_fusion( diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index dfacd858c..a0cf55b0e 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -38,17 +38,15 @@ from sglang.srt.layers.dp_attention import ( get_dp_device, get_dp_dtype, get_dp_hidden_size, - get_global_dp_buffer, get_local_attention_dp_size, - set_dp_buffer_len, ) from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, ForwardMode, ) +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend logger = logging.getLogger(__name__) @@ -230,8 +228,8 @@ class LogitsProcessor(nn.Module): super().__init__() self.config = config self.logit_scale = logit_scale - self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"] - self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"] + self.use_attn_tp_group = get_global_server_args().enable_dp_lm_head + self.use_fp32_lm_head = get_global_server_args().enable_fp32_lm_head if self.use_attn_tp_group: self.attn_tp_size = get_attention_tp_size() self.do_tensor_parallel_all_gather = ( @@ -254,8 +252,8 @@ class LogitsProcessor(nn.Module): ): self.final_logit_softcapping = None - self.debug_tensor_dump_output_folder = global_server_args_dict.get( - "debug_tensor_dump_output_folder", None + self.debug_tensor_dump_output_folder = ( + get_global_server_args().debug_tensor_dump_output_folder ) def compute_logprobs_for_multi_item_scoring( @@ -372,9 +370,7 @@ class LogitsProcessor(nn.Module): logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) # Check if multi-item scoring is enabled via server args (only for prefill-only requests) - multi_item_delimiter = global_server_args_dict.get( - "multi_item_scoring_delimiter" - ) + multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter if multi_item_delimiter is not None and logits_metadata.is_prefill_only: return self.compute_logprobs_for_multi_item_scoring( input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 9cdfbc86c..1ff778184 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -27,12 +27,10 @@ from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, QuantizationConfig, - QuantizeMethodBase, ) from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight from sglang.srt.utils import ( cpu_has_amx_support, diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index caf323950..76757e501 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, ) from sglang.srt.layers.quantization.utils import is_layer_skipped -from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( direct_register_custom_op, is_cuda, @@ -265,9 +265,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() self.with_bias = False self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4() - self.flashinfer_mxfp4_moe_precision = global_server_args_dict[ - "flashinfer_mxfp4_moe_precision" - ] + self.flashinfer_mxfp4_moe_precision = ( + get_global_server_args().flashinfer_mxfp4_moe_precision + ) self.triton_kernel_moe_forward = None self.triton_kernel_moe_with_bias_forward = None diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index f2deb2b26..bf50d4b11 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -11,8 +11,8 @@ from sglang.srt.layers.dp_attention import ( is_dp_attention_enabled, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda if is_cuda(): @@ -33,7 +33,7 @@ RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB") class Sampler(nn.Module): def __init__(self): super().__init__() - self.use_nan_detection = global_server_args_dict["enable_nan_detection"] + self.use_nan_detection = get_global_server_args().enable_nan_detection self.tp_sync_group = get_tp_group().device_group if is_dp_attention_enabled(): @@ -103,7 +103,7 @@ class Sampler(nn.Module): del logits if True: # Keep this redundant check to simplify some internal code sync - if global_server_args_dict["sampling_backend"] == "flashinfer": + if get_global_server_args().sampling_backend == "flashinfer": if sampling_info.need_min_p_sampling: probs = top_k_renorm_prob(probs, sampling_info.top_ks) probs = top_p_renorm_prob(probs, sampling_info.top_ps) @@ -118,7 +118,7 @@ class Sampler(nn.Module): filter_apply_order="joint", check_nan=self.use_nan_detection, ) - elif global_server_args_dict["sampling_backend"] == "pytorch": + elif get_global_server_args().sampling_backend == "pytorch": # A slower fallback implementation with torch native operations. batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( probs, @@ -131,7 +131,7 @@ class Sampler(nn.Module): ) else: raise ValueError( - f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" + f"Invalid sampling backend: {get_global_server_args().sampling_backend}" ) if return_logprob: diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 41de295af..e2012e9de 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -16,10 +16,10 @@ from sglang.srt.managers.schedule_batch import ( Modality, MultimodalDataItem, MultimodalInputs, - global_server_args_dict, ) from sglang.srt.mem_cache.multimodal_cache import MultiModalCache from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once from sglang.utils import logger @@ -428,7 +428,7 @@ def _adjust_embedding_length( f"tokens from multimodal embeddings." ) if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding: - chunked_prefill_size = global_server_args_dict["chunked_prefill_size"] + chunked_prefill_size = get_global_server_args().chunked_prefill_size if chunked_prefill_size != -1: logger.warning( "You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill" diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 0ce5ad115..ee427cce2 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -73,7 +73,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams -from sglang.srt.server_args import ServerArgs +from sglang.srt.server_args import ServerArgs, get_global_server_args from sglang.srt.utils import flatten_nested_list from sglang.srt.utils.common import next_power_of_2 @@ -83,47 +83,6 @@ if TYPE_CHECKING: INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 -GLOBAL_SERVER_ARGS_KEYS = [ - "attention_backend", - "mm_attention_backend", - "debug_tensor_dump_inject", - "debug_tensor_dump_output_folder", - "chunked_prefill_size", - "device", - "disable_chunked_prefix_cache", - "disable_flashinfer_cutlass_moe_fp4_allgather", - "disable_radix_cache", - "enable_dp_lm_head", - "enable_fp32_lm_head", - "flashinfer_mxfp4_moe_precision", - "enable_flashinfer_allreduce_fusion", - "moe_dense_tp_size", - "ep_dispatch_algorithm", - "ep_num_redundant_experts", - "enable_nan_detection", - "flashinfer_mla_disable_ragged", - "pp_max_micro_batch_size", - "disable_shared_experts_fusion", - "sampling_backend", - "speculative_accept_threshold_single", - "speculative_accept_threshold_acc", - "speculative_attention_mode", - "torchao_config", - "triton_attention_reduce_in_fp32", - "num_reserved_decode_tokens", - "weight_loader_disable_mmap", - "enable_multimodal", - "enable_symm_mem", - "enable_custom_logit_processor", - "disaggregation_mode", - "enable_deterministic_inference", - "nsa_prefill", - "nsa_decode", - "multi_item_scoring_delimiter", -] - -# Put some global args for easy access -global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS} logger = logging.getLogger(__name__) @@ -685,12 +644,9 @@ class Req: def is_prefill_only(self) -> bool: """Check if this request is prefill-only (no token generation needed).""" # NOTE: when spec is enabled, prefill_only optimizations are disabled - from sglang.srt.speculative.spec_info import SpeculativeAlgorithm - spec_alg = global_server_args_dict["speculative_algorithm"] - return self.sampling_params.max_new_tokens == 0 and ( - spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE - ) + spec_alg = get_global_server_args().speculative_algorithm + return self.sampling_params.max_new_tokens == 0 and spec_alg is None def add_latency(self, stage: RequestStage): if self.metrics_collector is None: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1fa41c1ef..1f42543ab 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -122,7 +122,6 @@ from sglang.srt.managers.schedule_batch import ( Req, RequestStage, ScheduleBatch, - global_server_args_dict, ) from sglang.srt.managers.schedule_policy import ( AddReqResult, @@ -151,7 +150,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.parser.reasoning_parser import ReasoningParser -from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args from sglang.srt.speculative.eagle_info import EagleDraftInput from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.tracing.trace import ( @@ -448,13 +447,12 @@ class Scheduler( self.max_req_input_len, self.random_seed, self.device, - worker_global_server_args_dict, _, _, _, ) = self.tp_worker.get_worker_info() - if global_server_args_dict["pp_max_micro_batch_size"] is None: - global_server_args_dict["pp_max_micro_batch_size"] = max( + if get_global_server_args().pp_max_micro_batch_size is None: + get_global_server_args().pp_max_micro_batch_size = max( self.max_running_requests // server_args.pp_size, 1 ) @@ -466,7 +464,6 @@ class Scheduler( self.world_group = get_world_group() self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func() - global_server_args_dict.update(worker_global_server_args_dict) set_random_seed(self.random_seed) # Hybrid memory pool @@ -1942,7 +1939,7 @@ class Scheduler( return ret def get_num_allocatable_reqs(self, running_bs): - res = global_server_args_dict["pp_max_micro_batch_size"] - running_bs + res = get_global_server_args().pp_max_micro_batch_size - running_bs if self.pp_size > 1: res = min(res, self.req_to_token_pool.available_size()) return res @@ -2686,7 +2683,7 @@ class Scheduler( ) def get_internal_state(self, recv_req: GetInternalStateReq): - ret = dict(global_server_args_dict) + ret = vars(get_global_server_args()) ret["last_gen_throughput"] = self.last_gen_throughput ret["memory_usage"] = { "weight": round( @@ -2742,11 +2739,11 @@ class Scheduler( logger.info(f"{avg_spec_accept_length=}") self.cum_spec_accept_length = self.cum_spec_accept_count = 0 for k, v in server_args_dict.items(): - global_server_args_dict[k] = v - logger.info(f"Global server args updated! {global_server_args_dict=}") + setattr(get_global_server_args(), k, v) + logger.info(f"Global server args updated! {get_global_server_args()=}") return SetInternalStateReqOutput( updated=True, - server_args=global_server_args_dict, + server_args=vars(get_global_server_args()), ) def handle_rpc_request(self, recv_req: RpcReqInput): diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 52a40a371..267810c06 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import ( UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, ) -from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict +from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool @@ -190,7 +190,6 @@ class TpModelWorker: self.max_req_input_len, self.random_seed, self.device, - global_server_args_dict, self.model_runner.req_to_token_pool.size, self.model_runner.req_to_token_pool.max_context_len, self.model_runner.token_to_kv_pool.size, diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index 040bc45bf..7dcd69410 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -11,7 +11,7 @@ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool -from sglang.srt.server_args import ServerArgs +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import support_triton if TYPE_CHECKING: @@ -19,10 +19,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -GLOBAL_SERVER_ARGS_KEYS = ["attention_backend"] - -global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS} - @triton.jit def write_req_to_token_pool_triton( @@ -88,7 +84,7 @@ def write_cache_indices( prefix_tensors: list[torch.Tensor], req_to_token_pool: ReqToTokenPool, ): - if support_triton(global_server_args_dict.get("attention_backend")): + if support_triton(get_global_server_args().attention_backend): prefix_pointers = torch.tensor( [t.data_ptr() for t in prefix_tensors], device=req_to_token_pool.device, @@ -129,8 +125,8 @@ def get_last_loc( prefix_lens_tensor: torch.Tensor, ) -> torch.Tensor: if ( - global_server_args_dict["attention_backend"] != "ascend" - and global_server_args_dict["attention_backend"] != "torch_native" + get_global_server_args().attention_backend != "ascend" + and get_global_server_args().attention_backend != "torch_native" ): impl = get_last_loc_triton else: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 867d75204..3977ad01e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -83,10 +83,6 @@ from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.lora.lora_registry import LoRARef -from sglang.srt.managers.schedule_batch import ( - GLOBAL_SERVER_ARGS_KEYS, - global_server_args_dict, -) from sglang.srt.mem_cache.allocator import ( BaseTokenToKVPoolAllocator, PagedTokenToKVPoolAllocator, @@ -125,7 +121,11 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo -from sglang.srt.server_args import ServerArgs +from sglang.srt.server_args import ( + ServerArgs, + get_global_server_args, + set_global_server_args_for_scheduler, +) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( MultiprocessingSerializer, @@ -278,15 +278,12 @@ class ModelRunner: # Model-specific adjustment self.model_specific_adjustment() - # Global vars - global_server_args_dict.update( - {k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS} - | { - # TODO it is indeed not a "server args" - "use_mla_backend": self.use_mla_backend, - "speculative_algorithm": self.spec_algorithm, - } - ) + # Set the global server_args in the scheduler process + set_global_server_args_for_scheduler(server_args) + global_server_args = get_global_server_args() + + # FIXME: hacky set `use_mla_backend` + global_server_args.use_mla_backend = self.use_mla_backend # Init OpenMP threads binding for CPU if self.device == "cpu": @@ -419,7 +416,7 @@ class ModelRunner: # In layered loading, torchao may have been applied if not torchao_applied: apply_torchao_config_to_model( - self.model, global_server_args_dict["torchao_config"] + self.model, get_global_server_args().torchao_config ) # Apply torch TP if the model supports it @@ -1879,12 +1876,10 @@ class ModelRunner: self.server_args.attention_backend ) - global_server_args_dict.update( - { - "decode_attention_backend": self.decode_attention_backend_str, - "prefill_attention_backend": self.prefill_attention_backend_str, - } - ) + ( + get_global_server_args().prefill_attention_backend, + get_global_server_args().decode_attention_backend, + ) = (self.prefill_attention_backend_str, self.decode_attention_backend_str) return attn_backend def _get_attention_backend_from_str(self, backend_str: str): diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index de58a8dd7..691a23b64 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -4,7 +4,6 @@ from __future__ import annotations # ruff: noqa: SIM117 import collections -import concurrent import dataclasses import fnmatch import glob @@ -12,12 +11,10 @@ import json import logging import math import os -import re import socket import threading import time from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager, suppress from typing import ( TYPE_CHECKING, @@ -33,10 +30,10 @@ from typing import ( import huggingface_hub import numpy as np -import requests -import safetensors.torch import torch +from sglang.srt.server_args import get_global_server_args + # Try to import accelerate (optional dependency) try: from accelerate import infer_auto_device_map, init_empty_weights @@ -81,8 +78,6 @@ DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = ( 0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration ) from sglang.srt.model_loader.weight_utils import ( - _BAR_FORMAT, - default_weight_loader, download_safetensors_index_file_from_hf, download_weights_from_hf, filter_duplicate_safetensors_files, @@ -445,10 +440,8 @@ class DefaultModelLoader(BaseModelLoader): hf_weights_files, ) elif use_safetensors: - from sglang.srt.managers.schedule_batch import global_server_args_dict - - weight_loader_disable_mmap = global_server_args_dict.get( - "weight_loader_disable_mmap" + weight_loader_disable_mmap = ( + get_global_server_args().weight_loader_disable_mmap ) if extra_config.get("enable_multithread_load"): @@ -616,9 +609,9 @@ class LayeredModelLoader(DefaultModelLoader): device_config: DeviceConfig, ) -> nn.Module: from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model - from sglang.srt.managers.schedule_batch import global_server_args_dict + from sglang.srt.server_args import get_global_server_args - torchao_config = global_server_args_dict.get("torchao_config") + torchao_config = get_global_server_args().torchao_config target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): diff --git a/python/sglang/srt/models/apertus.py b/python/sglang/srt/models/apertus.py index 161cf1062..ca84264b9 100644 --- a/python/sglang/srt/models/apertus.py +++ b/python/sglang/srt/models/apertus.py @@ -46,15 +46,14 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name, ) +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, make_layers -from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -447,7 +446,7 @@ class ApertusForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) diff --git a/python/sglang/srt/models/arcee.py b/python/sglang/srt/models/arcee.py index f9ebfe19a..5afd5f34f 100644 --- a/python/sglang/srt/models/arcee.py +++ b/python/sglang/srt/models/arcee.py @@ -42,13 +42,13 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name, ) +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, make_layers logger = logging.getLogger(__name__) @@ -407,7 +407,7 @@ class ArceeForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) diff --git a/python/sglang/srt/models/bailing_moe.py b/python/sglang/srt/models/bailing_moe.py index 23313cb42..2cb7d5961 100644 --- a/python/sglang/srt/models/bailing_moe.py +++ b/python/sglang/srt/models/bailing_moe.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" SGLang BailingMoE model.""" +"""SGLang BailingMoE model.""" import logging from typing import Any, Dict, Iterable, Optional, Tuple, Union @@ -68,7 +68,6 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -76,6 +75,7 @@ from sglang.srt.models.utils import ( create_fused_set_kv_buffer_arg, enable_fused_set_kv_buffer, ) +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers LoraConfig = None @@ -204,8 +204,8 @@ class BailingMoESparseMoeBlock(nn.Module): else: self.router_dtype = torch.bfloat16 - # TODO global_server_args_dict["ep_num_redundant_experts"] is used for eplb, not supported now - assert global_server_args_dict["ep_num_redundant_experts"] == 0 + # TODO global_server_args.ep_num_redundant_experts is used for eplb, not supported now + assert get_global_server_args().ep_num_redundant_experts == 0 # check group topk self.num_expert_group = getattr(config, "n_group", 0) self.topk_group = getattr(config, "topk_group", 0) @@ -220,7 +220,7 @@ class BailingMoESparseMoeBlock(nn.Module): self.use_grouped_topk = False self.num_experts = ( - config.num_experts + global_server_args_dict["ep_num_redundant_experts"] + config.num_experts + get_global_server_args().ep_num_redundant_experts ) self.gate = BailingMoEGate( @@ -824,7 +824,7 @@ class BailingMoEForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/bailing_moe_nextn.py b/python/sglang/srt/models/bailing_moe_nextn.py index 49198001c..76a24f4c9 100644 --- a/python/sglang/srt/models/bailing_moe_nextn.py +++ b/python/sglang/srt/models/bailing_moe_nextn.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" SGLang BailingMoENextN model.""" +"""SGLang BailingMoENextN model.""" import logging from typing import Iterable, Optional, Tuple @@ -29,15 +29,14 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.dp_attention import is_dp_attention_enabled from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.bailing_moe import BailingMoEBlock, BailingMoEForCausalLM +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix LoraConfig = None @@ -145,7 +144,7 @@ class BailingMoeForCausalLMNextN(BailingMoEForCausalLM): config.hidden_size, quant_config=quant_config, prefix=add_prefix("model.shared_head.head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index 0914ead19..a01f386da 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda logger = logging.getLogger(__name__) @@ -152,7 +152,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): config.hidden_size, quant_config=quant_config, prefix=add_prefix("model.shared_head.head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e66dd4a1f..454e08585 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -35,7 +35,6 @@ from sglang.srt.configs.model_config import ( get_nsa_index_topk, is_deepseek_nsa, ) -from sglang.srt.debug_utils.dumper import dumper from sglang.srt.distributed import ( get_moe_expert_parallel_world_size, get_pp_group, @@ -108,10 +107,11 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.server_args import get_global_server_args from sglang.srt.single_batch_overlap import SboFlags +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.two_batch_overlap import ( MaybeTboDeepEPDispatcher, model_forward_maybe_tbo, @@ -520,7 +520,7 @@ class DeepseekV2MoE(nn.Module): self.n_shared_experts = config.n_shared_experts self.num_fused_shared_experts = ( 0 - if global_server_args_dict["disable_shared_experts_fusion"] + if get_global_server_args().disable_shared_experts_fusion else config.n_shared_experts ) self.config = config @@ -549,7 +549,7 @@ class DeepseekV2MoE(nn.Module): self.experts = get_moe_impl_class(quant_config)( num_experts=config.n_routed_experts + self.num_fused_shared_experts - + global_server_args_dict["ep_num_redundant_experts"], + + get_global_server_args().ep_num_redundant_experts, num_fused_shared_experts=self.num_fused_shared_experts, top_k=config.num_experts_per_tok + self.num_fused_shared_experts, hidden_size=config.hidden_size, @@ -627,7 +627,7 @@ class DeepseekV2MoE(nn.Module): self.ep_size = get_moe_expert_parallel_world_size() self.num_experts = ( config.n_routed_experts - + global_server_args_dict["ep_num_redundant_experts"] + + get_global_server_args().ep_num_redundant_experts ) self.renormalize = config.norm_topk_prob self.topk_group = config.topk_group @@ -1125,7 +1125,7 @@ class DeepseekV2AttentionMLA(nn.Module): base=rope_theta, rope_scaling=rope_scaling, is_neox_style=False, - device=global_server_args_dict["device"], + device=get_global_server_args().device, ) if rope_scaling: @@ -1169,12 +1169,12 @@ class DeepseekV2AttentionMLA(nn.Module): self.w_scale_v = None self.use_deep_gemm_bmm = False - self.flashinfer_mla_disable_ragged = global_server_args_dict[ - "flashinfer_mla_disable_ragged" - ] - self.disable_chunked_prefix_cache = global_server_args_dict[ - "disable_chunked_prefix_cache" - ] + self.flashinfer_mla_disable_ragged = ( + get_global_server_args().flashinfer_mla_disable_ragged + ) + self.disable_chunked_prefix_cache = ( + get_global_server_args().disable_chunked_prefix_cache + ) self.current_attention_backend = ( None # Attention backend used by current forward batch @@ -1253,18 +1253,18 @@ class DeepseekV2AttentionMLA(nn.Module): ) -> AttnForwardMethod: # Determine attention backend used by current forward batch if forward_batch.forward_mode.is_decode_or_idle(): - attention_backend = global_server_args_dict["decode_attention_backend"] + attention_backend = get_global_server_args().decode_attention_backend elif ( forward_batch.forward_mode.is_target_verify() or forward_batch.forward_mode.is_draft_extend() ): # Use the specified backend for speculative operations (both verify and draft extend) - if global_server_args_dict["speculative_attention_mode"] == "decode": - attention_backend = global_server_args_dict["decode_attention_backend"] + if get_global_server_args().speculative_attention_mode == "decode": + attention_backend = get_global_server_args().decode_attention_backend else: # default to prefill - attention_backend = global_server_args_dict["prefill_attention_backend"] + attention_backend = get_global_server_args().prefill_attention_backend else: - attention_backend = global_server_args_dict["prefill_attention_backend"] + attention_backend = get_global_server_args().prefill_attention_backend self.current_attention_backend = attention_backend handler = AttentionBackendRegistry.get_handler(attention_backend) @@ -2365,7 +2365,9 @@ class DeepseekV2DecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - self.speculative_algorithm = global_server_args_dict["speculative_algorithm"] + self.speculative_algorithm = SpeculativeAlgorithm.from_string( + get_global_server_args().speculative_algorithm + ) self.layer_id = layer_id self.is_nextn = is_nextn self.self_attn = DeepseekV2AttentionMLA( @@ -2817,7 +2819,7 @@ class DeepseekV2ForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) self.logits_processor = LogitsProcessor(config) @@ -2837,7 +2839,7 @@ class DeepseekV2ForCausalLM(nn.Module): self, architecture: str = "DeepseekV3ForCausalLM" ): self.num_fused_shared_experts = 0 - if global_server_args_dict["disable_shared_experts_fusion"]: + if get_global_server_args().disable_shared_experts_fusion: return # Only Deepseek V3/R1 can use shared experts fusion optimization now. @@ -2856,7 +2858,7 @@ class DeepseekV2ForCausalLM(nn.Module): disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts." if disable_reason is not None: - global_server_args_dict["disable_shared_experts_fusion"] = True + get_global_server_args().disable_shared_experts_fusion = True self.num_fused_shared_experts = 0 log_info_on_rank0( logger, diff --git a/python/sglang/srt/models/falcon_h1.py b/python/sglang/srt/models/falcon_h1.py index f05a395d9..01f07be1d 100644 --- a/python/sglang/srt/models/falcon_h1.py +++ b/python/sglang/srt/models/falcon_h1.py @@ -33,9 +33,9 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, is_cuda, make_layers logger = logging.getLogger(__name__) @@ -483,7 +483,7 @@ class FalconH1ForCausalLM(nn.Module): quant_config=quant_config, org_num_embeddings=config.vocab_size, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) self.lm_head = self.lm_head.float() self.lm_head_multiplier = config.lm_head_multiplier diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index d4cc9e1e6..5080bf88f 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -56,18 +56,13 @@ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.quantization.fp8_kernel import ( - is_fp8_fnuz, - per_tensor_quant_mla_fp8, - per_token_group_quant_mla_deep_gemm_masked_fp8, -) +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -77,6 +72,7 @@ from sglang.srt.models.deepseek_v2 import ( DeepseekV2Model, DeepseekV2MoE, ) +from sglang.srt.server_args import get_global_server_args from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher from sglang.srt.utils import ( BumpAllocator, @@ -395,7 +391,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): self.n_shared_experts = config.n_shared_experts self.num_fused_shared_experts = ( 0 - if global_server_args_dict["disable_shared_experts_fusion"] + if get_global_server_args().disable_shared_experts_fusion else config.n_shared_experts ) self.config = config @@ -432,7 +428,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): self.experts = get_moe_impl_class(quant_config)( num_experts=config.n_routed_experts + self.num_fused_shared_experts - + global_server_args_dict["ep_num_redundant_experts"], + + get_global_server_args().ep_num_redundant_experts, num_fused_shared_experts=self.num_fused_shared_experts, top_k=config.num_experts_per_tok + self.num_fused_shared_experts, hidden_size=config.hidden_size, @@ -476,7 +472,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): self.ep_size = get_moe_expert_parallel_world_size() self.num_experts = ( config.n_routed_experts - + global_server_args_dict["ep_num_redundant_experts"] + + get_global_server_args().ep_num_redundant_experts ) self.renormalize = config.norm_topk_prob self.topk_group = config.topk_group @@ -758,7 +754,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) self.logits_processor = LogitsProcessor(config) @@ -774,7 +770,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): self, architecture: str = "Glm4MoeForCausalLM" ): self.num_fused_shared_experts = 0 - if global_server_args_dict["disable_shared_experts_fusion"]: + if get_global_server_args().disable_shared_experts_fusion: return # Only Deepseek V3/R1 can use shared experts fusion optimization now. @@ -790,7 +786,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism." if disable_reason is not None: - global_server_args_dict["disable_shared_experts_fusion"] = True + get_global_server_args().disable_shared_experts_fusion = True self.num_fused_shared_experts = 0 log_info_on_rank0( logger, diff --git a/python/sglang/srt/models/glm4_moe_nextn.py b/python/sglang/srt/models/glm4_moe_nextn.py index 4816f5775..8697fc1a1 100644 --- a/python/sglang/srt/models/glm4_moe_nextn.py +++ b/python/sglang/srt/models/glm4_moe_nextn.py @@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import BumpAllocator, add_prefix logger = logging.getLogger(__name__) @@ -145,7 +145,7 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM): config.hidden_size, quant_config=quant_config, prefix=add_prefix("model.shared_head.head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/glm4v_moe.py b/python/sglang/srt/models/glm4v_moe.py index fb3d26f11..2688d1225 100644 --- a/python/sglang/srt/models/glm4v_moe.py +++ b/python/sglang/srt/models/glm4v_moe.py @@ -16,10 +16,10 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.glm4_moe import Glm4MoeModel from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0 from sglang.srt.utils.hf_transformers_utils import get_processor @@ -47,7 +47,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): self.determine_num_fused_shared_experts("Glm4MoeForCausalLM") self.num_fused_shared_experts = ( 0 - if global_server_args_dict["disable_shared_experts_fusion"] + if get_global_server_args().disable_shared_experts_fusion else config.n_shared_experts ) @@ -68,7 +68,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) @@ -81,7 +81,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): self, architecture: str = "Glm4MoeForCausalLM" ): self.num_fused_shared_experts = 0 - if global_server_args_dict["disable_shared_experts_fusion"]: + if get_global_server_args().disable_shared_experts_fusion: return # Only Deepseek V3/R1 can use shared experts fusion optimization now. @@ -97,7 +97,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism." if disable_reason is not None: - global_server_args_dict["disable_shared_experts_fusion"] = True + get_global_server_args().disable_shared_experts_fusion = True self.num_fused_shared_experts = 0 log_info_on_rank0( logger, diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 982400514..1f280f37e 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -63,13 +63,13 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.utils import ( create_fused_set_kv_buffer_arg, enable_fused_set_kv_buffer, ) +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( LazyValue, add_prefix, @@ -138,7 +138,7 @@ class GptOssSparseMoeBlock(nn.Module): } self.experts = experts_type( num_experts=config.num_local_experts - + global_server_args_dict["ep_num_redundant_experts"], + + get_global_server_args().ep_num_redundant_experts, top_k=config.num_experts_per_tok, layer_id=layer_id, hidden_size=config.hidden_size, @@ -259,7 +259,7 @@ class GptOssAttention(nn.Module): # Choose dtype of sinks based on attention backend: trtllm_mha requires float32, # others can use bfloat16 - attn_backend = global_server_args_dict.get("attention_backend") + attn_backend = get_global_server_args().attention_backend sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16 self.sinks = nn.Parameter( torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False @@ -591,7 +591,7 @@ class GptOssForCausalLM(nn.Module): config.hidden_size, # quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) self.logits_processor = LogitsProcessor(config) self.capture_aux_hidden_states = False diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index aa4a05713..1f4a3b443 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -28,7 +28,6 @@ from torch import nn from transformers import PretrainedConfig from sglang.srt.distributed import ( - get_moe_expert_parallel_world_size, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, @@ -36,7 +35,6 @@ from sglang.srt.distributed import ( ) from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.elementwise import ( - experts_combine_triton, fused_dual_residual_rmsnorm, fused_rmsnorm, gelu_and_mul_triton, @@ -64,10 +62,10 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file logger = logging.getLogger(__name__) @@ -866,10 +864,10 @@ class Grok1ForCausalLM(nn.Module): # Dump tensors for debugging global debug_tensor_dump_output_folder, debug_tensor_dump_inject - debug_tensor_dump_output_folder = global_server_args_dict[ - "debug_tensor_dump_output_folder" - ] - debug_tensor_dump_inject = global_server_args_dict["debug_tensor_dump_inject"] + debug_tensor_dump_output_folder = ( + get_global_server_args().debug_tensor_dump_output_folder + ) + debug_tensor_dump_inject = get_global_server_args().debug_tensor_dump_inject warnings.filterwarnings("ignore", category=FutureWarning) if get_tensor_model_parallel_rank() == 0: diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 420a9d0f4..dbf6968ee 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -45,13 +45,13 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name, ) +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, make_layers from sglang.utils import get_exception_traceback @@ -433,7 +433,7 @@ class LlamaForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) diff --git a/python/sglang/srt/models/longcat_flash.py b/python/sglang/srt/models/longcat_flash.py index 8af280771..edfadfa0a 100644 --- a/python/sglang/srt/models/longcat_flash.py +++ b/python/sglang/srt/models/longcat_flash.py @@ -32,14 +32,10 @@ import concurrent.futures import logging -import os -from enum import IntEnum, auto -from typing import Any, Dict, Iterable, Optional, Tuple, Union +from typing import Iterable, Optional, Tuple import torch -import torch.nn.functional as F from torch import nn -from tqdm import tqdm from sglang.srt.configs import LongcatFlashConfig from sglang.srt.distributed import ( @@ -85,10 +81,10 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.deepseek_v2 import DeepseekV2AttentionMLA +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( BumpAllocator, LazyValue, @@ -595,7 +591,7 @@ class LongcatFlashForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py index bca9e7cc3..c68c394e2 100644 --- a/python/sglang/srt/models/mllama4.py +++ b/python/sglang/srt/models/mllama4.py @@ -31,9 +31,9 @@ from sglang.srt.managers.schedule_batch import ( Modality, MultimodalDataItem, MultimodalInputs, - global_server_args_dict, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import is_cpu _is_cpu = is_cpu() @@ -448,7 +448,7 @@ class Llama4ForConditionalGeneration(nn.Module): ) self.has_vision = ( - self.has_vision_weights and global_server_args_dict["enable_multimodal"] + self.has_vision_weights and get_global_server_args().enable_multimodal ) if self.has_vision: diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index f00610454..5d84a23af 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -64,10 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.server_args import get_global_server_args from sglang.srt.two_batch_overlap import model_forward_maybe_tbo from sglang.srt.utils import add_prefix, is_cuda, make_layers @@ -156,7 +156,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): layer_id=self.layer_id, top_k=config.num_experts_per_tok, num_experts=config.num_experts - + global_server_args_dict["ep_num_redundant_experts"], + + get_global_server_args().ep_num_redundant_experts, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, quant_config=quant_config, @@ -192,7 +192,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): # TODO: we will support tp < ep in the future self.ep_size = get_moe_expert_parallel_world_size() self.num_experts = ( - config.num_experts + global_server_args_dict["ep_num_redundant_experts"] + config.num_experts + get_global_server_args().ep_num_redundant_experts ) self.top_k = config.num_experts_per_tok @@ -643,7 +643,7 @@ class Qwen2MoeForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) self.logits_processor = LogitsProcessor(config) # For EAGLE3 support diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index c4842416c..9991eb96b 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -54,7 +54,6 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -64,6 +63,7 @@ from sglang.srt.models.utils import ( create_fused_set_kv_buffer_arg, enable_fused_set_kv_buffer, ) +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( add_prefix, is_cuda, @@ -104,7 +104,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): self.experts = get_moe_impl_class(quant_config)( num_experts=config.num_experts - + global_server_args_dict["ep_num_redundant_experts"], + + get_global_server_args().ep_num_redundant_experts, top_k=config.num_experts_per_tok, layer_id=layer_id, hidden_size=config.hidden_size, @@ -125,7 +125,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): # TODO: we will support tp < ep in the future self.ep_size = get_moe_expert_parallel_world_size() self.num_experts = ( - config.num_experts + global_server_args_dict["ep_num_redundant_experts"] + config.num_experts + get_global_server_args().ep_num_redundant_experts ) self.top_k = config.num_experts_per_tok @@ -693,7 +693,7 @@ class Qwen3MoeForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) self.logits_processor = LogitsProcessor(config) self.capture_aux_hidden_states = False diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 2a1b9d48c..1b11aa30b 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -39,7 +39,6 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import ( @@ -47,6 +46,7 @@ from sglang.srt.model_loader.weight_utils import ( sharded_weight_loader, ) from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( LazyValue, add_prefix, @@ -905,7 +905,7 @@ class Qwen3NextForCausalLM(nn.Module): quant_config=quant_config, org_num_embeddings=config.vocab_size, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) self.lm_head = self.lm_head.float() self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/qwen3_next_mtp.py b/python/sglang/srt/models/qwen3_next_mtp.py index b123efcf8..aa0f8ec1e 100644 --- a/python/sglang/srt/models/qwen3_next_mtp.py +++ b/python/sglang/srt/models/qwen3_next_mtp.py @@ -21,14 +21,13 @@ from torch import nn from transformers import PretrainedConfig from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size -from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm +from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.models.qwen3_moe import Qwen3MoeModel from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix logger = logging.getLogger(__name__) @@ -69,7 +68,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM): config.hidden_size, quant_config=quant_config, prefix=add_prefix("model.shared_head.head", prefix), - use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/qwen3_vl_moe.py b/python/sglang/srt/models/qwen3_vl_moe.py index 125114749..507403adb 100644 --- a/python/sglang/srt/models/qwen3_vl_moe.py +++ b/python/sglang/srt/models/qwen3_vl_moe.py @@ -38,20 +38,12 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead -from sglang.srt.managers.mm_utils import ( - MultiModalityDataPaddingPatternMultimodalTokens, - general_mm_embed_routine, -) -from sglang.srt.managers.schedule_batch import ( - MultimodalDataItem, - MultimodalInputs, - global_server_args_dict, -) +from sglang.srt.managers.mm_utils import general_mm_embed_routine +from sglang.srt.managers.schedule_batch import MultimodalDataItem from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel +from sglang.srt.models.qwen3_moe import Qwen3MoeModel from sglang.srt.models.qwen3_vl import ( Qwen3_VisionTransformer, Qwen3VLForConditionalGeneration, diff --git a/python/sglang/srt/models/step3_vl.py b/python/sglang/srt/models/step3_vl.py index 626406da1..14d277f9f 100644 --- a/python/sglang/srt/models/step3_vl.py +++ b/python/sglang/srt/models/step3_vl.py @@ -57,7 +57,6 @@ from sglang.srt.managers.schedule_batch import ( Modality, MultimodalDataItem, MultimodalInputs, - global_server_args_dict, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -300,7 +299,7 @@ class Step3TextDecoderLayer(nn.Module): # self.n_shared_experts = 1 # self.num_fused_shared_experts = ( # 0 - # if global_server_args_dict["disable_shared_experts_fusion"] + # if global_server_args.disable_shared_experts_fusion # else self.n_shared_experts # ) self.num_fused_shared_experts = 0 @@ -774,7 +773,7 @@ class Step3VLForConditionalGeneration(nn.Module): # self.n_shared_experts = 1 # self.num_fused_shared_experts = ( # 0 - # if global_server_args_dict["disable_shared_experts_fusion"] + # if global_server_args.disable_shared_experts_fusion # else self.n_shared_experts # ) self.num_fused_shared_experts = 0 diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index d636ccdd0..cff9419b7 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -2,7 +2,6 @@ from __future__ import annotations import dataclasses import logging -import threading from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple import torch @@ -10,6 +9,7 @@ import torch import sglang.srt.sampling.penaltylib as penaltylib from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.sampling_params import TOP_K_ALL +from sglang.srt.server_args import get_global_server_args if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -66,16 +66,10 @@ class SamplingBatchInfo: # Handle logit bias logit_bias: Optional[torch.Tensor] = None - @classmethod - def _get_global_server_args_dict(cls): - from sglang.srt.managers.schedule_batch import global_server_args_dict - - return global_server_args_dict - @classmethod def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): - global_server_args_dict = cls._get_global_server_args_dict() - enable_deterministic = global_server_args_dict["enable_deterministic_inference"] + global_server_args = get_global_server_args() + enable_deterministic = global_server_args.enable_deterministic_inference reqs = batch.reqs device = batch.device @@ -112,10 +106,9 @@ class SamplingBatchInfo: logit_bias[i, int(key)] = value # Check if any request has custom logit processor - has_custom_logit_processor = global_server_args_dict[ - "enable_custom_logit_processor" - ] and any( # check the flag first. - r.custom_logit_processor for r in reqs + has_custom_logit_processor = ( + global_server_args.enable_custom_logit_processor + and any(r.custom_logit_processor for r in reqs) # check the flag first. ) # then check the requests. if has_custom_logit_processor: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index eeab694d0..e81e6c53b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -55,6 +55,7 @@ from sglang.utils import is_in_ci logger = logging.getLogger(__name__) + # Define constants LOAD_FORMAT_CHOICES = [ "auto", @@ -3357,6 +3358,22 @@ class ServerArgs: ) +# NOTE: This is a global variable to hold the server args for scheduler. +_global_server_args: Optional[ServerArgs] = None + + +def set_global_server_args_for_scheduler(server_args: ServerArgs): + global _global_server_args + _global_server_args = server_args + + +def get_global_server_args() -> ServerArgs: + if _global_server_args is None: + raise ValueError("Global server args is not set yet!") + + return _global_server_args + + def prepare_server_args(argv: List[str]) -> ServerArgs: """ Prepare the server arguments from the command line arguments. @@ -3391,8 +3408,8 @@ def prepare_server_args(argv: List[str]) -> ServerArgs: parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) raw_args = parser.parse_args(argv) - server_args = ServerArgs.from_cli_args(raw_args) - return server_args + + return ServerArgs.from_cli_args(raw_args) ZMQ_TCP_PORT_DELTA = 233 diff --git a/python/sglang/srt/single_batch_overlap.py b/python/sglang/srt/single_batch_overlap.py index b8839c68f..c56af1731 100644 --- a/python/sglang/srt/single_batch_overlap.py +++ b/python/sglang/srt/single_batch_overlap.py @@ -6,7 +6,6 @@ import torch from sglang.srt.layers.moe import get_moe_runner_backend from sglang.srt.layers.moe.utils import is_sbo_enabled from sglang.srt.layers.quantization import deep_gemm_wrapper -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import get_int_env_var diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index d230cf193..ad94a7c5c 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -11,7 +11,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.managers.overlap_utils import FutureIndices -from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict +from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.common import ( alloc_paged_token_slots_extend, @@ -19,6 +19,7 @@ from sglang.srt.mem_cache.common import ( get_last_loc, ) from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode +from sglang.srt.server_args import get_global_server_args from sglang.srt.speculative.eagle_info_v2 import ( EagleDraftInputV2Mixin, EagleVerifyInputV2Mixin, @@ -332,12 +333,8 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): uniform_samples_for_final_sampling=coins_for_final_sampling, target_probs=target_probs, draft_probs=draft_probs, - threshold_single=global_server_args_dict[ - "speculative_accept_threshold_single" - ], - threshold_acc=global_server_args_dict[ - "speculative_accept_threshold_acc" - ], + threshold_single=get_global_server_args().speculative_accept_threshold_single, + threshold_acc=get_global_server_args().speculative_accept_threshold_acc, deterministic=True, ) diff --git a/python/sglang/srt/speculative/eagle_info_v2.py b/python/sglang/srt/speculative/eagle_info_v2.py index 23902a846..982343ce9 100644 --- a/python/sglang/srt/speculative/eagle_info_v2.py +++ b/python/sglang/srt/speculative/eagle_info_v2.py @@ -11,7 +11,6 @@ import triton.language as tl from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import ModelWorkerBatch -from sglang.srt.managers.scheduler import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, @@ -19,6 +18,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardMode, ) from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.server_args import get_global_server_args from sglang.srt.speculative.build_eagle_tree import TreeMaskMode from sglang.srt.speculative.spec_utils import ( SIMULATE_ACC_LEN, @@ -265,12 +265,8 @@ class EagleVerifyInputV2Mixin: uniform_samples_for_final_sampling=coins_for_final_sampling, target_probs=target_probs, draft_probs=draft_probs, - threshold_single=global_server_args_dict[ - "speculative_accept_threshold_single" - ], - threshold_acc=global_server_args_dict[ - "speculative_accept_threshold_acc" - ], + threshold_single=get_global_server_args().speculative_accept_threshold_single, + threshold_acc=get_global_server_args().speculative_accept_threshold_acc, deterministic=True, ) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 162ce53ec..f501f9d8b 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -14,7 +14,7 @@ from sglang.srt.distributed import ( ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs -from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict +from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.mem_cache.common import ( @@ -27,7 +27,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, ) -from sglang.srt.server_args import ServerArgs +from sglang.srt.server_args import ServerArgs, get_global_server_args from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( EAGLEDraftCudaGraphRunner, @@ -261,7 +261,7 @@ class EAGLEWorker(TpModelWorker): ) def _create_flashinfer_decode_backend(self): - if not global_server_args_dict["use_mla_backend"]: + if not get_global_server_args().use_mla_backend: from sglang.srt.layers.attention.flashinfer_backend import ( FlashInferMultiStepDraftBackend, ) @@ -325,7 +325,7 @@ class EAGLEWorker(TpModelWorker): ) def _create_trtllm_mla_decode_backend(self): - if not global_server_args_dict["use_mla_backend"]: + if not get_global_server_args().use_mla_backend: raise ValueError( "trtllm_mla backend requires MLA model (use_mla_backend=True)." ) @@ -340,7 +340,7 @@ class EAGLEWorker(TpModelWorker): ) def _create_flashinfer_prefill_backend(self): - if not global_server_args_dict["use_mla_backend"]: + if not get_global_server_args().use_mla_backend: from sglang.srt.layers.attention.flashinfer_backend import ( FlashInferAttnBackend, ) @@ -376,7 +376,7 @@ class EAGLEWorker(TpModelWorker): return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False) def _create_trtllm_mla_prefill_backend(self): - if not global_server_args_dict["use_mla_backend"]: + if not get_global_server_args().use_mla_backend: raise ValueError( "trtllm_mla backend requires MLA model (use_mla_backend=True)." ) diff --git a/python/sglang/srt/speculative/ngram_info.py b/python/sglang/srt/speculative/ngram_info.py index ce4557b89..f0d152ab4 100644 --- a/python/sglang/srt/speculative/ngram_info.py +++ b/python/sglang/srt/speculative/ngram_info.py @@ -7,6 +7,8 @@ from typing import Optional, Tuple import torch import triton +from sglang.srt.server_args import get_global_server_args + logger = logging.getLogger(__name__) from dataclasses import dataclass @@ -16,7 +18,7 @@ import torch.nn.functional as F from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import apply_custom_logit_processor -from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict +from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.mem_cache.common import ( alloc_paged_token_slots_extend, alloc_token_slots, @@ -350,10 +352,8 @@ class NgramVerifyInput(SpecInput): uniform_samples_for_final_sampling=coins_for_final_sampling, target_probs=target_probs, draft_probs=draft_probs, - threshold_single=global_server_args_dict[ - "speculative_accept_threshold_single" - ], - threshold_acc=global_server_args_dict["speculative_accept_threshold_acc"], + threshold_single=get_global_server_args().speculative_accept_threshold_single, + threshold_acc=get_global_server_args().speculative_accept_threshold_acc, deterministic=True, ) diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 61e45440b..a5485e8e9 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -22,7 +22,7 @@ from sglang.srt.layers.moe import ( ) from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.quantization import deep_gemm_wrapper -from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict +from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, @@ -30,6 +30,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ) from sglang.srt.operations import execute_operations, execute_overlapped_operations from sglang.srt.operations_strategy import OperationsStrategy +from sglang.srt.server_args import get_global_server_args from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip @@ -153,7 +154,7 @@ def _update_device_and_sum_field_from_cpu_field( cpu_value if isinstance(cpu_value, torch.Tensor) else torch.tensor(cpu_value, dtype=old_device_value.dtype) - ).to(device=global_server_args_dict["device"], non_blocking=True) + ).to(device=get_global_server_args().device, non_blocking=True) setattr(batch, device_field, new_device_value) if sum_field is not None: @@ -582,7 +583,7 @@ class TboForwardBatchPreparer: sum_field=None, ) _, child_b.extend_start_loc = compute_position( - global_server_args_dict["attention_backend"], + get_global_server_args().attention_backend, child_b.extend_prefix_lens, child_b.extend_seq_lens, child_b.extend_num_tokens, @@ -687,7 +688,7 @@ class TboForwardBatchPreparer: # TODO improve, e.g. unify w/ `init_raw` if ( - global_server_args_dict["moe_dense_tp_size"] == 1 + get_global_server_args().moe_dense_tp_size == 1 and batch.global_dp_buffer_len is not None ): sum_len = end_token_index - start_token_index @@ -755,7 +756,7 @@ class TboForwardBatchPreparer: value_a = min(tbo_split_token_index, num_token_non_padded) value_b = max(0, num_token_non_padded - tbo_split_token_index) return torch.tensor([value_a, value_b], dtype=torch.int32).to( - device=global_server_args_dict["device"], non_blocking=True + device=get_global_server_args().device, non_blocking=True ) @classmethod diff --git a/test/srt/rl/test_fp32_lm_head.py b/test/srt/rl/test_fp32_lm_head.py index e892e3151..dea43995b 100644 --- a/test/srt/rl/test_fp32_lm_head.py +++ b/test/srt/rl/test_fp32_lm_head.py @@ -7,7 +7,11 @@ import torch.nn as nn import torch.nn.functional as F from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.server_args import ( + ServerArgs, + get_global_server_args, + set_global_server_args_for_scheduler, +) class LMHeadStub(nn.Module): @@ -32,8 +36,10 @@ class TestLMHeadFP32(unittest.TestCase): raise unittest.SkipTest("needs CUDA GPU") def _make_logprocessor(self, vocab_size, enable_fp32): - global_server_args_dict["enable_dp_lm_head"] = False - global_server_args_dict["enable_fp32_lm_head"] = enable_fp32 + ServerArgs.__post_init__ = lambda self: None # disable validation + set_global_server_args_for_scheduler(ServerArgs(model_path="dummy")) + get_global_server_args().enable_dp_lm_head = False + get_global_server_args().enable_fp32_lm_head = enable_fp32 cfg = SimpleNamespace(vocab_size=vocab_size, final_logit_softcapping=None) return LogitsProcessor(cfg, skip_all_gather=True, logit_scale=None) diff --git a/test/srt/test_gptqmodel_dynamic.py b/test/srt/test_gptqmodel_dynamic.py index 9be711d12..ea141df3e 100644 --- a/test/srt/test_gptqmodel_dynamic.py +++ b/test/srt/test_gptqmodel_dynamic.py @@ -4,6 +4,7 @@ import unittest import requests import torch +from sglang.srt.server_args import set_global_server_args_for_scheduler from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -16,17 +17,15 @@ from sglang.test.test_utils import ( def check_quant_method(model_path: str, use_marlin_kernel: bool): from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig - from sglang.srt.configs.model_config import AttentionArch, ModelConfig + from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import ( - get_tp_group, init_distributed_environment, initialize_model_parallel, - set_custom_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state from sglang.srt.layers.quantization.utils import get_dynamic_override from sglang.srt.model_loader import get_model - from sglang.srt.server_args import PortArgs, ServerArgs + from sglang.srt.server_args import ServerArgs try: init_distributed_environment( @@ -43,6 +42,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool): pass server_args = ServerArgs(model_path=model_path, dtype=torch.float16) + set_global_server_args_for_scheduler(server_args) model_config = ModelConfig.from_server_args(server_args) load_config = LoadConfig()