diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index fcd65b5ed..552351466 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -6,6 +6,9 @@ class GlobalConfig: """ Store some global constants. + + See also python/sglang/srt/managers/schedule_batch.py::global_server_args_dict, which stores + many global runtime arguments as well. """ def __init__(self): diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py b/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py index 9ce1c1c20..d7274cf2c 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py @@ -5,7 +5,7 @@ from packaging import version from torch.cuda.memory import CUDAPluggableAllocator from sglang.srt.distributed.parallel_state import GroupCoordinator -from sglang.srt.server_args import get_global_server_args +from sglang.srt.managers.schedule_batch import global_server_args_dict nccl_allocator_source = """ #include @@ -32,7 +32,7 @@ _graph_pool_id = None def is_symmetric_memory_enabled(): - return get_global_server_args().enable_symm_mem + return global_server_args_dict["enable_symm_mem"] def set_graph_pool_id(graph_pool_id): diff --git a/python/sglang/srt/eplb/expert_location_dispatch.py b/python/sglang/srt/eplb/expert_location_dispatch.py index 7ac03390a..624dc3fd1 100644 --- a/python/sglang/srt/eplb/expert_location_dispatch.py +++ b/python/sglang/srt/eplb/expert_location_dispatch.py @@ -18,7 +18,7 @@ from typing import Literal, Optional import torch from sglang.srt.eplb.expert_location import get_global_expert_location_metadata -from sglang.srt.server_args import get_global_server_args +from sglang.srt.managers.schedule_batch import global_server_args_dict @dataclass @@ -34,7 +34,7 @@ class ExpertLocationDispatchInfo: @classmethod def init_new(cls, layer_id: int): - ep_dispatch_algorithm = get_global_server_args().ep_dispatch_algorithm + ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"] expert_location_metadata = get_global_expert_location_metadata() assert expert_location_metadata is not None diff --git a/python/sglang/srt/eplb/expert_location_updater.py b/python/sglang/srt/eplb/expert_location_updater.py index 286f1d0e3..772e65f18 100644 --- a/python/sglang/srt/eplb/expert_location_updater.py +++ b/python/sglang/srt/eplb/expert_location_updater.py @@ -24,7 +24,7 @@ from sglang.srt.eplb.expert_location import ( ExpertLocationMetadata, get_global_expert_location_metadata, ) -from sglang.srt.server_args import get_global_server_args +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import get_bool_env_var logger = logging.getLogger(__name__) @@ -97,7 +97,7 @@ def _update_expert_weights_with_canary( canary_tensor = ( _get_canary_value(old_expert_location_metadata, layer_id) .clone() - .to(device=get_global_server_args().device, non_blocking=True) + .to(device=global_server_args_dict["device"], non_blocking=True) ) routed_experts_weights_of_layer[layer_id].append(canary_tensor) diff --git a/python/sglang/srt/layers/attention/double_sparsity_backend.py b/python/sglang/srt/layers/attention/double_sparsity_backend.py index 76a63a093..47b867f61 100644 --- a/python/sglang/srt/layers/attention/double_sparsity_backend.py +++ b/python/sglang/srt/layers/attention/double_sparsity_backend.py @@ -5,8 +5,8 @@ from typing import TYPE_CHECKING import torch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.server_args import get_global_server_args if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention @@ -42,7 +42,7 @@ class DoubleSparseAttnBackend(AttentionBackend): # TODO: Change the hard-coded block_seq_num self.BLOCK_SEQ = 128 - if get_global_server_args().triton_attention_reduce_in_fp32: + if global_server_args_dict.get("triton_attention_reduce_in_fp32", False): self.reduce_dtype = torch.float32 else: self.reduce_dtype = torch.float16 diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 4fae2cb1d..927f1d93c 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -11,8 +11,8 @@ import triton.language as tl from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.radix_attention import AttentionType +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.server_args import get_global_server_args from sglang.srt.speculative.spec_info import SpecInput if TYPE_CHECKING: @@ -830,7 +830,7 @@ class FlashAttentionBackend(AttentionBackend): ): # Do multi-head attention with chunked prefix cache if forward_batch.attn_attend_prefix_cache: - assert not get_global_server_args().disable_chunked_prefix_cache + assert not global_server_args_dict["disable_chunked_prefix_cache"] # MHA for chunked prefix kv cache when running model with MLA assert forward_batch.prefix_chunk_idx is not None assert forward_batch.prefix_chunk_cu_seq_lens is not None diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 82d1b05b4..6efda7775 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -28,8 +28,8 @@ from sglang.srt.layers.attention.flashinfer_backend import ( create_flashinfer_kv_indices_triton, ) from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.server_args import get_global_server_args from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.utils import ( is_flashinfer_available, @@ -193,9 +193,9 @@ class FlashInferMLAAttnBackend(AttentionBackend): self.skip_prefill = skip_prefill self.enable_chunk_kv = ( not skip_prefill - and get_global_server_args().disaggregation_mode != "decode" - and not get_global_server_args().disable_chunked_prefix_cache - and not get_global_server_args().flashinfer_mla_disable_ragged + and global_server_args_dict["disaggregation_mode"] != "decode" + and not global_server_args_dict["disable_chunked_prefix_cache"] + and not global_server_args_dict["flashinfer_mla_disable_ragged"] ) self.page_size = model_runner.page_size @@ -306,7 +306,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): prefix_lens = forward_batch.extend_prefix_lens extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) use_ragged = ( - not get_global_server_args().flashinfer_mla_disable_ragged + not global_server_args_dict["flashinfer_mla_disable_ragged"] and extend_no_prefix ) diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 13ab6b172..798e1c0a8 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -23,9 +23,9 @@ from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.rotary_embedding import get_rope_wrapper +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.server_args import get_global_server_args if TYPE_CHECKING: from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool @@ -162,7 +162,7 @@ class Indexer(CustomOp): base=rope_theta, # type: ignore rope_scaling=rope_scaling, is_neox_style=False, - device=get_global_server_args().device, + device=global_server_args_dict["device"], ) self.block_size = block_size self.scale_fmt = scale_fmt diff --git a/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py b/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py index 8e972f408..72e0bfe78 100644 --- a/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py @@ -2,7 +2,7 @@ import torch import triton import triton.language as tl -from sglang.srt.server_args import get_global_server_args +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import is_cuda, is_hip _is_cuda = is_cuda() @@ -11,7 +11,7 @@ if _is_cuda: _is_hip = is_hip() -if get_global_server_args().triton_attention_reduce_in_fp32: +if global_server_args_dict.get("attention_reduce_in_fp32", False): REDUCE_TRITON_TYPE = tl.float32 REDUCE_TORCH_TYPE = torch.float32 else: diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 3727524ef..85e535b07 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -20,8 +20,8 @@ from sglang.srt.layers.attention.utils import ( create_flashmla_kv_indices_triton, ) from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import is_cuda, is_flashinfer_available if is_flashinfer_available(): @@ -123,9 +123,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None - self.disable_chunked_prefix_cache = ( - get_global_server_args().disable_chunked_prefix_cache - ) + self.disable_chunked_prefix_cache = global_server_args_dict[ + "disable_chunked_prefix_cache" + ] self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index bb9016e0d..489b8248b 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -45,7 +45,7 @@ from sglang.srt.layers.linear import ( ) from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb -from sglang.srt.server_args import get_global_server_args +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import add_prefix ROTARY_EMBED_CLASSES = { @@ -468,7 +468,7 @@ class VisionAttention(nn.Module): _passed_backend = qkv_backend qkv_backend = self._determine_attention_backend(_passed_backend) if ( - get_global_server_args().mm_attention_backend is None + global_server_args_dict["mm_attention_backend"] is None and _passed_backend is None ): print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.") @@ -528,7 +528,7 @@ class VisionAttention(nn.Module): - CUDA: "triton_attn" - Non-CUDA: "sdpa" """ - override_backend = get_global_server_args().mm_attention_backend + override_backend = global_server_args_dict["mm_attention_backend"] if override_backend is not None: backend = override_backend elif passed_backend is not None: diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 60b4e9e5f..e050da91d 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -40,9 +40,8 @@ from sglang.srt.layers.moe import ( get_moe_a2a_backend, should_use_flashinfer_cutlass_moe_fp4_allgather, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.server_args import get_global_server_args -from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( get_bool_env_var, is_cuda, @@ -169,7 +168,7 @@ class LayerScatterModes: def enable_moe_dense_fully_dp(): - return get_global_server_args().moe_dense_tp_size == 1 + return global_server_args_dict["moe_dense_tp_size"] == 1 class LayerCommunicator: @@ -315,9 +314,7 @@ class LayerCommunicator: def should_fuse_mlp_allreduce_with_next_layer( self, forward_batch: ForwardBatch ) -> bool: - speculative_algo = SpeculativeAlgorithm.from_string( - get_global_server_args().speculative_algorithm - ) + speculative_algo = global_server_args_dict.get("speculative_algorithm", None) if ( is_dp_attention_enabled() and speculative_algo is not None @@ -336,7 +333,7 @@ class LayerCommunicator: static_conditions_met = ( (not self.is_last_layer) and (self._context.tp_size > 1) - and get_global_server_args().enable_flashinfer_allreduce_fusion + and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False) and _is_flashinfer_available ) @@ -534,7 +531,7 @@ class CommunicateWithAllReduceAndLayerNormFn: (_is_sm100_supported or _is_sm90_supported) and _is_flashinfer_available and hasattr(layernorm, "forward_with_allreduce_fusion") - and get_global_server_args().enable_flashinfer_allreduce_fusion + and global_server_args_dict["enable_flashinfer_allreduce_fusion"] and hidden_states.shape[0] <= 4096 ): hidden_states, residual = layernorm.forward_with_allreduce_fusion( diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index a0cf55b0e..dfacd858c 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -38,15 +38,17 @@ from sglang.srt.layers.dp_attention import ( get_dp_device, get_dp_dtype, get_dp_hidden_size, + get_global_dp_buffer, get_local_attention_dp_size, + set_dp_buffer_len, ) from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, ForwardMode, ) -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend logger = logging.getLogger(__name__) @@ -228,8 +230,8 @@ class LogitsProcessor(nn.Module): super().__init__() self.config = config self.logit_scale = logit_scale - self.use_attn_tp_group = get_global_server_args().enable_dp_lm_head - self.use_fp32_lm_head = get_global_server_args().enable_fp32_lm_head + self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"] + self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"] if self.use_attn_tp_group: self.attn_tp_size = get_attention_tp_size() self.do_tensor_parallel_all_gather = ( @@ -252,8 +254,8 @@ class LogitsProcessor(nn.Module): ): self.final_logit_softcapping = None - self.debug_tensor_dump_output_folder = ( - get_global_server_args().debug_tensor_dump_output_folder + self.debug_tensor_dump_output_folder = global_server_args_dict.get( + "debug_tensor_dump_output_folder", None ) def compute_logprobs_for_multi_item_scoring( @@ -370,7 +372,9 @@ class LogitsProcessor(nn.Module): logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) # Check if multi-item scoring is enabled via server args (only for prefill-only requests) - multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter + multi_item_delimiter = global_server_args_dict.get( + "multi_item_scoring_delimiter" + ) if multi_item_delimiter is not None and logits_metadata.is_prefill_only: return self.compute_logprobs_for_multi_item_scoring( input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 1ff778184..9cdfbc86c 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -27,10 +27,12 @@ from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, QuantizationConfig, + QuantizeMethodBase, ) from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight from sglang.srt.utils import ( cpu_has_amx_support, diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index 76757e501..caf323950 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, ) from sglang.srt.layers.quantization.utils import is_layer_skipped -from sglang.srt.server_args import get_global_server_args +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import ( direct_register_custom_op, is_cuda, @@ -265,9 +265,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() self.with_bias = False self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4() - self.flashinfer_mxfp4_moe_precision = ( - get_global_server_args().flashinfer_mxfp4_moe_precision - ) + self.flashinfer_mxfp4_moe_precision = global_server_args_dict[ + "flashinfer_mxfp4_moe_precision" + ] self.triton_kernel_moe_forward = None self.triton_kernel_moe_with_bias_forward = None diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index bf50d4b11..f2deb2b26 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -11,8 +11,8 @@ from sglang.srt.layers.dp_attention import ( is_dp_attention_enabled, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda if is_cuda(): @@ -33,7 +33,7 @@ RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB") class Sampler(nn.Module): def __init__(self): super().__init__() - self.use_nan_detection = get_global_server_args().enable_nan_detection + self.use_nan_detection = global_server_args_dict["enable_nan_detection"] self.tp_sync_group = get_tp_group().device_group if is_dp_attention_enabled(): @@ -103,7 +103,7 @@ class Sampler(nn.Module): del logits if True: # Keep this redundant check to simplify some internal code sync - if get_global_server_args().sampling_backend == "flashinfer": + if global_server_args_dict["sampling_backend"] == "flashinfer": if sampling_info.need_min_p_sampling: probs = top_k_renorm_prob(probs, sampling_info.top_ks) probs = top_p_renorm_prob(probs, sampling_info.top_ps) @@ -118,7 +118,7 @@ class Sampler(nn.Module): filter_apply_order="joint", check_nan=self.use_nan_detection, ) - elif get_global_server_args().sampling_backend == "pytorch": + elif global_server_args_dict["sampling_backend"] == "pytorch": # A slower fallback implementation with torch native operations. batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( probs, @@ -131,7 +131,7 @@ class Sampler(nn.Module): ) else: raise ValueError( - f"Invalid sampling backend: {get_global_server_args().sampling_backend}" + f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" ) if return_logprob: diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index e2012e9de..41de295af 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -16,10 +16,10 @@ from sglang.srt.managers.schedule_batch import ( Modality, MultimodalDataItem, MultimodalInputs, + global_server_args_dict, ) from sglang.srt.mem_cache.multimodal_cache import MultiModalCache from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once from sglang.utils import logger @@ -428,7 +428,7 @@ def _adjust_embedding_length( f"tokens from multimodal embeddings." ) if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding: - chunked_prefill_size = get_global_server_args().chunked_prefill_size + chunked_prefill_size = global_server_args_dict["chunked_prefill_size"] if chunked_prefill_size != -1: logger.warning( "You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill" diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 5205f93a9..f0b03638f 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -72,7 +72,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams -from sglang.srt.server_args import ServerArgs, get_global_server_args +from sglang.srt.server_args import ServerArgs from sglang.srt.utils import flatten_nested_list from sglang.srt.utils.common import next_power_of_2 @@ -82,6 +82,47 @@ if TYPE_CHECKING: INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 +GLOBAL_SERVER_ARGS_KEYS = [ + "attention_backend", + "mm_attention_backend", + "debug_tensor_dump_inject", + "debug_tensor_dump_output_folder", + "chunked_prefill_size", + "device", + "disable_chunked_prefix_cache", + "disable_flashinfer_cutlass_moe_fp4_allgather", + "disable_radix_cache", + "enable_dp_lm_head", + "enable_fp32_lm_head", + "flashinfer_mxfp4_moe_precision", + "enable_flashinfer_allreduce_fusion", + "moe_dense_tp_size", + "ep_dispatch_algorithm", + "ep_num_redundant_experts", + "enable_nan_detection", + "flashinfer_mla_disable_ragged", + "pp_max_micro_batch_size", + "disable_shared_experts_fusion", + "sampling_backend", + "speculative_accept_threshold_single", + "speculative_accept_threshold_acc", + "speculative_attention_mode", + "torchao_config", + "triton_attention_reduce_in_fp32", + "num_reserved_decode_tokens", + "weight_loader_disable_mmap", + "enable_multimodal", + "enable_symm_mem", + "enable_custom_logit_processor", + "disaggregation_mode", + "enable_deterministic_inference", + "nsa_prefill", + "nsa_decode", + "multi_item_scoring_delimiter", +] + +# Put some global args for easy access +global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS} logger = logging.getLogger(__name__) @@ -642,9 +683,12 @@ class Req: def is_prefill_only(self) -> bool: """Check if this request is prefill-only (no token generation needed).""" # NOTE: when spec is enabled, prefill_only optimizations are disabled + from sglang.srt.speculative.spec_info import SpeculativeAlgorithm - spec_alg = get_global_server_args().speculative_algorithm - return self.sampling_params.max_new_tokens == 0 and spec_alg is None + spec_alg = global_server_args_dict["speculative_algorithm"] + return self.sampling_params.max_new_tokens == 0 and ( + spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE + ) def add_latency(self, stage: RequestStage): if self.metrics_collector is None: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a6ce39a4e..ea7b8222b 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -122,6 +122,7 @@ from sglang.srt.managers.schedule_batch import ( Req, RequestStage, ScheduleBatch, + global_server_args_dict, ) from sglang.srt.managers.schedule_policy import ( AddReqResult, @@ -149,7 +150,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.parser.reasoning_parser import ReasoningParser -from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args +from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.eagle_info import EagleDraftInput from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.tracing.trace import ( @@ -446,12 +447,13 @@ class Scheduler( self.max_req_input_len, self.random_seed, self.device, + worker_global_server_args_dict, _, _, _, ) = self.tp_worker.get_worker_info() - if get_global_server_args().pp_max_micro_batch_size is None: - get_global_server_args().pp_max_micro_batch_size = max( + if global_server_args_dict["pp_max_micro_batch_size"] is None: + global_server_args_dict["pp_max_micro_batch_size"] = max( self.max_running_requests // server_args.pp_size, 1 ) @@ -463,6 +465,7 @@ class Scheduler( self.world_group = get_world_group() self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func() + global_server_args_dict.update(worker_global_server_args_dict) set_random_seed(self.random_seed) # Hybrid memory pool @@ -1863,7 +1866,7 @@ class Scheduler( return ret def get_num_allocatable_reqs(self, running_bs): - res = get_global_server_args().pp_max_micro_batch_size - running_bs + res = global_server_args_dict["pp_max_micro_batch_size"] - running_bs if self.pp_size > 1: res = min(res, self.req_to_token_pool.available_size()) return res @@ -2607,7 +2610,7 @@ class Scheduler( ) def get_internal_state(self, recv_req: GetInternalStateReq): - ret = vars(get_global_server_args()) + ret = dict(global_server_args_dict) ret["last_gen_throughput"] = self.last_gen_throughput ret["memory_usage"] = { "weight": round( @@ -2663,11 +2666,11 @@ class Scheduler( logger.info(f"{avg_spec_accept_length=}") self.cum_spec_accept_length = self.cum_spec_accept_count = 0 for k, v in server_args_dict.items(): - setattr(get_global_server_args(), k, v) - logger.info(f"Global server args updated! {get_global_server_args()=}") + global_server_args_dict[k] = v + logger.info(f"Global server args updated! {global_server_args_dict=}") return SetInternalStateReqOutput( updated=True, - server_args=vars(get_global_server_args()), + server_args=global_server_args_dict, ) def handle_rpc_request(self, recv_req: RpcReqInput): diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 267810c06..52a40a371 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import ( UpdateWeightsFromDistributedReqInput, UpdateWeightsFromTensorReqInput, ) -from sglang.srt.managers.schedule_batch import ModelWorkerBatch +from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool @@ -190,6 +190,7 @@ class TpModelWorker: self.max_req_input_len, self.random_seed, self.device, + global_server_args_dict, self.model_runner.req_to_token_pool.size, self.model_runner.req_to_token_pool.max_context_len, self.model_runner.token_to_kv_pool.size, diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index 7dcd69410..040bc45bf 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -11,7 +11,7 @@ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool -from sglang.srt.server_args import get_global_server_args +from sglang.srt.server_args import ServerArgs from sglang.srt.utils import support_triton if TYPE_CHECKING: @@ -19,6 +19,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +GLOBAL_SERVER_ARGS_KEYS = ["attention_backend"] + +global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS} + @triton.jit def write_req_to_token_pool_triton( @@ -84,7 +88,7 @@ def write_cache_indices( prefix_tensors: list[torch.Tensor], req_to_token_pool: ReqToTokenPool, ): - if support_triton(get_global_server_args().attention_backend): + if support_triton(global_server_args_dict.get("attention_backend")): prefix_pointers = torch.tensor( [t.data_ptr() for t in prefix_tensors], device=req_to_token_pool.device, @@ -125,8 +129,8 @@ def get_last_loc( prefix_lens_tensor: torch.Tensor, ) -> torch.Tensor: if ( - get_global_server_args().attention_backend != "ascend" - and get_global_server_args().attention_backend != "torch_native" + global_server_args_dict["attention_backend"] != "ascend" + and global_server_args_dict["attention_backend"] != "torch_native" ): impl = get_last_loc_triton else: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d7c156fc8..fea4a49ef 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -83,6 +83,10 @@ from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.lora.lora_registry import LoRARef +from sglang.srt.managers.schedule_batch import ( + GLOBAL_SERVER_ARGS_KEYS, + global_server_args_dict, +) from sglang.srt.mem_cache.allocator import ( BaseTokenToKVPoolAllocator, PagedTokenToKVPoolAllocator, @@ -121,11 +125,7 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo -from sglang.srt.server_args import ( - ServerArgs, - get_global_server_args, - set_global_server_args_for_scheduler, -) +from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( MultiprocessingSerializer, @@ -275,12 +275,15 @@ class ModelRunner: # Model-specific adjustment self.model_specific_adjustment() - # Set the global server_args in the scheduler process - set_global_server_args_for_scheduler(server_args) - global_server_args = get_global_server_args() - - # FIXME: hacky set `use_mla_backend` - global_server_args.use_mla_backend = self.use_mla_backend + # Global vars + global_server_args_dict.update( + {k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS} + | { + # TODO it is indeed not a "server args" + "use_mla_backend": self.use_mla_backend, + "speculative_algorithm": self.spec_algorithm, + } + ) # Init OpenMP threads binding for CPU if self.device == "cpu": @@ -429,7 +432,7 @@ class ModelRunner: # In layered loading, torchao may have been applied if not torchao_applied: apply_torchao_config_to_model( - self.model, get_global_server_args().torchao_config + self.model, global_server_args_dict["torchao_config"] ) # Apply torch TP if the model supports it @@ -1835,10 +1838,12 @@ class ModelRunner: self.server_args.attention_backend ) - ( - get_global_server_args().prefill_attention_backend, - get_global_server_args().decode_attention_backend, - ) = (self.prefill_attention_backend_str, self.decode_attention_backend_str) + global_server_args_dict.update( + { + "decode_attention_backend": self.decode_attention_backend_str, + "prefill_attention_backend": self.prefill_attention_backend_str, + } + ) return attn_backend def _get_attention_backend_from_str(self, backend_str: str): diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 691a23b64..de58a8dd7 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -4,6 +4,7 @@ from __future__ import annotations # ruff: noqa: SIM117 import collections +import concurrent import dataclasses import fnmatch import glob @@ -11,10 +12,12 @@ import json import logging import math import os +import re import socket import threading import time from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager, suppress from typing import ( TYPE_CHECKING, @@ -30,10 +33,10 @@ from typing import ( import huggingface_hub import numpy as np +import requests +import safetensors.torch import torch -from sglang.srt.server_args import get_global_server_args - # Try to import accelerate (optional dependency) try: from accelerate import infer_auto_device_map, init_empty_weights @@ -78,6 +81,8 @@ DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = ( 0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration ) from sglang.srt.model_loader.weight_utils import ( + _BAR_FORMAT, + default_weight_loader, download_safetensors_index_file_from_hf, download_weights_from_hf, filter_duplicate_safetensors_files, @@ -440,8 +445,10 @@ class DefaultModelLoader(BaseModelLoader): hf_weights_files, ) elif use_safetensors: - weight_loader_disable_mmap = ( - get_global_server_args().weight_loader_disable_mmap + from sglang.srt.managers.schedule_batch import global_server_args_dict + + weight_loader_disable_mmap = global_server_args_dict.get( + "weight_loader_disable_mmap" ) if extra_config.get("enable_multithread_load"): @@ -609,9 +616,9 @@ class LayeredModelLoader(DefaultModelLoader): device_config: DeviceConfig, ) -> nn.Module: from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model - from sglang.srt.server_args import get_global_server_args + from sglang.srt.managers.schedule_batch import global_server_args_dict - torchao_config = get_global_server_args().torchao_config + torchao_config = global_server_args_dict.get("torchao_config") target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): diff --git a/python/sglang/srt/models/apertus.py b/python/sglang/srt/models/apertus.py index ca84264b9..161cf1062 100644 --- a/python/sglang/srt/models/apertus.py +++ b/python/sglang/srt/models/apertus.py @@ -46,14 +46,15 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name, ) -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, make_layers +from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -446,7 +447,7 @@ class ApertusForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) diff --git a/python/sglang/srt/models/arcee.py b/python/sglang/srt/models/arcee.py index 5afd5f34f..f9ebfe19a 100644 --- a/python/sglang/srt/models/arcee.py +++ b/python/sglang/srt/models/arcee.py @@ -42,13 +42,13 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name, ) -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, make_layers logger = logging.getLogger(__name__) @@ -407,7 +407,7 @@ class ArceeForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) diff --git a/python/sglang/srt/models/bailing_moe.py b/python/sglang/srt/models/bailing_moe.py index 2cb7d5961..23313cb42 100644 --- a/python/sglang/srt/models/bailing_moe.py +++ b/python/sglang/srt/models/bailing_moe.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""SGLang BailingMoE model.""" +""" SGLang BailingMoE model.""" import logging from typing import Any, Dict, Iterable, Optional, Tuple, Union @@ -68,6 +68,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -75,7 +76,6 @@ from sglang.srt.models.utils import ( create_fused_set_kv_buffer_arg, enable_fused_set_kv_buffer, ) -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers LoraConfig = None @@ -204,8 +204,8 @@ class BailingMoESparseMoeBlock(nn.Module): else: self.router_dtype = torch.bfloat16 - # TODO global_server_args.ep_num_redundant_experts is used for eplb, not supported now - assert get_global_server_args().ep_num_redundant_experts == 0 + # TODO global_server_args_dict["ep_num_redundant_experts"] is used for eplb, not supported now + assert global_server_args_dict["ep_num_redundant_experts"] == 0 # check group topk self.num_expert_group = getattr(config, "n_group", 0) self.topk_group = getattr(config, "topk_group", 0) @@ -220,7 +220,7 @@ class BailingMoESparseMoeBlock(nn.Module): self.use_grouped_topk = False self.num_experts = ( - config.num_experts + get_global_server_args().ep_num_redundant_experts + config.num_experts + global_server_args_dict["ep_num_redundant_experts"] ) self.gate = BailingMoEGate( @@ -824,7 +824,7 @@ class BailingMoEForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/bailing_moe_nextn.py b/python/sglang/srt/models/bailing_moe_nextn.py index 76a24f4c9..49198001c 100644 --- a/python/sglang/srt/models/bailing_moe_nextn.py +++ b/python/sglang/srt/models/bailing_moe_nextn.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""SGLang BailingMoENextN model.""" +""" SGLang BailingMoENextN model.""" import logging from typing import Iterable, Optional, Tuple @@ -29,14 +29,15 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.dp_attention import is_dp_attention_enabled from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.bailing_moe import BailingMoEBlock, BailingMoEForCausalLM -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix LoraConfig = None @@ -144,7 +145,7 @@ class BailingMoeForCausalLMNextN(BailingMoEForCausalLM): config.hidden_size, quant_config=quant_config, prefix=add_prefix("model.shared_head.head", prefix), - use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index a01f386da..0914ead19 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda logger = logging.getLogger(__name__) @@ -152,7 +152,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): config.hidden_size, quant_config=quant_config, prefix=add_prefix("model.shared_head.head", prefix), - use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index bb773a8ad..e66dd4a1f 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -35,6 +35,7 @@ from sglang.srt.configs.model_config import ( get_nsa_index_topk, is_deepseek_nsa, ) +from sglang.srt.debug_utils.dumper import dumper from sglang.srt.distributed import ( get_moe_expert_parallel_world_size, get_pp_group, @@ -107,11 +108,10 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.server_args import get_global_server_args from sglang.srt.single_batch_overlap import SboFlags -from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.two_batch_overlap import ( MaybeTboDeepEPDispatcher, model_forward_maybe_tbo, @@ -520,7 +520,7 @@ class DeepseekV2MoE(nn.Module): self.n_shared_experts = config.n_shared_experts self.num_fused_shared_experts = ( 0 - if get_global_server_args().disable_shared_experts_fusion + if global_server_args_dict["disable_shared_experts_fusion"] else config.n_shared_experts ) self.config = config @@ -549,7 +549,7 @@ class DeepseekV2MoE(nn.Module): self.experts = get_moe_impl_class(quant_config)( num_experts=config.n_routed_experts + self.num_fused_shared_experts - + get_global_server_args().ep_num_redundant_experts, + + global_server_args_dict["ep_num_redundant_experts"], num_fused_shared_experts=self.num_fused_shared_experts, top_k=config.num_experts_per_tok + self.num_fused_shared_experts, hidden_size=config.hidden_size, @@ -627,7 +627,7 @@ class DeepseekV2MoE(nn.Module): self.ep_size = get_moe_expert_parallel_world_size() self.num_experts = ( config.n_routed_experts - + get_global_server_args().ep_num_redundant_experts + + global_server_args_dict["ep_num_redundant_experts"] ) self.renormalize = config.norm_topk_prob self.topk_group = config.topk_group @@ -1125,7 +1125,7 @@ class DeepseekV2AttentionMLA(nn.Module): base=rope_theta, rope_scaling=rope_scaling, is_neox_style=False, - device=get_global_server_args().device, + device=global_server_args_dict["device"], ) if rope_scaling: @@ -1169,12 +1169,12 @@ class DeepseekV2AttentionMLA(nn.Module): self.w_scale_v = None self.use_deep_gemm_bmm = False - self.flashinfer_mla_disable_ragged = ( - get_global_server_args().flashinfer_mla_disable_ragged, - ) - self.disable_chunked_prefix_cache = ( - get_global_server_args().disable_chunked_prefix_cache - ) + self.flashinfer_mla_disable_ragged = global_server_args_dict[ + "flashinfer_mla_disable_ragged" + ] + self.disable_chunked_prefix_cache = global_server_args_dict[ + "disable_chunked_prefix_cache" + ] self.current_attention_backend = ( None # Attention backend used by current forward batch @@ -1253,18 +1253,18 @@ class DeepseekV2AttentionMLA(nn.Module): ) -> AttnForwardMethod: # Determine attention backend used by current forward batch if forward_batch.forward_mode.is_decode_or_idle(): - attention_backend = get_global_server_args().decode_attention_backend + attention_backend = global_server_args_dict["decode_attention_backend"] elif ( forward_batch.forward_mode.is_target_verify() or forward_batch.forward_mode.is_draft_extend() ): # Use the specified backend for speculative operations (both verify and draft extend) - if get_global_server_args().speculative_attention_mode == "decode": - attention_backend = get_global_server_args().decode_attention_backend + if global_server_args_dict["speculative_attention_mode"] == "decode": + attention_backend = global_server_args_dict["decode_attention_backend"] else: # default to prefill - attention_backend = get_global_server_args().prefill_attention_backend + attention_backend = global_server_args_dict["prefill_attention_backend"] else: - attention_backend = get_global_server_args().prefill_attention_backend + attention_backend = global_server_args_dict["prefill_attention_backend"] self.current_attention_backend = attention_backend handler = AttentionBackendRegistry.get_handler(attention_backend) @@ -2365,9 +2365,7 @@ class DeepseekV2DecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - self.speculative_algorithm = SpeculativeAlgorithm.from_string( - get_global_server_args().speculative_algorithm - ) + self.speculative_algorithm = global_server_args_dict["speculative_algorithm"] self.layer_id = layer_id self.is_nextn = is_nextn self.self_attn = DeepseekV2AttentionMLA( @@ -2819,7 +2817,7 @@ class DeepseekV2ForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) @@ -2839,7 +2837,7 @@ class DeepseekV2ForCausalLM(nn.Module): self, architecture: str = "DeepseekV3ForCausalLM" ): self.num_fused_shared_experts = 0 - if get_global_server_args().disable_shared_experts_fusion: + if global_server_args_dict["disable_shared_experts_fusion"]: return # Only Deepseek V3/R1 can use shared experts fusion optimization now. @@ -2858,7 +2856,7 @@ class DeepseekV2ForCausalLM(nn.Module): disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts." if disable_reason is not None: - get_global_server_args().disable_shared_experts_fusion = True + global_server_args_dict["disable_shared_experts_fusion"] = True self.num_fused_shared_experts = 0 log_info_on_rank0( logger, diff --git a/python/sglang/srt/models/falcon_h1.py b/python/sglang/srt/models/falcon_h1.py index 01f07be1d..f05a395d9 100644 --- a/python/sglang/srt/models/falcon_h1.py +++ b/python/sglang/srt/models/falcon_h1.py @@ -33,9 +33,9 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, is_cuda, make_layers logger = logging.getLogger(__name__) @@ -483,7 +483,7 @@ class FalconH1ForCausalLM(nn.Module): quant_config=quant_config, org_num_embeddings=config.vocab_size, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.lm_head = self.lm_head.float() self.lm_head_multiplier = config.lm_head_multiplier diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 5080bf88f..d4cc9e1e6 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -56,13 +56,18 @@ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz +from sglang.srt.layers.quantization.fp8_kernel import ( + is_fp8_fnuz, + per_tensor_quant_mla_fp8, + per_token_group_quant_mla_deep_gemm_masked_fp8, +) from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -72,7 +77,6 @@ from sglang.srt.models.deepseek_v2 import ( DeepseekV2Model, DeepseekV2MoE, ) -from sglang.srt.server_args import get_global_server_args from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher from sglang.srt.utils import ( BumpAllocator, @@ -391,7 +395,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): self.n_shared_experts = config.n_shared_experts self.num_fused_shared_experts = ( 0 - if get_global_server_args().disable_shared_experts_fusion + if global_server_args_dict["disable_shared_experts_fusion"] else config.n_shared_experts ) self.config = config @@ -428,7 +432,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): self.experts = get_moe_impl_class(quant_config)( num_experts=config.n_routed_experts + self.num_fused_shared_experts - + get_global_server_args().ep_num_redundant_experts, + + global_server_args_dict["ep_num_redundant_experts"], num_fused_shared_experts=self.num_fused_shared_experts, top_k=config.num_experts_per_tok + self.num_fused_shared_experts, hidden_size=config.hidden_size, @@ -472,7 +476,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): self.ep_size = get_moe_expert_parallel_world_size() self.num_experts = ( config.n_routed_experts - + get_global_server_args().ep_num_redundant_experts + + global_server_args_dict["ep_num_redundant_experts"] ) self.renormalize = config.norm_topk_prob self.topk_group = config.topk_group @@ -754,7 +758,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) @@ -770,7 +774,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): self, architecture: str = "Glm4MoeForCausalLM" ): self.num_fused_shared_experts = 0 - if get_global_server_args().disable_shared_experts_fusion: + if global_server_args_dict["disable_shared_experts_fusion"]: return # Only Deepseek V3/R1 can use shared experts fusion optimization now. @@ -786,7 +790,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism." if disable_reason is not None: - get_global_server_args().disable_shared_experts_fusion = True + global_server_args_dict["disable_shared_experts_fusion"] = True self.num_fused_shared_experts = 0 log_info_on_rank0( logger, diff --git a/python/sglang/srt/models/glm4_moe_nextn.py b/python/sglang/srt/models/glm4_moe_nextn.py index 8697fc1a1..4816f5775 100644 --- a/python/sglang/srt/models/glm4_moe_nextn.py +++ b/python/sglang/srt/models/glm4_moe_nextn.py @@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import BumpAllocator, add_prefix logger = logging.getLogger(__name__) @@ -145,7 +145,7 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM): config.hidden_size, quant_config=quant_config, prefix=add_prefix("model.shared_head.head", prefix), - use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/glm4v_moe.py b/python/sglang/srt/models/glm4v_moe.py index 2688d1225..fb3d26f11 100644 --- a/python/sglang/srt/models/glm4v_moe.py +++ b/python/sglang/srt/models/glm4v_moe.py @@ -16,10 +16,10 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.glm4_moe import Glm4MoeModel from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0 from sglang.srt.utils.hf_transformers_utils import get_processor @@ -47,7 +47,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): self.determine_num_fused_shared_experts("Glm4MoeForCausalLM") self.num_fused_shared_experts = ( 0 - if get_global_server_args().disable_shared_experts_fusion + if global_server_args_dict["disable_shared_experts_fusion"] else config.n_shared_experts ) @@ -68,7 +68,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) @@ -81,7 +81,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): self, architecture: str = "Glm4MoeForCausalLM" ): self.num_fused_shared_experts = 0 - if get_global_server_args().disable_shared_experts_fusion: + if global_server_args_dict["disable_shared_experts_fusion"]: return # Only Deepseek V3/R1 can use shared experts fusion optimization now. @@ -97,7 +97,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism." if disable_reason is not None: - get_global_server_args().disable_shared_experts_fusion = True + global_server_args_dict["disable_shared_experts_fusion"] = True self.num_fused_shared_experts = 0 log_info_on_rank0( logger, diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 1f280f37e..982400514 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -63,13 +63,13 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.utils import ( create_fused_set_kv_buffer_arg, enable_fused_set_kv_buffer, ) -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( LazyValue, add_prefix, @@ -138,7 +138,7 @@ class GptOssSparseMoeBlock(nn.Module): } self.experts = experts_type( num_experts=config.num_local_experts - + get_global_server_args().ep_num_redundant_experts, + + global_server_args_dict["ep_num_redundant_experts"], top_k=config.num_experts_per_tok, layer_id=layer_id, hidden_size=config.hidden_size, @@ -259,7 +259,7 @@ class GptOssAttention(nn.Module): # Choose dtype of sinks based on attention backend: trtllm_mha requires float32, # others can use bfloat16 - attn_backend = get_global_server_args().attention_backend + attn_backend = global_server_args_dict.get("attention_backend") sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16 self.sinks = nn.Parameter( torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False @@ -591,7 +591,7 @@ class GptOssForCausalLM(nn.Module): config.hidden_size, # quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) self.capture_aux_hidden_states = False diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 1f4a3b443..aa4a05713 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -28,6 +28,7 @@ from torch import nn from transformers import PretrainedConfig from sglang.srt.distributed import ( + get_moe_expert_parallel_world_size, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, @@ -35,6 +36,7 @@ from sglang.srt.distributed import ( ) from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.elementwise import ( + experts_combine_triton, fused_dual_residual_rmsnorm, fused_rmsnorm, gelu_and_mul_triton, @@ -62,10 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file logger = logging.getLogger(__name__) @@ -864,10 +866,10 @@ class Grok1ForCausalLM(nn.Module): # Dump tensors for debugging global debug_tensor_dump_output_folder, debug_tensor_dump_inject - debug_tensor_dump_output_folder = ( - get_global_server_args().debug_tensor_dump_output_folder - ) - debug_tensor_dump_inject = get_global_server_args().debug_tensor_dump_inject + debug_tensor_dump_output_folder = global_server_args_dict[ + "debug_tensor_dump_output_folder" + ] + debug_tensor_dump_inject = global_server_args_dict["debug_tensor_dump_inject"] warnings.filterwarnings("ignore", category=FutureWarning) if get_tensor_model_parallel_rank() == 0: diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index dbf6968ee..420a9d0f4 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -45,13 +45,13 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name, ) -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, make_layers from sglang.utils import get_exception_traceback @@ -433,7 +433,7 @@ class LlamaForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) diff --git a/python/sglang/srt/models/longcat_flash.py b/python/sglang/srt/models/longcat_flash.py index edfadfa0a..8af280771 100644 --- a/python/sglang/srt/models/longcat_flash.py +++ b/python/sglang/srt/models/longcat_flash.py @@ -32,10 +32,14 @@ import concurrent.futures import logging -from typing import Iterable, Optional, Tuple +import os +from enum import IntEnum, auto +from typing import Any, Dict, Iterable, Optional, Tuple, Union import torch +import torch.nn.functional as F from torch import nn +from tqdm import tqdm from sglang.srt.configs import LongcatFlashConfig from sglang.srt.distributed import ( @@ -81,10 +85,10 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.deepseek_v2 import DeepseekV2AttentionMLA -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( BumpAllocator, LazyValue, @@ -591,7 +595,7 @@ class LongcatFlashForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py index c68c394e2..bca9e7cc3 100644 --- a/python/sglang/srt/models/mllama4.py +++ b/python/sglang/srt/models/mllama4.py @@ -31,9 +31,9 @@ from sglang.srt.managers.schedule_batch import ( Modality, MultimodalDataItem, MultimodalInputs, + global_server_args_dict, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import is_cpu _is_cpu = is_cpu() @@ -448,7 +448,7 @@ class Llama4ForConditionalGeneration(nn.Module): ) self.has_vision = ( - self.has_vision_weights and get_global_server_args().enable_multimodal + self.has_vision_weights and global_server_args_dict["enable_multimodal"] ) if self.has_vision: diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 5d84a23af..f00610454 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -64,10 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.server_args import get_global_server_args from sglang.srt.two_batch_overlap import model_forward_maybe_tbo from sglang.srt.utils import add_prefix, is_cuda, make_layers @@ -156,7 +156,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): layer_id=self.layer_id, top_k=config.num_experts_per_tok, num_experts=config.num_experts - + get_global_server_args().ep_num_redundant_experts, + + global_server_args_dict["ep_num_redundant_experts"], hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, quant_config=quant_config, @@ -192,7 +192,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): # TODO: we will support tp < ep in the future self.ep_size = get_moe_expert_parallel_world_size() self.num_experts = ( - config.num_experts + get_global_server_args().ep_num_redundant_experts + config.num_experts + global_server_args_dict["ep_num_redundant_experts"] ) self.top_k = config.num_experts_per_tok @@ -643,7 +643,7 @@ class Qwen2MoeForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) # For EAGLE3 support diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 9991eb96b..c4842416c 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -54,6 +54,7 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -63,7 +64,6 @@ from sglang.srt.models.utils import ( create_fused_set_kv_buffer_arg, enable_fused_set_kv_buffer, ) -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( add_prefix, is_cuda, @@ -104,7 +104,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): self.experts = get_moe_impl_class(quant_config)( num_experts=config.num_experts - + get_global_server_args().ep_num_redundant_experts, + + global_server_args_dict["ep_num_redundant_experts"], top_k=config.num_experts_per_tok, layer_id=layer_id, hidden_size=config.hidden_size, @@ -125,7 +125,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): # TODO: we will support tp < ep in the future self.ep_size = get_moe_expert_parallel_world_size() self.num_experts = ( - config.num_experts + get_global_server_args().ep_num_redundant_experts + config.num_experts + global_server_args_dict["ep_num_redundant_experts"] ) self.top_k = config.num_experts_per_tok @@ -693,7 +693,7 @@ class Qwen3MoeForCausalLM(nn.Module): config.hidden_size, quant_config=quant_config, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) self.capture_aux_hidden_states = False diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 1b11aa30b..2a1b9d48c 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -39,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import ( @@ -46,7 +47,6 @@ from sglang.srt.model_loader.weight_utils import ( sharded_weight_loader, ) from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( LazyValue, add_prefix, @@ -905,7 +905,7 @@ class Qwen3NextForCausalLM(nn.Module): quant_config=quant_config, org_num_embeddings=config.vocab_size, prefix=add_prefix("lm_head", prefix), - use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.lm_head = self.lm_head.float() self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/qwen3_next_mtp.py b/python/sglang/srt/models/qwen3_next_mtp.py index aa0f8ec1e..b123efcf8 100644 --- a/python/sglang/srt/models/qwen3_next_mtp.py +++ b/python/sglang/srt/models/qwen3_next_mtp.py @@ -21,13 +21,14 @@ from torch import nn from transformers import PretrainedConfig from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size -from sglang.srt.layers.layernorm import GemmaRMSNorm +from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.qwen3_moe import Qwen3MoeModel from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix logger = logging.getLogger(__name__) @@ -68,7 +69,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM): config.hidden_size, quant_config=quant_config, prefix=add_prefix("model.shared_head.head", prefix), - use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) diff --git a/python/sglang/srt/models/qwen3_vl_moe.py b/python/sglang/srt/models/qwen3_vl_moe.py index 507403adb..125114749 100644 --- a/python/sglang/srt/models/qwen3_vl_moe.py +++ b/python/sglang/srt/models/qwen3_vl_moe.py @@ -38,12 +38,20 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead -from sglang.srt.managers.mm_utils import general_mm_embed_routine -from sglang.srt.managers.schedule_batch import MultimodalDataItem +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import ( + MultimodalDataItem, + MultimodalInputs, + global_server_args_dict, +) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.models.qwen3_moe import Qwen3MoeModel +from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel from sglang.srt.models.qwen3_vl import ( Qwen3_VisionTransformer, Qwen3VLForConditionalGeneration, diff --git a/python/sglang/srt/models/step3_vl.py b/python/sglang/srt/models/step3_vl.py index 14d277f9f..626406da1 100644 --- a/python/sglang/srt/models/step3_vl.py +++ b/python/sglang/srt/models/step3_vl.py @@ -57,6 +57,7 @@ from sglang.srt.managers.schedule_batch import ( Modality, MultimodalDataItem, MultimodalInputs, + global_server_args_dict, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -299,7 +300,7 @@ class Step3TextDecoderLayer(nn.Module): # self.n_shared_experts = 1 # self.num_fused_shared_experts = ( # 0 - # if global_server_args.disable_shared_experts_fusion + # if global_server_args_dict["disable_shared_experts_fusion"] # else self.n_shared_experts # ) self.num_fused_shared_experts = 0 @@ -773,7 +774,7 @@ class Step3VLForConditionalGeneration(nn.Module): # self.n_shared_experts = 1 # self.num_fused_shared_experts = ( # 0 - # if global_server_args.disable_shared_experts_fusion + # if global_server_args_dict["disable_shared_experts_fusion"] # else self.n_shared_experts # ) self.num_fused_shared_experts = 0 diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index cff9419b7..d636ccdd0 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -2,6 +2,7 @@ from __future__ import annotations import dataclasses import logging +import threading from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple import torch @@ -9,7 +10,6 @@ import torch import sglang.srt.sampling.penaltylib as penaltylib from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.sampling_params import TOP_K_ALL -from sglang.srt.server_args import get_global_server_args if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -66,10 +66,16 @@ class SamplingBatchInfo: # Handle logit bias logit_bias: Optional[torch.Tensor] = None + @classmethod + def _get_global_server_args_dict(cls): + from sglang.srt.managers.schedule_batch import global_server_args_dict + + return global_server_args_dict + @classmethod def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): - global_server_args = get_global_server_args() - enable_deterministic = global_server_args.enable_deterministic_inference + global_server_args_dict = cls._get_global_server_args_dict() + enable_deterministic = global_server_args_dict["enable_deterministic_inference"] reqs = batch.reqs device = batch.device @@ -106,9 +112,10 @@ class SamplingBatchInfo: logit_bias[i, int(key)] = value # Check if any request has custom logit processor - has_custom_logit_processor = ( - global_server_args.enable_custom_logit_processor - and any(r.custom_logit_processor for r in reqs) # check the flag first. + has_custom_logit_processor = global_server_args_dict[ + "enable_custom_logit_processor" + ] and any( # check the flag first. + r.custom_logit_processor for r in reqs ) # then check the requests. if has_custom_logit_processor: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e20e35129..8053be39d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -53,7 +53,6 @@ from sglang.utils import is_in_ci logger = logging.getLogger(__name__) - # Define constants LOAD_FORMAT_CHOICES = [ "auto", @@ -3324,22 +3323,6 @@ class ServerArgs: ) -# NOTE: This is a global variable to hold the server args for scheduler. -_global_server_args: Optional[ServerArgs] = None - - -def set_global_server_args_for_scheduler(server_args: ServerArgs): - global _global_server_args - _global_server_args = server_args - - -def get_global_server_args() -> ServerArgs: - if _global_server_args is None: - raise ValueError("Global server args is not set yet!") - - return _global_server_args - - def prepare_server_args(argv: List[str]) -> ServerArgs: """ Prepare the server arguments from the command line arguments. @@ -3374,8 +3357,8 @@ def prepare_server_args(argv: List[str]) -> ServerArgs: parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) raw_args = parser.parse_args(argv) - - return ServerArgs.from_cli_args(raw_args) + server_args = ServerArgs.from_cli_args(raw_args) + return server_args ZMQ_TCP_PORT_DELTA = 233 diff --git a/python/sglang/srt/single_batch_overlap.py b/python/sglang/srt/single_batch_overlap.py index c56af1731..b8839c68f 100644 --- a/python/sglang/srt/single_batch_overlap.py +++ b/python/sglang/srt/single_batch_overlap.py @@ -6,6 +6,7 @@ import torch from sglang.srt.layers.moe import get_moe_runner_backend from sglang.srt.layers.moe.utils import is_sbo_enabled from sglang.srt.layers.quantization import deep_gemm_wrapper +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import get_int_env_var diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index ad94a7c5c..d230cf193 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -11,7 +11,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.managers.overlap_utils import FutureIndices -from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.common import ( alloc_paged_token_slots_extend, @@ -19,7 +19,6 @@ from sglang.srt.mem_cache.common import ( get_last_loc, ) from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode -from sglang.srt.server_args import get_global_server_args from sglang.srt.speculative.eagle_info_v2 import ( EagleDraftInputV2Mixin, EagleVerifyInputV2Mixin, @@ -333,8 +332,12 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin): uniform_samples_for_final_sampling=coins_for_final_sampling, target_probs=target_probs, draft_probs=draft_probs, - threshold_single=get_global_server_args().speculative_accept_threshold_single, - threshold_acc=get_global_server_args().speculative_accept_threshold_acc, + threshold_single=global_server_args_dict[ + "speculative_accept_threshold_single" + ], + threshold_acc=global_server_args_dict[ + "speculative_accept_threshold_acc" + ], deterministic=True, ) diff --git a/python/sglang/srt/speculative/eagle_info_v2.py b/python/sglang/srt/speculative/eagle_info_v2.py index 982343ce9..23902a846 100644 --- a/python/sglang/srt/speculative/eagle_info_v2.py +++ b/python/sglang/srt/speculative/eagle_info_v2.py @@ -11,6 +11,7 @@ import triton.language as tl from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import ModelWorkerBatch +from sglang.srt.managers.scheduler import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, @@ -18,7 +19,6 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardMode, ) from sglang.srt.model_executor.model_runner import ModelRunner -from sglang.srt.server_args import get_global_server_args from sglang.srt.speculative.build_eagle_tree import TreeMaskMode from sglang.srt.speculative.spec_utils import ( SIMULATE_ACC_LEN, @@ -265,8 +265,12 @@ class EagleVerifyInputV2Mixin: uniform_samples_for_final_sampling=coins_for_final_sampling, target_probs=target_probs, draft_probs=draft_probs, - threshold_single=get_global_server_args().speculative_accept_threshold_single, - threshold_acc=get_global_server_args().speculative_accept_threshold_acc, + threshold_single=global_server_args_dict[ + "speculative_accept_threshold_single" + ], + threshold_acc=global_server_args_dict[ + "speculative_accept_threshold_acc" + ], deterministic=True, ) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index f501f9d8b..162ce53ec 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -14,7 +14,7 @@ from sglang.srt.distributed import ( ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs -from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.mem_cache.common import ( @@ -27,7 +27,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, ) -from sglang.srt.server_args import ServerArgs, get_global_server_args +from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( EAGLEDraftCudaGraphRunner, @@ -261,7 +261,7 @@ class EAGLEWorker(TpModelWorker): ) def _create_flashinfer_decode_backend(self): - if not get_global_server_args().use_mla_backend: + if not global_server_args_dict["use_mla_backend"]: from sglang.srt.layers.attention.flashinfer_backend import ( FlashInferMultiStepDraftBackend, ) @@ -325,7 +325,7 @@ class EAGLEWorker(TpModelWorker): ) def _create_trtllm_mla_decode_backend(self): - if not get_global_server_args().use_mla_backend: + if not global_server_args_dict["use_mla_backend"]: raise ValueError( "trtllm_mla backend requires MLA model (use_mla_backend=True)." ) @@ -340,7 +340,7 @@ class EAGLEWorker(TpModelWorker): ) def _create_flashinfer_prefill_backend(self): - if not get_global_server_args().use_mla_backend: + if not global_server_args_dict["use_mla_backend"]: from sglang.srt.layers.attention.flashinfer_backend import ( FlashInferAttnBackend, ) @@ -376,7 +376,7 @@ class EAGLEWorker(TpModelWorker): return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False) def _create_trtllm_mla_prefill_backend(self): - if not get_global_server_args().use_mla_backend: + if not global_server_args_dict["use_mla_backend"]: raise ValueError( "trtllm_mla backend requires MLA model (use_mla_backend=True)." ) diff --git a/python/sglang/srt/speculative/ngram_info.py b/python/sglang/srt/speculative/ngram_info.py index f0d152ab4..ce4557b89 100644 --- a/python/sglang/srt/speculative/ngram_info.py +++ b/python/sglang/srt/speculative/ngram_info.py @@ -7,8 +7,6 @@ from typing import Optional, Tuple import torch import triton -from sglang.srt.server_args import get_global_server_args - logger = logging.getLogger(__name__) from dataclasses import dataclass @@ -18,7 +16,7 @@ import torch.nn.functional as F from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import apply_custom_logit_processor -from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.mem_cache.common import ( alloc_paged_token_slots_extend, alloc_token_slots, @@ -352,8 +350,10 @@ class NgramVerifyInput(SpecInput): uniform_samples_for_final_sampling=coins_for_final_sampling, target_probs=target_probs, draft_probs=draft_probs, - threshold_single=get_global_server_args().speculative_accept_threshold_single, - threshold_acc=get_global_server_args().speculative_accept_threshold_acc, + threshold_single=global_server_args_dict[ + "speculative_accept_threshold_single" + ], + threshold_acc=global_server_args_dict["speculative_accept_threshold_acc"], deterministic=True, ) diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index a5485e8e9..61e45440b 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -22,7 +22,7 @@ from sglang.srt.layers.moe import ( ) from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.quantization import deep_gemm_wrapper -from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, @@ -30,7 +30,6 @@ from sglang.srt.model_executor.forward_batch_info import ( ) from sglang.srt.operations import execute_operations, execute_overlapped_operations from sglang.srt.operations_strategy import OperationsStrategy -from sglang.srt.server_args import get_global_server_args from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip @@ -154,7 +153,7 @@ def _update_device_and_sum_field_from_cpu_field( cpu_value if isinstance(cpu_value, torch.Tensor) else torch.tensor(cpu_value, dtype=old_device_value.dtype) - ).to(device=get_global_server_args().device, non_blocking=True) + ).to(device=global_server_args_dict["device"], non_blocking=True) setattr(batch, device_field, new_device_value) if sum_field is not None: @@ -583,7 +582,7 @@ class TboForwardBatchPreparer: sum_field=None, ) _, child_b.extend_start_loc = compute_position( - get_global_server_args().attention_backend, + global_server_args_dict["attention_backend"], child_b.extend_prefix_lens, child_b.extend_seq_lens, child_b.extend_num_tokens, @@ -688,7 +687,7 @@ class TboForwardBatchPreparer: # TODO improve, e.g. unify w/ `init_raw` if ( - get_global_server_args().moe_dense_tp_size == 1 + global_server_args_dict["moe_dense_tp_size"] == 1 and batch.global_dp_buffer_len is not None ): sum_len = end_token_index - start_token_index @@ -756,7 +755,7 @@ class TboForwardBatchPreparer: value_a = min(tbo_split_token_index, num_token_non_padded) value_b = max(0, num_token_non_padded - tbo_split_token_index) return torch.tensor([value_a, value_b], dtype=torch.int32).to( - device=get_global_server_args().device, non_blocking=True + device=global_server_args_dict["device"], non_blocking=True ) @classmethod diff --git a/test/srt/rl/test_fp32_lm_head.py b/test/srt/rl/test_fp32_lm_head.py index dea43995b..e892e3151 100644 --- a/test/srt/rl/test_fp32_lm_head.py +++ b/test/srt/rl/test_fp32_lm_head.py @@ -7,11 +7,7 @@ import torch.nn as nn import torch.nn.functional as F from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.server_args import ( - ServerArgs, - get_global_server_args, - set_global_server_args_for_scheduler, -) +from sglang.srt.managers.schedule_batch import global_server_args_dict class LMHeadStub(nn.Module): @@ -36,10 +32,8 @@ class TestLMHeadFP32(unittest.TestCase): raise unittest.SkipTest("needs CUDA GPU") def _make_logprocessor(self, vocab_size, enable_fp32): - ServerArgs.__post_init__ = lambda self: None # disable validation - set_global_server_args_for_scheduler(ServerArgs(model_path="dummy")) - get_global_server_args().enable_dp_lm_head = False - get_global_server_args().enable_fp32_lm_head = enable_fp32 + global_server_args_dict["enable_dp_lm_head"] = False + global_server_args_dict["enable_fp32_lm_head"] = enable_fp32 cfg = SimpleNamespace(vocab_size=vocab_size, final_logit_softcapping=None) return LogitsProcessor(cfg, skip_all_gather=True, logit_scale=None) diff --git a/test/srt/test_gptqmodel_dynamic.py b/test/srt/test_gptqmodel_dynamic.py index ea141df3e..9be711d12 100644 --- a/test/srt/test_gptqmodel_dynamic.py +++ b/test/srt/test_gptqmodel_dynamic.py @@ -4,7 +4,6 @@ import unittest import requests import torch -from sglang.srt.server_args import set_global_server_args_for_scheduler from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -17,15 +16,17 @@ from sglang.test.test_utils import ( def check_quant_method(model_path: str, use_marlin_kernel: bool): from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig - from sglang.srt.configs.model_config import ModelConfig + from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.distributed import ( + get_tp_group, init_distributed_environment, initialize_model_parallel, + set_custom_all_reduce, ) from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state from sglang.srt.layers.quantization.utils import get_dynamic_override from sglang.srt.model_loader import get_model - from sglang.srt.server_args import ServerArgs + from sglang.srt.server_args import PortArgs, ServerArgs try: init_distributed_environment( @@ -42,7 +43,6 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool): pass server_args = ServerArgs(model_path=model_path, dtype=torch.float16) - set_global_server_args_for_scheduler(server_args) model_config = ModelConfig.from_server_args(server_args) load_config = LoadConfig()