Deprecate global_server_args_dict (#11331)

This commit is contained in:
Liangsheng Yin
2025-10-13 01:20:47 +08:00
committed by GitHub
parent 2157d12ae8
commit 1083e7e3df
54 changed files with 240 additions and 321 deletions

View File

@@ -6,9 +6,6 @@
class GlobalConfig: class GlobalConfig:
""" """
Store some global constants. Store some global constants.
See also python/sglang/srt/managers/schedule_batch.py::global_server_args_dict, which stores
many global runtime arguments as well.
""" """
def __init__(self): def __init__(self):

View File

@@ -5,7 +5,7 @@ from packaging import version
from torch.cuda.memory import CUDAPluggableAllocator from torch.cuda.memory import CUDAPluggableAllocator
from sglang.srt.distributed.parallel_state import GroupCoordinator from sglang.srt.distributed.parallel_state import GroupCoordinator
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.server_args import get_global_server_args
nccl_allocator_source = """ nccl_allocator_source = """
#include <nccl.h> #include <nccl.h>
@@ -32,7 +32,7 @@ _graph_pool_id = None
def is_symmetric_memory_enabled(): def is_symmetric_memory_enabled():
return global_server_args_dict["enable_symm_mem"] return get_global_server_args().enable_symm_mem
def set_graph_pool_id(graph_pool_id): def set_graph_pool_id(graph_pool_id):

View File

@@ -18,7 +18,7 @@ from typing import Literal, Optional
import torch import torch
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.server_args import get_global_server_args
@dataclass @dataclass
@@ -34,7 +34,7 @@ class ExpertLocationDispatchInfo:
@classmethod @classmethod
def init_new(cls, layer_id: int): def init_new(cls, layer_id: int):
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"] ep_dispatch_algorithm = get_global_server_args().ep_dispatch_algorithm
expert_location_metadata = get_global_expert_location_metadata() expert_location_metadata = get_global_expert_location_metadata()
assert expert_location_metadata is not None assert expert_location_metadata is not None

View File

@@ -24,7 +24,7 @@ from sglang.srt.eplb.expert_location import (
ExpertLocationMetadata, ExpertLocationMetadata,
get_global_expert_location_metadata, get_global_expert_location_metadata,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import get_bool_env_var from sglang.srt.utils import get_bool_env_var
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -97,7 +97,7 @@ def _update_expert_weights_with_canary(
canary_tensor = ( canary_tensor = (
_get_canary_value(old_expert_location_metadata, layer_id) _get_canary_value(old_expert_location_metadata, layer_id)
.clone() .clone()
.to(device=global_server_args_dict["device"], non_blocking=True) .to(device=get_global_server_args().device, non_blocking=True)
) )
routed_experts_weights_of_layer[layer_id].append(canary_tensor) routed_experts_weights_of_layer[layer_id].append(canary_tensor)

View File

@@ -5,8 +5,8 @@ from typing import TYPE_CHECKING
import torch import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
@@ -42,7 +42,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
# TODO: Change the hard-coded block_seq_num # TODO: Change the hard-coded block_seq_num
self.BLOCK_SEQ = 128 self.BLOCK_SEQ = 128
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False): if get_global_server_args().triton_attention_reduce_in_fp32:
self.reduce_dtype = torch.float32 self.reduce_dtype = torch.float32
else: else:
self.reduce_dtype = torch.float16 self.reduce_dtype = torch.float16

View File

@@ -11,8 +11,8 @@ import triton.language as tl
from sglang.srt.configs.model_config import AttentionArch from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.speculative.spec_info import SpecInput
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -830,7 +830,7 @@ class FlashAttentionBackend(AttentionBackend):
): ):
# Do multi-head attention with chunked prefix cache # Do multi-head attention with chunked prefix cache
if forward_batch.attn_attend_prefix_cache: if forward_batch.attn_attend_prefix_cache:
assert not global_server_args_dict["disable_chunked_prefix_cache"] assert not get_global_server_args().disable_chunked_prefix_cache
# MHA for chunked prefix kv cache when running model with MLA # MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None assert forward_batch.prefix_chunk_cu_seq_lens is not None

View File

@@ -28,8 +28,8 @@ from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton, create_flashinfer_kv_indices_triton,
) )
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import ( from sglang.srt.utils import (
is_flashinfer_available, is_flashinfer_available,
@@ -193,9 +193,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
self.enable_chunk_kv = ( self.enable_chunk_kv = (
not skip_prefill not skip_prefill
and global_server_args_dict["disaggregation_mode"] != "decode" and get_global_server_args().disaggregation_mode != "decode"
and not global_server_args_dict["disable_chunked_prefix_cache"] and not get_global_server_args().disable_chunked_prefix_cache
and not global_server_args_dict["flashinfer_mla_disable_ragged"] and not get_global_server_args().flashinfer_mla_disable_ragged
) )
self.page_size = model_runner.page_size self.page_size = model_runner.page_size
@@ -306,7 +306,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
prefix_lens = forward_batch.extend_prefix_lens prefix_lens = forward_batch.extend_prefix_lens
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
use_ragged = ( use_ragged = (
not global_server_args_dict["flashinfer_mla_disable_ragged"] not get_global_server_args().flashinfer_mla_disable_ragged
and extend_no_prefix and extend_no_prefix
) )

View File

@@ -23,9 +23,9 @@ from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.rotary_embedding import get_rope_wrapper from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
@@ -162,7 +162,7 @@ class Indexer(CustomOp):
base=rope_theta, # type: ignore base=rope_theta, # type: ignore
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
is_neox_style=False, is_neox_style=False,
device=global_server_args_dict["device"], device=get_global_server_args().device,
) )
self.block_size = block_size self.block_size = block_size
self.scale_fmt = scale_fmt self.scale_fmt = scale_fmt

View File

@@ -2,7 +2,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda() _is_cuda = is_cuda()
@@ -11,7 +11,7 @@ if _is_cuda:
_is_hip = is_hip() _is_hip = is_hip()
if global_server_args_dict.get("attention_reduce_in_fp32", False): if get_global_server_args().triton_attention_reduce_in_fp32:
REDUCE_TRITON_TYPE = tl.float32 REDUCE_TRITON_TYPE = tl.float32
REDUCE_TORCH_TYPE = torch.float32 REDUCE_TORCH_TYPE = torch.float32
else: else:

View File

@@ -20,8 +20,8 @@ from sglang.srt.layers.attention.utils import (
create_flashmla_kv_indices_triton, create_flashmla_kv_indices_triton,
) )
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import is_cuda, is_flashinfer_available from sglang.srt.utils import is_cuda, is_flashinfer_available
if is_flashinfer_available(): if is_flashinfer_available():
@@ -123,9 +123,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
self.disable_chunked_prefix_cache = global_server_args_dict[ self.disable_chunked_prefix_cache = (
"disable_chunked_prefix_cache" get_global_server_args().disable_chunked_prefix_cache
] )
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens

View File

@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import (
) )
from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
ROTARY_EMBED_CLASSES = { ROTARY_EMBED_CLASSES = {
@@ -468,7 +468,7 @@ class VisionAttention(nn.Module):
_passed_backend = qkv_backend _passed_backend = qkv_backend
qkv_backend = self._determine_attention_backend(_passed_backend) qkv_backend = self._determine_attention_backend(_passed_backend)
if ( if (
global_server_args_dict["mm_attention_backend"] is None get_global_server_args().mm_attention_backend is None
and _passed_backend is None and _passed_backend is None
): ):
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.") print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
@@ -528,7 +528,7 @@ class VisionAttention(nn.Module):
- CUDA: "triton_attn" - CUDA: "triton_attn"
- Non-CUDA: "sdpa" - Non-CUDA: "sdpa"
""" """
override_backend = global_server_args_dict["mm_attention_backend"] override_backend = get_global_server_args().mm_attention_backend
if override_backend is not None: if override_backend is not None:
backend = override_backend backend = override_backend
elif passed_backend is not None: elif passed_backend is not None:

View File

@@ -40,8 +40,9 @@ from sglang.srt.layers.moe import (
get_moe_a2a_backend, get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather, should_use_flashinfer_cutlass_moe_fp4_allgather,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
is_cuda, is_cuda,
@@ -168,7 +169,7 @@ class LayerScatterModes:
def enable_moe_dense_fully_dp(): def enable_moe_dense_fully_dp():
return global_server_args_dict["moe_dense_tp_size"] == 1 return get_global_server_args().moe_dense_tp_size == 1
class LayerCommunicator: class LayerCommunicator:
@@ -314,7 +315,9 @@ class LayerCommunicator:
def should_fuse_mlp_allreduce_with_next_layer( def should_fuse_mlp_allreduce_with_next_layer(
self, forward_batch: ForwardBatch self, forward_batch: ForwardBatch
) -> bool: ) -> bool:
speculative_algo = global_server_args_dict.get("speculative_algorithm", None) speculative_algo = SpeculativeAlgorithm.from_string(
get_global_server_args().speculative_algorithm
)
if ( if (
is_dp_attention_enabled() is_dp_attention_enabled()
and speculative_algo is not None and speculative_algo is not None
@@ -333,7 +336,7 @@ class LayerCommunicator:
static_conditions_met = ( static_conditions_met = (
(not self.is_last_layer) (not self.is_last_layer)
and (self._context.tp_size > 1) and (self._context.tp_size > 1)
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False) and get_global_server_args().enable_flashinfer_allreduce_fusion
and _is_flashinfer_available and _is_flashinfer_available
) )
@@ -531,7 +534,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
(_is_sm100_supported or _is_sm90_supported) (_is_sm100_supported or _is_sm90_supported)
and _is_flashinfer_available and _is_flashinfer_available
and hasattr(layernorm, "forward_with_allreduce_fusion") and hasattr(layernorm, "forward_with_allreduce_fusion")
and global_server_args_dict["enable_flashinfer_allreduce_fusion"] and get_global_server_args().enable_flashinfer_allreduce_fusion
and hidden_states.shape[0] <= 4096 and hidden_states.shape[0] <= 4096
): ):
hidden_states, residual = layernorm.forward_with_allreduce_fusion( hidden_states, residual = layernorm.forward_with_allreduce_fusion(

View File

@@ -38,17 +38,15 @@ from sglang.srt.layers.dp_attention import (
get_dp_device, get_dp_device,
get_dp_dtype, get_dp_dtype,
get_dp_hidden_size, get_dp_hidden_size,
get_global_dp_buffer,
get_local_attention_dp_size, get_local_attention_dp_size,
set_dp_buffer_len,
) )
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode, CaptureHiddenMode,
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -230,8 +228,8 @@ class LogitsProcessor(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.logit_scale = logit_scale self.logit_scale = logit_scale
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"] self.use_attn_tp_group = get_global_server_args().enable_dp_lm_head
self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"] self.use_fp32_lm_head = get_global_server_args().enable_fp32_lm_head
if self.use_attn_tp_group: if self.use_attn_tp_group:
self.attn_tp_size = get_attention_tp_size() self.attn_tp_size = get_attention_tp_size()
self.do_tensor_parallel_all_gather = ( self.do_tensor_parallel_all_gather = (
@@ -254,8 +252,8 @@ class LogitsProcessor(nn.Module):
): ):
self.final_logit_softcapping = None self.final_logit_softcapping = None
self.debug_tensor_dump_output_folder = global_server_args_dict.get( self.debug_tensor_dump_output_folder = (
"debug_tensor_dump_output_folder", None get_global_server_args().debug_tensor_dump_output_folder
) )
def compute_logprobs_for_multi_item_scoring( def compute_logprobs_for_multi_item_scoring(
@@ -372,9 +370,7 @@ class LogitsProcessor(nn.Module):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
# Check if multi-item scoring is enabled via server args (only for prefill-only requests) # Check if multi-item scoring is enabled via server args (only for prefill-only requests)
multi_item_delimiter = global_server_args_dict.get( multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter
"multi_item_scoring_delimiter"
)
if multi_item_delimiter is not None and logits_metadata.is_prefill_only: if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
return self.compute_logprobs_for_multi_item_scoring( return self.compute_logprobs_for_multi_item_scoring(
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter

View File

@@ -27,12 +27,10 @@ from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,

View File

@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import ( from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
is_cuda, is_cuda,
@@ -265,9 +265,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
self.with_bias = False self.with_bias = False
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4() self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
self.flashinfer_mxfp4_moe_precision = global_server_args_dict[ self.flashinfer_mxfp4_moe_precision = (
"flashinfer_mxfp4_moe_precision" get_global_server_args().flashinfer_mxfp4_moe_precision
] )
self.triton_kernel_moe_forward = None self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None self.triton_kernel_moe_with_bias_forward = None

View File

@@ -11,8 +11,8 @@ from sglang.srt.layers.dp_attention import (
is_dp_attention_enabled, is_dp_attention_enabled,
) )
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
if is_cuda(): if is_cuda():
@@ -33,7 +33,7 @@ RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
class Sampler(nn.Module): class Sampler(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.use_nan_detection = global_server_args_dict["enable_nan_detection"] self.use_nan_detection = get_global_server_args().enable_nan_detection
self.tp_sync_group = get_tp_group().device_group self.tp_sync_group = get_tp_group().device_group
if is_dp_attention_enabled(): if is_dp_attention_enabled():
@@ -104,7 +104,7 @@ class Sampler(nn.Module):
del logits del logits
if True: # Keep this redundant check to simplify some internal code sync if True: # Keep this redundant check to simplify some internal code sync
if global_server_args_dict["sampling_backend"] == "flashinfer": if get_global_server_args().sampling_backend == "flashinfer":
if sampling_info.need_min_p_sampling: if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks) probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps) probs = top_p_renorm_prob(probs, sampling_info.top_ps)
@@ -119,7 +119,7 @@ class Sampler(nn.Module):
filter_apply_order="joint", filter_apply_order="joint",
check_nan=self.use_nan_detection, check_nan=self.use_nan_detection,
) )
elif global_server_args_dict["sampling_backend"] == "pytorch": elif get_global_server_args().sampling_backend == "pytorch":
# A slower fallback implementation with torch native operations. # A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
probs, probs,
@@ -132,7 +132,7 @@ class Sampler(nn.Module):
) )
else: else:
raise ValueError( raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" f"Invalid sampling backend: {get_global_server_args().sampling_backend}"
) )
if return_logprob: if return_logprob:

View File

@@ -16,10 +16,10 @@ from sglang.srt.managers.schedule_batch import (
Modality, Modality,
MultimodalDataItem, MultimodalDataItem,
MultimodalInputs, MultimodalInputs,
global_server_args_dict,
) )
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
from sglang.utils import logger from sglang.utils import logger
@@ -428,7 +428,7 @@ def _adjust_embedding_length(
f"tokens from multimodal embeddings." f"tokens from multimodal embeddings."
) )
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding: if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"] chunked_prefill_size = get_global_server_args().chunked_prefill_size
if chunked_prefill_size != -1: if chunked_prefill_size != -1:
logger.warning( logger.warning(
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill" "You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"

View File

@@ -72,7 +72,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs, get_global_server_args
from sglang.srt.utils import flatten_nested_list from sglang.srt.utils import flatten_nested_list
from sglang.srt.utils.common import next_power_of_2 from sglang.srt.utils.common import next_power_of_2
@@ -82,47 +82,6 @@ if TYPE_CHECKING:
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
GLOBAL_SERVER_ARGS_KEYS = [
"attention_backend",
"mm_attention_backend",
"debug_tensor_dump_inject",
"debug_tensor_dump_output_folder",
"chunked_prefill_size",
"device",
"disable_chunked_prefix_cache",
"disable_flashinfer_cutlass_moe_fp4_allgather",
"disable_radix_cache",
"enable_dp_lm_head",
"enable_fp32_lm_head",
"flashinfer_mxfp4_moe_precision",
"enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size",
"ep_dispatch_algorithm",
"ep_num_redundant_experts",
"enable_nan_detection",
"flashinfer_mla_disable_ragged",
"pp_max_micro_batch_size",
"disable_shared_experts_fusion",
"sampling_backend",
"speculative_accept_threshold_single",
"speculative_accept_threshold_acc",
"speculative_attention_mode",
"torchao_config",
"triton_attention_reduce_in_fp32",
"num_reserved_decode_tokens",
"weight_loader_disable_mmap",
"enable_multimodal",
"enable_symm_mem",
"enable_custom_logit_processor",
"disaggregation_mode",
"enable_deterministic_inference",
"nsa_prefill",
"nsa_decode",
"multi_item_scoring_delimiter",
]
# Put some global args for easy access
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -683,12 +642,9 @@ class Req:
def is_prefill_only(self) -> bool: def is_prefill_only(self) -> bool:
"""Check if this request is prefill-only (no token generation needed).""" """Check if this request is prefill-only (no token generation needed)."""
# NOTE: when spec is enabled, prefill_only optimizations are disabled # NOTE: when spec is enabled, prefill_only optimizations are disabled
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
spec_alg = global_server_args_dict["speculative_algorithm"] spec_alg = get_global_server_args().speculative_algorithm
return self.sampling_params.max_new_tokens == 0 and ( return self.sampling_params.max_new_tokens == 0 and spec_alg is None
spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE
)
def add_latency(self, stage: RequestStage): def add_latency(self, stage: RequestStage):
if self.metrics_collector is None: if self.metrics_collector is None:

View File

@@ -122,7 +122,6 @@ from sglang.srt.managers.schedule_batch import (
Req, Req,
RequestStage, RequestStage,
ScheduleBatch, ScheduleBatch,
global_server_args_dict,
) )
from sglang.srt.managers.schedule_policy import ( from sglang.srt.managers.schedule_policy import (
AddReqResult, AddReqResult,
@@ -150,7 +149,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
from sglang.srt.speculative.eagle_info import EagleDraftInput from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.tracing.trace import ( from sglang.srt.tracing.trace import (
@@ -447,13 +446,12 @@ class Scheduler(
self.max_req_input_len, self.max_req_input_len,
self.random_seed, self.random_seed,
self.device, self.device,
worker_global_server_args_dict,
_, _,
_, _,
_, _,
) = self.tp_worker.get_worker_info() ) = self.tp_worker.get_worker_info()
if global_server_args_dict["pp_max_micro_batch_size"] is None: if get_global_server_args().pp_max_micro_batch_size is None:
global_server_args_dict["pp_max_micro_batch_size"] = max( get_global_server_args().pp_max_micro_batch_size = max(
self.max_running_requests // server_args.pp_size, 1 self.max_running_requests // server_args.pp_size, 1
) )
@@ -465,7 +463,6 @@ class Scheduler(
self.world_group = get_world_group() self.world_group = get_world_group()
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func() self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed) set_random_seed(self.random_seed)
# Hybrid memory pool # Hybrid memory pool
@@ -1866,7 +1863,7 @@ class Scheduler(
return ret return ret
def get_num_allocatable_reqs(self, running_bs): def get_num_allocatable_reqs(self, running_bs):
res = global_server_args_dict["pp_max_micro_batch_size"] - running_bs res = get_global_server_args().pp_max_micro_batch_size - running_bs
if self.pp_size > 1: if self.pp_size > 1:
res = min(res, self.req_to_token_pool.available_size()) res = min(res, self.req_to_token_pool.available_size())
return res return res
@@ -2610,7 +2607,7 @@ class Scheduler(
) )
def get_internal_state(self, recv_req: GetInternalStateReq): def get_internal_state(self, recv_req: GetInternalStateReq):
ret = dict(global_server_args_dict) ret = vars(get_global_server_args())
ret["last_gen_throughput"] = self.last_gen_throughput ret["last_gen_throughput"] = self.last_gen_throughput
ret["memory_usage"] = { ret["memory_usage"] = {
"weight": round( "weight": round(
@@ -2666,11 +2663,11 @@ class Scheduler(
logger.info(f"{avg_spec_accept_length=}") logger.info(f"{avg_spec_accept_length=}")
self.cum_spec_accept_length = self.cum_spec_accept_count = 0 self.cum_spec_accept_length = self.cum_spec_accept_count = 0
for k, v in server_args_dict.items(): for k, v in server_args_dict.items():
global_server_args_dict[k] = v setattr(get_global_server_args(), k, v)
logger.info(f"Global server args updated! {global_server_args_dict=}") logger.info(f"Global server args updated! {get_global_server_args()=}")
return SetInternalStateReqOutput( return SetInternalStateReqOutput(
updated=True, updated=True,
server_args=global_server_args_dict, server_args=vars(get_global_server_args()),
) )
def handle_rpc_request(self, recv_req: RpcReqInput): def handle_rpc_request(self, recv_req: RpcReqInput):

View File

@@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
) )
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
@@ -190,7 +190,6 @@ class TpModelWorker:
self.max_req_input_len, self.max_req_input_len,
self.random_seed, self.random_seed,
self.device, self.device,
global_server_args_dict,
self.model_runner.req_to_token_pool.size, self.model_runner.req_to_token_pool.size,
self.model_runner.req_to_token_pool.max_context_len, self.model_runner.req_to_token_pool.max_context_len,
self.model_runner.token_to_kv_pool.size, self.model_runner.token_to_kv_pool.size,

View File

@@ -11,7 +11,7 @@ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import support_triton from sglang.srt.utils import support_triton
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -19,10 +19,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GLOBAL_SERVER_ARGS_KEYS = ["attention_backend"]
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
@triton.jit @triton.jit
def write_req_to_token_pool_triton( def write_req_to_token_pool_triton(
@@ -88,7 +84,7 @@ def write_cache_indices(
prefix_tensors: list[torch.Tensor], prefix_tensors: list[torch.Tensor],
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
): ):
if support_triton(global_server_args_dict.get("attention_backend")): if support_triton(get_global_server_args().attention_backend):
prefix_pointers = torch.tensor( prefix_pointers = torch.tensor(
[t.data_ptr() for t in prefix_tensors], [t.data_ptr() for t in prefix_tensors],
device=req_to_token_pool.device, device=req_to_token_pool.device,
@@ -129,8 +125,8 @@ def get_last_loc(
prefix_lens_tensor: torch.Tensor, prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
if ( if (
global_server_args_dict["attention_backend"] != "ascend" get_global_server_args().attention_backend != "ascend"
and global_server_args_dict["attention_backend"] != "torch_native" and get_global_server_args().attention_backend != "torch_native"
): ):
impl = get_last_loc_triton impl = get_last_loc_triton
else: else:

View File

@@ -83,10 +83,6 @@ from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.managers.schedule_batch import (
GLOBAL_SERVER_ARGS_KEYS,
global_server_args_dict,
)
from sglang.srt.mem_cache.allocator import ( from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator, BaseTokenToKVPoolAllocator,
PagedTokenToKVPoolAllocator, PagedTokenToKVPoolAllocator,
@@ -125,7 +121,11 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import (
ServerArgs,
get_global_server_args,
set_global_server_args_for_scheduler,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import ( from sglang.srt.utils import (
MultiprocessingSerializer, MultiprocessingSerializer,
@@ -275,15 +275,12 @@ class ModelRunner:
# Model-specific adjustment # Model-specific adjustment
self.model_specific_adjustment() self.model_specific_adjustment()
# Global vars # Set the global server_args in the scheduler process
global_server_args_dict.update( set_global_server_args_for_scheduler(server_args)
{k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS} global_server_args = get_global_server_args()
| {
# TODO it is indeed not a "server args" # FIXME: hacky set `use_mla_backend`
"use_mla_backend": self.use_mla_backend, global_server_args.use_mla_backend = self.use_mla_backend
"speculative_algorithm": self.spec_algorithm,
}
)
# Init OpenMP threads binding for CPU # Init OpenMP threads binding for CPU
if self.device == "cpu": if self.device == "cpu":
@@ -432,7 +429,7 @@ class ModelRunner:
# In layered loading, torchao may have been applied # In layered loading, torchao may have been applied
if not torchao_applied: if not torchao_applied:
apply_torchao_config_to_model( apply_torchao_config_to_model(
self.model, global_server_args_dict["torchao_config"] self.model, get_global_server_args().torchao_config
) )
# Apply torch TP if the model supports it # Apply torch TP if the model supports it
@@ -1838,12 +1835,10 @@ class ModelRunner:
self.server_args.attention_backend self.server_args.attention_backend
) )
global_server_args_dict.update( (
{ get_global_server_args().prefill_attention_backend,
"decode_attention_backend": self.decode_attention_backend_str, get_global_server_args().decode_attention_backend,
"prefill_attention_backend": self.prefill_attention_backend_str, ) = (self.prefill_attention_backend_str, self.decode_attention_backend_str)
}
)
return attn_backend return attn_backend
def _get_attention_backend_from_str(self, backend_str: str): def _get_attention_backend_from_str(self, backend_str: str):

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
# ruff: noqa: SIM117 # ruff: noqa: SIM117
import collections import collections
import concurrent
import dataclasses import dataclasses
import fnmatch import fnmatch
import glob import glob
@@ -12,12 +11,10 @@ import json
import logging import logging
import math import math
import os import os
import re
import socket import socket
import threading import threading
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager, suppress from contextlib import contextmanager, suppress
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@@ -33,10 +30,10 @@ from typing import (
import huggingface_hub import huggingface_hub
import numpy as np import numpy as np
import requests
import safetensors.torch
import torch import torch
from sglang.srt.server_args import get_global_server_args
# Try to import accelerate (optional dependency) # Try to import accelerate (optional dependency)
try: try:
from accelerate import infer_auto_device_map, init_empty_weights from accelerate import infer_auto_device_map, init_empty_weights
@@ -81,8 +78,6 @@ DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration 0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration
) )
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
_BAR_FORMAT,
default_weight_loader,
download_safetensors_index_file_from_hf, download_safetensors_index_file_from_hf,
download_weights_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_duplicate_safetensors_files,
@@ -445,10 +440,8 @@ class DefaultModelLoader(BaseModelLoader):
hf_weights_files, hf_weights_files,
) )
elif use_safetensors: elif use_safetensors:
from sglang.srt.managers.schedule_batch import global_server_args_dict weight_loader_disable_mmap = (
get_global_server_args().weight_loader_disable_mmap
weight_loader_disable_mmap = global_server_args_dict.get(
"weight_loader_disable_mmap"
) )
if extra_config.get("enable_multithread_load"): if extra_config.get("enable_multithread_load"):
@@ -616,9 +609,9 @@ class LayeredModelLoader(DefaultModelLoader):
device_config: DeviceConfig, device_config: DeviceConfig,
) -> nn.Module: ) -> nn.Module:
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.server_args import get_global_server_args
torchao_config = global_server_args_dict.get("torchao_config") torchao_config = get_global_server_args().torchao_config
target_device = torch.device(device_config.device) target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):

View File

@@ -46,15 +46,14 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
kv_cache_scales_loader, kv_cache_scales_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, make_layers
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -447,7 +446,7 @@ class ApertusForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

View File

@@ -42,13 +42,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
kv_cache_scales_loader, kv_cache_scales_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, make_layers
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -407,7 +407,7 @@ class ArceeForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

View File

@@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" SGLang BailingMoE model.""" """SGLang BailingMoE model."""
import logging import logging
from typing import Any, Dict, Iterable, Optional, Tuple, Union from typing import Any, Dict, Iterable, Optional, Tuple, Union
@@ -68,7 +68,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -76,6 +75,7 @@ from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg, create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer, enable_fused_set_kv_buffer,
) )
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
LoraConfig = None LoraConfig = None
@@ -204,8 +204,8 @@ class BailingMoESparseMoeBlock(nn.Module):
else: else:
self.router_dtype = torch.bfloat16 self.router_dtype = torch.bfloat16
# TODO global_server_args_dict["ep_num_redundant_experts"] is used for eplb, not supported now # TODO global_server_args.ep_num_redundant_experts is used for eplb, not supported now
assert global_server_args_dict["ep_num_redundant_experts"] == 0 assert get_global_server_args().ep_num_redundant_experts == 0
# check group topk # check group topk
self.num_expert_group = getattr(config, "n_group", 0) self.num_expert_group = getattr(config, "n_group", 0)
self.topk_group = getattr(config, "topk_group", 0) self.topk_group = getattr(config, "topk_group", 0)
@@ -220,7 +220,7 @@ class BailingMoESparseMoeBlock(nn.Module):
self.use_grouped_topk = False self.use_grouped_topk = False
self.num_experts = ( self.num_experts = (
config.num_experts + global_server_args_dict["ep_num_redundant_experts"] config.num_experts + get_global_server_args().ep_num_redundant_experts
) )
self.gate = BailingMoEGate( self.gate = BailingMoEGate(
@@ -824,7 +824,7 @@ class BailingMoEForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)

View File

@@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" SGLang BailingMoENextN model.""" """SGLang BailingMoENextN model."""
import logging import logging
from typing import Iterable, Optional, Tuple from typing import Iterable, Optional, Tuple
@@ -29,15 +29,14 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.dp_attention import is_dp_attention_enabled from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.bailing_moe import BailingMoEBlock, BailingMoEForCausalLM from sglang.srt.models.bailing_moe import BailingMoEBlock, BailingMoEForCausalLM
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
LoraConfig = None LoraConfig = None
@@ -145,7 +144,7 @@ class BailingMoeForCausalLMNextN(BailingMoEForCausalLM):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix), prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)

View File

@@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -152,7 +152,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix), prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)

View File

@@ -35,7 +35,6 @@ from sglang.srt.configs.model_config import (
get_nsa_index_topk, get_nsa_index_topk,
is_deepseek_nsa, is_deepseek_nsa,
) )
from sglang.srt.debug_utils.dumper import dumper
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_moe_expert_parallel_world_size, get_moe_expert_parallel_world_size,
get_pp_group, get_pp_group,
@@ -108,10 +107,11 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args
from sglang.srt.single_batch_overlap import SboFlags from sglang.srt.single_batch_overlap import SboFlags
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.two_batch_overlap import ( from sglang.srt.two_batch_overlap import (
MaybeTboDeepEPDispatcher, MaybeTboDeepEPDispatcher,
model_forward_maybe_tbo, model_forward_maybe_tbo,
@@ -520,7 +520,7 @@ class DeepseekV2MoE(nn.Module):
self.n_shared_experts = config.n_shared_experts self.n_shared_experts = config.n_shared_experts
self.num_fused_shared_experts = ( self.num_fused_shared_experts = (
0 0
if global_server_args_dict["disable_shared_experts_fusion"] if get_global_server_args().disable_shared_experts_fusion
else config.n_shared_experts else config.n_shared_experts
) )
self.config = config self.config = config
@@ -549,7 +549,7 @@ class DeepseekV2MoE(nn.Module):
self.experts = get_moe_impl_class(quant_config)( self.experts = get_moe_impl_class(quant_config)(
num_experts=config.n_routed_experts num_experts=config.n_routed_experts
+ self.num_fused_shared_experts + self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"], + get_global_server_args().ep_num_redundant_experts,
num_fused_shared_experts=self.num_fused_shared_experts, num_fused_shared_experts=self.num_fused_shared_experts,
top_k=config.num_experts_per_tok + self.num_fused_shared_experts, top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
@@ -627,7 +627,7 @@ class DeepseekV2MoE(nn.Module):
self.ep_size = get_moe_expert_parallel_world_size() self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = ( self.num_experts = (
config.n_routed_experts config.n_routed_experts
+ global_server_args_dict["ep_num_redundant_experts"] + get_global_server_args().ep_num_redundant_experts
) )
self.renormalize = config.norm_topk_prob self.renormalize = config.norm_topk_prob
self.topk_group = config.topk_group self.topk_group = config.topk_group
@@ -1125,7 +1125,7 @@ class DeepseekV2AttentionMLA(nn.Module):
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
is_neox_style=False, is_neox_style=False,
device=global_server_args_dict["device"], device=get_global_server_args().device,
) )
if rope_scaling: if rope_scaling:
@@ -1169,12 +1169,12 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_scale_v = None self.w_scale_v = None
self.use_deep_gemm_bmm = False self.use_deep_gemm_bmm = False
self.flashinfer_mla_disable_ragged = global_server_args_dict[ self.flashinfer_mla_disable_ragged = (
"flashinfer_mla_disable_ragged" get_global_server_args().flashinfer_mla_disable_ragged,
] )
self.disable_chunked_prefix_cache = global_server_args_dict[ self.disable_chunked_prefix_cache = (
"disable_chunked_prefix_cache" get_global_server_args().disable_chunked_prefix_cache
] )
self.current_attention_backend = ( self.current_attention_backend = (
None # Attention backend used by current forward batch None # Attention backend used by current forward batch
@@ -1253,18 +1253,18 @@ class DeepseekV2AttentionMLA(nn.Module):
) -> AttnForwardMethod: ) -> AttnForwardMethod:
# Determine attention backend used by current forward batch # Determine attention backend used by current forward batch
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
attention_backend = global_server_args_dict["decode_attention_backend"] attention_backend = get_global_server_args().decode_attention_backend
elif ( elif (
forward_batch.forward_mode.is_target_verify() forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend() or forward_batch.forward_mode.is_draft_extend()
): ):
# Use the specified backend for speculative operations (both verify and draft extend) # Use the specified backend for speculative operations (both verify and draft extend)
if global_server_args_dict["speculative_attention_mode"] == "decode": if get_global_server_args().speculative_attention_mode == "decode":
attention_backend = global_server_args_dict["decode_attention_backend"] attention_backend = get_global_server_args().decode_attention_backend
else: # default to prefill else: # default to prefill
attention_backend = global_server_args_dict["prefill_attention_backend"] attention_backend = get_global_server_args().prefill_attention_backend
else: else:
attention_backend = global_server_args_dict["prefill_attention_backend"] attention_backend = get_global_server_args().prefill_attention_backend
self.current_attention_backend = attention_backend self.current_attention_backend = attention_backend
handler = AttentionBackendRegistry.get_handler(attention_backend) handler = AttentionBackendRegistry.get_handler(attention_backend)
@@ -2365,7 +2365,9 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"] self.speculative_algorithm = SpeculativeAlgorithm.from_string(
get_global_server_args().speculative_algorithm
)
self.layer_id = layer_id self.layer_id = layer_id
self.is_nextn = is_nextn self.is_nextn = is_nextn
self.self_attn = DeepseekV2AttentionMLA( self.self_attn = DeepseekV2AttentionMLA(
@@ -2817,7 +2819,7 @@ class DeepseekV2ForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@@ -2837,7 +2839,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self, architecture: str = "DeepseekV3ForCausalLM" self, architecture: str = "DeepseekV3ForCausalLM"
): ):
self.num_fused_shared_experts = 0 self.num_fused_shared_experts = 0
if global_server_args_dict["disable_shared_experts_fusion"]: if get_global_server_args().disable_shared_experts_fusion:
return return
# Only Deepseek V3/R1 can use shared experts fusion optimization now. # Only Deepseek V3/R1 can use shared experts fusion optimization now.
@@ -2856,7 +2858,7 @@ class DeepseekV2ForCausalLM(nn.Module):
disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts." disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
if disable_reason is not None: if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True get_global_server_args().disable_shared_experts_fusion = True
self.num_fused_shared_experts = 0 self.num_fused_shared_experts = 0
log_info_on_rank0( log_info_on_rank0(
logger, logger,

View File

@@ -33,9 +33,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_cuda, make_layers from sglang.srt.utils import add_prefix, is_cuda, make_layers
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -483,7 +483,7 @@ class FalconH1ForCausalLM(nn.Module):
quant_config=quant_config, quant_config=quant_config,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.lm_head = self.lm_head.float() self.lm_head = self.lm_head.float()
self.lm_head_multiplier = config.lm_head_multiplier self.lm_head_multiplier = config.lm_head_multiplier

View File

@@ -56,18 +56,13 @@ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
is_fp8_fnuz,
per_tensor_quant_mla_fp8,
per_token_group_quant_mla_deep_gemm_masked_fp8,
)
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -77,6 +72,7 @@ from sglang.srt.models.deepseek_v2 import (
DeepseekV2Model, DeepseekV2Model,
DeepseekV2MoE, DeepseekV2MoE,
) )
from sglang.srt.server_args import get_global_server_args
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import ( from sglang.srt.utils import (
BumpAllocator, BumpAllocator,
@@ -395,7 +391,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.n_shared_experts = config.n_shared_experts self.n_shared_experts = config.n_shared_experts
self.num_fused_shared_experts = ( self.num_fused_shared_experts = (
0 0
if global_server_args_dict["disable_shared_experts_fusion"] if get_global_server_args().disable_shared_experts_fusion
else config.n_shared_experts else config.n_shared_experts
) )
self.config = config self.config = config
@@ -432,7 +428,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.experts = get_moe_impl_class(quant_config)( self.experts = get_moe_impl_class(quant_config)(
num_experts=config.n_routed_experts num_experts=config.n_routed_experts
+ self.num_fused_shared_experts + self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"], + get_global_server_args().ep_num_redundant_experts,
num_fused_shared_experts=self.num_fused_shared_experts, num_fused_shared_experts=self.num_fused_shared_experts,
top_k=config.num_experts_per_tok + self.num_fused_shared_experts, top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
@@ -476,7 +472,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.ep_size = get_moe_expert_parallel_world_size() self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = ( self.num_experts = (
config.n_routed_experts config.n_routed_experts
+ global_server_args_dict["ep_num_redundant_experts"] + get_global_server_args().ep_num_redundant_experts
) )
self.renormalize = config.norm_topk_prob self.renormalize = config.norm_topk_prob
self.topk_group = config.topk_group self.topk_group = config.topk_group
@@ -758,7 +754,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@@ -774,7 +770,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
self, architecture: str = "Glm4MoeForCausalLM" self, architecture: str = "Glm4MoeForCausalLM"
): ):
self.num_fused_shared_experts = 0 self.num_fused_shared_experts = 0
if global_server_args_dict["disable_shared_experts_fusion"]: if get_global_server_args().disable_shared_experts_fusion:
return return
# Only Deepseek V3/R1 can use shared experts fusion optimization now. # Only Deepseek V3/R1 can use shared experts fusion optimization now.
@@ -790,7 +786,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism." disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None: if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True get_global_server_args().disable_shared_experts_fusion = True
self.num_fused_shared_experts = 0 self.num_fused_shared_experts = 0
log_info_on_rank0( log_info_on_rank0(
logger, logger,

View File

@@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import BumpAllocator, add_prefix from sglang.srt.utils import BumpAllocator, add_prefix
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -145,7 +145,7 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix), prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)

View File

@@ -16,10 +16,10 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.glm4_moe import Glm4MoeModel from sglang.srt.models.glm4_moe import Glm4MoeModel
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0 from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
from sglang.srt.utils.hf_transformers_utils import get_processor from sglang.srt.utils.hf_transformers_utils import get_processor
@@ -47,7 +47,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM") self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
self.num_fused_shared_experts = ( self.num_fused_shared_experts = (
0 0
if global_server_args_dict["disable_shared_experts_fusion"] if get_global_server_args().disable_shared_experts_fusion
else config.n_shared_experts else config.n_shared_experts
) )
@@ -68,7 +68,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -81,7 +81,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
self, architecture: str = "Glm4MoeForCausalLM" self, architecture: str = "Glm4MoeForCausalLM"
): ):
self.num_fused_shared_experts = 0 self.num_fused_shared_experts = 0
if global_server_args_dict["disable_shared_experts_fusion"]: if get_global_server_args().disable_shared_experts_fusion:
return return
# Only Deepseek V3/R1 can use shared experts fusion optimization now. # Only Deepseek V3/R1 can use shared experts fusion optimization now.
@@ -97,7 +97,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism." disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None: if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True get_global_server_args().disable_shared_experts_fusion = True
self.num_fused_shared_experts = 0 self.num_fused_shared_experts = 0
log_info_on_rank0( log_info_on_rank0(
logger, logger,

View File

@@ -63,13 +63,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.utils import ( from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg, create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer, enable_fused_set_kv_buffer,
) )
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import ( from sglang.srt.utils import (
LazyValue, LazyValue,
add_prefix, add_prefix,
@@ -138,7 +138,7 @@ class GptOssSparseMoeBlock(nn.Module):
} }
self.experts = experts_type( self.experts = experts_type(
num_experts=config.num_local_experts num_experts=config.num_local_experts
+ global_server_args_dict["ep_num_redundant_experts"], + get_global_server_args().ep_num_redundant_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
layer_id=layer_id, layer_id=layer_id,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
@@ -259,7 +259,7 @@ class GptOssAttention(nn.Module):
# Choose dtype of sinks based on attention backend: trtllm_mha requires float32, # Choose dtype of sinks based on attention backend: trtllm_mha requires float32,
# others can use bfloat16 # others can use bfloat16
attn_backend = global_server_args_dict.get("attention_backend") attn_backend = get_global_server_args().attention_backend
sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16 sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16
self.sinks = nn.Parameter( self.sinks = nn.Parameter(
torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False
@@ -591,7 +591,7 @@ class GptOssForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
# quant_config=quant_config, # quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.capture_aux_hidden_states = False self.capture_aux_hidden_states = False

View File

@@ -28,7 +28,6 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
@@ -36,7 +35,6 @@ from sglang.srt.distributed import (
) )
from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.elementwise import ( from sglang.srt.layers.elementwise import (
experts_combine_triton,
fused_dual_residual_rmsnorm, fused_dual_residual_rmsnorm,
fused_rmsnorm, fused_rmsnorm,
gelu_and_mul_triton, gelu_and_mul_triton,
@@ -64,10 +62,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.loader import DefaultModelLoader
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -869,10 +867,10 @@ class Grok1ForCausalLM(nn.Module):
# Dump tensors for debugging # Dump tensors for debugging
global debug_tensor_dump_output_folder, debug_tensor_dump_inject global debug_tensor_dump_output_folder, debug_tensor_dump_inject
debug_tensor_dump_output_folder = global_server_args_dict[ debug_tensor_dump_output_folder = (
"debug_tensor_dump_output_folder" get_global_server_args().debug_tensor_dump_output_folder
] )
debug_tensor_dump_inject = global_server_args_dict["debug_tensor_dump_inject"] debug_tensor_dump_inject = get_global_server_args().debug_tensor_dump_inject
warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=FutureWarning)
if get_tensor_model_parallel_rank() == 0: if get_tensor_model_parallel_rank() == 0:

View File

@@ -45,13 +45,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
kv_cache_scales_loader, kv_cache_scales_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, make_layers
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
@@ -433,7 +433,7 @@ class LlamaForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

View File

@@ -32,14 +32,10 @@
import concurrent.futures import concurrent.futures
import logging import logging
import os from typing import Iterable, Optional, Tuple
from enum import IntEnum, auto
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from tqdm import tqdm
from sglang.srt.configs import LongcatFlashConfig from sglang.srt.configs import LongcatFlashConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
@@ -85,10 +81,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2AttentionMLA from sglang.srt.models.deepseek_v2 import DeepseekV2AttentionMLA
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import ( from sglang.srt.utils import (
BumpAllocator, BumpAllocator,
LazyValue, LazyValue,
@@ -595,7 +591,7 @@ class LongcatFlashForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)

View File

@@ -31,9 +31,9 @@ from sglang.srt.managers.schedule_batch import (
Modality, Modality,
MultimodalDataItem, MultimodalDataItem,
MultimodalInputs, MultimodalInputs,
global_server_args_dict,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import is_cpu from sglang.srt.utils import is_cpu
_is_cpu = is_cpu() _is_cpu = is_cpu()
@@ -448,7 +448,7 @@ class Llama4ForConditionalGeneration(nn.Module):
) )
self.has_vision = ( self.has_vision = (
self.has_vision_weights and global_server_args_dict["enable_multimodal"] self.has_vision_weights and get_global_server_args().enable_multimodal
) )
if self.has_vision: if self.has_vision:

View File

@@ -64,10 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
from sglang.srt.utils import add_prefix, is_cuda, make_layers from sglang.srt.utils import add_prefix, is_cuda, make_layers
@@ -156,7 +156,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
layer_id=self.layer_id, layer_id=self.layer_id,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
num_experts=config.num_experts num_experts=config.num_experts
+ global_server_args_dict["ep_num_redundant_experts"], + get_global_server_args().ep_num_redundant_experts,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
quant_config=quant_config, quant_config=quant_config,
@@ -192,7 +192,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
# TODO: we will support tp < ep in the future # TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size() self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = ( self.num_experts = (
config.num_experts + global_server_args_dict["ep_num_redundant_experts"] config.num_experts + get_global_server_args().ep_num_redundant_experts
) )
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
@@ -643,7 +643,7 @@ class Qwen2MoeForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
# For EAGLE3 support # For EAGLE3 support

View File

@@ -54,7 +54,6 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope
from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -64,6 +63,7 @@ from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg, create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer, enable_fused_set_kv_buffer,
) )
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import ( from sglang.srt.utils import (
add_prefix, add_prefix,
is_cuda, is_cuda,
@@ -104,7 +104,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self.experts = get_moe_impl_class(quant_config)( self.experts = get_moe_impl_class(quant_config)(
num_experts=config.num_experts num_experts=config.num_experts
+ global_server_args_dict["ep_num_redundant_experts"], + get_global_server_args().ep_num_redundant_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
layer_id=layer_id, layer_id=layer_id,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
@@ -125,7 +125,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# TODO: we will support tp < ep in the future # TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size() self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = ( self.num_experts = (
config.num_experts + global_server_args_dict["ep_num_redundant_experts"] config.num_experts + get_global_server_args().ep_num_redundant_experts
) )
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
@@ -693,7 +693,7 @@ class Qwen3MoeForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.capture_aux_hidden_states = False self.capture_aux_hidden_states = False

View File

@@ -39,7 +39,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
@@ -47,6 +46,7 @@ from sglang.srt.model_loader.weight_utils import (
sharded_weight_loader, sharded_weight_loader,
) )
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import ( from sglang.srt.utils import (
LazyValue, LazyValue,
add_prefix, add_prefix,
@@ -905,7 +905,7 @@ class Qwen3NextForCausalLM(nn.Module):
quant_config=quant_config, quant_config=quant_config,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.lm_head = self.lm_head.float() self.lm_head = self.lm_head.float()
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)

View File

@@ -21,14 +21,13 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm from sglang.srt.layers.layernorm import GemmaRMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.qwen3_moe import Qwen3MoeModel
from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -69,7 +68,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix), prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)

View File

@@ -38,20 +38,12 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import ( from sglang.srt.managers.mm_utils import general_mm_embed_routine
MultiModalityDataPaddingPatternMultimodalTokens, from sglang.srt.managers.schedule_batch import MultimodalDataItem
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import (
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel from sglang.srt.models.qwen3_moe import Qwen3MoeModel
from sglang.srt.models.qwen3_vl import ( from sglang.srt.models.qwen3_vl import (
Qwen3_VisionTransformer, Qwen3_VisionTransformer,
Qwen3VLForConditionalGeneration, Qwen3VLForConditionalGeneration,

View File

@@ -57,7 +57,6 @@ from sglang.srt.managers.schedule_batch import (
Modality, Modality,
MultimodalDataItem, MultimodalDataItem,
MultimodalInputs, MultimodalInputs,
global_server_args_dict,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -300,7 +299,7 @@ class Step3TextDecoderLayer(nn.Module):
# self.n_shared_experts = 1 # self.n_shared_experts = 1
# self.num_fused_shared_experts = ( # self.num_fused_shared_experts = (
# 0 # 0
# if global_server_args_dict["disable_shared_experts_fusion"] # if global_server_args.disable_shared_experts_fusion
# else self.n_shared_experts # else self.n_shared_experts
# ) # )
self.num_fused_shared_experts = 0 self.num_fused_shared_experts = 0
@@ -774,7 +773,7 @@ class Step3VLForConditionalGeneration(nn.Module):
# self.n_shared_experts = 1 # self.n_shared_experts = 1
# self.num_fused_shared_experts = ( # self.num_fused_shared_experts = (
# 0 # 0
# if global_server_args_dict["disable_shared_experts_fusion"] # if global_server_args.disable_shared_experts_fusion
# else self.n_shared_experts # else self.n_shared_experts
# ) # )
self.num_fused_shared_experts = 0 self.num_fused_shared_experts = 0

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import dataclasses import dataclasses
import logging import logging
import threading
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
import torch import torch
@@ -10,6 +9,7 @@ import torch
import sglang.srt.sampling.penaltylib as penaltylib import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_params import TOP_K_ALL from sglang.srt.sampling.sampling_params import TOP_K_ALL
from sglang.srt.server_args import get_global_server_args
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -66,16 +66,10 @@ class SamplingBatchInfo:
# Handle logit bias # Handle logit bias
logit_bias: Optional[torch.Tensor] = None logit_bias: Optional[torch.Tensor] = None
@classmethod
def _get_global_server_args_dict(cls):
from sglang.srt.managers.schedule_batch import global_server_args_dict
return global_server_args_dict
@classmethod @classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
global_server_args_dict = cls._get_global_server_args_dict() global_server_args = get_global_server_args()
enable_deterministic = global_server_args_dict["enable_deterministic_inference"] enable_deterministic = global_server_args.enable_deterministic_inference
reqs = batch.reqs reqs = batch.reqs
device = batch.device device = batch.device
@@ -112,10 +106,9 @@ class SamplingBatchInfo:
logit_bias[i, int(key)] = value logit_bias[i, int(key)] = value
# Check if any request has custom logit processor # Check if any request has custom logit processor
has_custom_logit_processor = global_server_args_dict[ has_custom_logit_processor = (
"enable_custom_logit_processor" global_server_args.enable_custom_logit_processor
] and any( # check the flag first. and any(r.custom_logit_processor for r in reqs) # check the flag first.
r.custom_logit_processor for r in reqs
) # then check the requests. ) # then check the requests.
if has_custom_logit_processor: if has_custom_logit_processor:

View File

@@ -53,6 +53,7 @@ from sglang.utils import is_in_ci
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Define constants # Define constants
LOAD_FORMAT_CHOICES = [ LOAD_FORMAT_CHOICES = [
"auto", "auto",
@@ -3329,6 +3330,22 @@ class ServerArgs:
) )
# NOTE: This is a global variable to hold the server args for scheduler.
_global_server_args: Optional[ServerArgs] = None
def set_global_server_args_for_scheduler(server_args: ServerArgs):
global _global_server_args
_global_server_args = server_args
def get_global_server_args() -> ServerArgs:
if _global_server_args is None:
raise ValueError("Global server args is not set yet!")
return _global_server_args
def prepare_server_args(argv: List[str]) -> ServerArgs: def prepare_server_args(argv: List[str]) -> ServerArgs:
""" """
Prepare the server arguments from the command line arguments. Prepare the server arguments from the command line arguments.
@@ -3363,8 +3380,8 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser) ServerArgs.add_cli_args(parser)
raw_args = parser.parse_args(argv) raw_args = parser.parse_args(argv)
server_args = ServerArgs.from_cli_args(raw_args)
return server_args return ServerArgs.from_cli_args(raw_args)
ZMQ_TCP_PORT_DELTA = 233 ZMQ_TCP_PORT_DELTA = 233

View File

@@ -6,7 +6,6 @@ import torch
from sglang.srt.layers.moe import get_moe_runner_backend from sglang.srt.layers.moe import get_moe_runner_backend
from sglang.srt.layers.moe.utils import is_sbo_enabled from sglang.srt.layers.moe.utils import is_sbo_enabled
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_int_env_var from sglang.srt.utils import get_int_env_var

View File

@@ -11,7 +11,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.overlap_utils import FutureIndices from sglang.srt.managers.overlap_utils import FutureIndices
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.common import ( from sglang.srt.mem_cache.common import (
alloc_paged_token_slots_extend, alloc_paged_token_slots_extend,
@@ -19,6 +19,7 @@ from sglang.srt.mem_cache.common import (
get_last_loc, get_last_loc,
) )
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.eagle_info_v2 import ( from sglang.srt.speculative.eagle_info_v2 import (
EagleDraftInputV2Mixin, EagleDraftInputV2Mixin,
EagleVerifyInputV2Mixin, EagleVerifyInputV2Mixin,
@@ -332,12 +333,8 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
uniform_samples_for_final_sampling=coins_for_final_sampling, uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs, target_probs=target_probs,
draft_probs=draft_probs, draft_probs=draft_probs,
threshold_single=global_server_args_dict[ threshold_single=get_global_server_args().speculative_accept_threshold_single,
"speculative_accept_threshold_single" threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
],
threshold_acc=global_server_args_dict[
"speculative_accept_threshold_acc"
],
deterministic=True, deterministic=True,
) )

View File

@@ -11,7 +11,6 @@ import triton.language as tl
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.scheduler import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode, CaptureHiddenMode,
@@ -19,6 +18,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode, ForwardMode,
) )
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.build_eagle_tree import TreeMaskMode from sglang.srt.speculative.build_eagle_tree import TreeMaskMode
from sglang.srt.speculative.spec_utils import ( from sglang.srt.speculative.spec_utils import (
SIMULATE_ACC_LEN, SIMULATE_ACC_LEN,
@@ -265,12 +265,8 @@ class EagleVerifyInputV2Mixin:
uniform_samples_for_final_sampling=coins_for_final_sampling, uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs, target_probs=target_probs,
draft_probs=draft_probs, draft_probs=draft_probs,
threshold_single=global_server_args_dict[ threshold_single=get_global_server_args().speculative_accept_threshold_single,
"speculative_accept_threshold_single" threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
],
threshold_acc=global_server_args_dict[
"speculative_accept_threshold_acc"
],
deterministic=True, deterministic=True,
) )

View File

@@ -14,7 +14,7 @@ from sglang.srt.distributed import (
) )
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.mem_cache.common import ( from sglang.srt.mem_cache.common import (
@@ -27,7 +27,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs, get_global_server_args
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner, EAGLEDraftCudaGraphRunner,
@@ -261,7 +261,7 @@ class EAGLEWorker(TpModelWorker):
) )
def _create_flashinfer_decode_backend(self): def _create_flashinfer_decode_backend(self):
if not global_server_args_dict["use_mla_backend"]: if not get_global_server_args().use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferMultiStepDraftBackend, FlashInferMultiStepDraftBackend,
) )
@@ -325,7 +325,7 @@ class EAGLEWorker(TpModelWorker):
) )
def _create_trtllm_mla_decode_backend(self): def _create_trtllm_mla_decode_backend(self):
if not global_server_args_dict["use_mla_backend"]: if not get_global_server_args().use_mla_backend:
raise ValueError( raise ValueError(
"trtllm_mla backend requires MLA model (use_mla_backend=True)." "trtllm_mla backend requires MLA model (use_mla_backend=True)."
) )
@@ -340,7 +340,7 @@ class EAGLEWorker(TpModelWorker):
) )
def _create_flashinfer_prefill_backend(self): def _create_flashinfer_prefill_backend(self):
if not global_server_args_dict["use_mla_backend"]: if not get_global_server_args().use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferAttnBackend, FlashInferAttnBackend,
) )
@@ -376,7 +376,7 @@ class EAGLEWorker(TpModelWorker):
return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False) return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_trtllm_mla_prefill_backend(self): def _create_trtllm_mla_prefill_backend(self):
if not global_server_args_dict["use_mla_backend"]: if not get_global_server_args().use_mla_backend:
raise ValueError( raise ValueError(
"trtllm_mla backend requires MLA model (use_mla_backend=True)." "trtllm_mla backend requires MLA model (use_mla_backend=True)."
) )

View File

@@ -7,6 +7,8 @@ from typing import Optional, Tuple
import torch import torch
import triton import triton
from sglang.srt.server_args import get_global_server_args
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from dataclasses import dataclass from dataclasses import dataclass
@@ -16,7 +18,7 @@ import torch.nn.functional as F
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.common import ( from sglang.srt.mem_cache.common import (
alloc_paged_token_slots_extend, alloc_paged_token_slots_extend,
alloc_token_slots, alloc_token_slots,
@@ -350,10 +352,8 @@ class NgramVerifyInput(SpecInput):
uniform_samples_for_final_sampling=coins_for_final_sampling, uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs, target_probs=target_probs,
draft_probs=draft_probs, draft_probs=draft_probs,
threshold_single=global_server_args_dict[ threshold_single=get_global_server_args().speculative_accept_threshold_single,
"speculative_accept_threshold_single" threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
],
threshold_acc=global_server_args_dict["speculative_accept_threshold_acc"],
deterministic=True, deterministic=True,
) )

View File

@@ -22,7 +22,7 @@ from sglang.srt.layers.moe import (
) )
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
@@ -30,6 +30,7 @@ from sglang.srt.model_executor.forward_batch_info import (
) )
from sglang.srt.operations import execute_operations, execute_overlapped_operations from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip
@@ -153,7 +154,7 @@ def _update_device_and_sum_field_from_cpu_field(
cpu_value cpu_value
if isinstance(cpu_value, torch.Tensor) if isinstance(cpu_value, torch.Tensor)
else torch.tensor(cpu_value, dtype=old_device_value.dtype) else torch.tensor(cpu_value, dtype=old_device_value.dtype)
).to(device=global_server_args_dict["device"], non_blocking=True) ).to(device=get_global_server_args().device, non_blocking=True)
setattr(batch, device_field, new_device_value) setattr(batch, device_field, new_device_value)
if sum_field is not None: if sum_field is not None:
@@ -582,7 +583,7 @@ class TboForwardBatchPreparer:
sum_field=None, sum_field=None,
) )
_, child_b.extend_start_loc = compute_position( _, child_b.extend_start_loc = compute_position(
global_server_args_dict["attention_backend"], get_global_server_args().attention_backend,
child_b.extend_prefix_lens, child_b.extend_prefix_lens,
child_b.extend_seq_lens, child_b.extend_seq_lens,
child_b.extend_num_tokens, child_b.extend_num_tokens,
@@ -687,7 +688,7 @@ class TboForwardBatchPreparer:
# TODO improve, e.g. unify w/ `init_raw` # TODO improve, e.g. unify w/ `init_raw`
if ( if (
global_server_args_dict["moe_dense_tp_size"] == 1 get_global_server_args().moe_dense_tp_size == 1
and batch.global_dp_buffer_len is not None and batch.global_dp_buffer_len is not None
): ):
sum_len = end_token_index - start_token_index sum_len = end_token_index - start_token_index
@@ -755,7 +756,7 @@ class TboForwardBatchPreparer:
value_a = min(tbo_split_token_index, num_token_non_padded) value_a = min(tbo_split_token_index, num_token_non_padded)
value_b = max(0, num_token_non_padded - tbo_split_token_index) value_b = max(0, num_token_non_padded - tbo_split_token_index)
return torch.tensor([value_a, value_b], dtype=torch.int32).to( return torch.tensor([value_a, value_b], dtype=torch.int32).to(
device=global_server_args_dict["device"], non_blocking=True device=get_global_server_args().device, non_blocking=True
) )
@classmethod @classmethod

View File

@@ -7,7 +7,11 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.server_args import (
ServerArgs,
get_global_server_args,
set_global_server_args_for_scheduler,
)
class LMHeadStub(nn.Module): class LMHeadStub(nn.Module):
@@ -32,8 +36,10 @@ class TestLMHeadFP32(unittest.TestCase):
raise unittest.SkipTest("needs CUDA GPU") raise unittest.SkipTest("needs CUDA GPU")
def _make_logprocessor(self, vocab_size, enable_fp32): def _make_logprocessor(self, vocab_size, enable_fp32):
global_server_args_dict["enable_dp_lm_head"] = False ServerArgs.__post_init__ = lambda self: None # disable validation
global_server_args_dict["enable_fp32_lm_head"] = enable_fp32 set_global_server_args_for_scheduler(ServerArgs(model_path="dummy"))
get_global_server_args().enable_dp_lm_head = False
get_global_server_args().enable_fp32_lm_head = enable_fp32
cfg = SimpleNamespace(vocab_size=vocab_size, final_logit_softcapping=None) cfg = SimpleNamespace(vocab_size=vocab_size, final_logit_softcapping=None)
return LogitsProcessor(cfg, skip_all_gather=True, logit_scale=None) return LogitsProcessor(cfg, skip_all_gather=True, logit_scale=None)

View File

@@ -4,6 +4,7 @@ import unittest
import requests import requests
import torch import torch
from sglang.srt.server_args import set_global_server_args_for_scheduler
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
@@ -16,17 +17,15 @@ from sglang.test.test_utils import (
def check_quant_method(model_path: str, use_marlin_kernel: bool): def check_quant_method(model_path: str, use_marlin_kernel: bool):
from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tp_group,
init_distributed_environment, init_distributed_environment,
initialize_model_parallel, initialize_model_parallel,
set_custom_all_reduce,
) )
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.quantization.utils import get_dynamic_override from sglang.srt.layers.quantization.utils import get_dynamic_override
from sglang.srt.model_loader import get_model from sglang.srt.model_loader import get_model
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import ServerArgs
try: try:
init_distributed_environment( init_distributed_environment(
@@ -43,6 +42,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
pass pass
server_args = ServerArgs(model_path=model_path, dtype=torch.float16) server_args = ServerArgs(model_path=model_path, dtype=torch.float16)
set_global_server_args_for_scheduler(server_args)
model_config = ModelConfig.from_server_args(server_args) model_config = ModelConfig.from_server_args(server_args)
load_config = LoadConfig() load_config = LoadConfig()