Deprecate global_server_args_dict (#11331)

This commit is contained in:
Liangsheng Yin
2025-10-13 01:20:47 +08:00
committed by GitHub
parent 2157d12ae8
commit 1083e7e3df
54 changed files with 240 additions and 321 deletions

View File

@@ -6,9 +6,6 @@
class GlobalConfig:
"""
Store some global constants.
See also python/sglang/srt/managers/schedule_batch.py::global_server_args_dict, which stores
many global runtime arguments as well.
"""
def __init__(self):

View File

@@ -5,7 +5,7 @@ from packaging import version
from torch.cuda.memory import CUDAPluggableAllocator
from sglang.srt.distributed.parallel_state import GroupCoordinator
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.server_args import get_global_server_args
nccl_allocator_source = """
#include <nccl.h>
@@ -32,7 +32,7 @@ _graph_pool_id = None
def is_symmetric_memory_enabled():
return global_server_args_dict["enable_symm_mem"]
return get_global_server_args().enable_symm_mem
def set_graph_pool_id(graph_pool_id):

View File

@@ -18,7 +18,7 @@ from typing import Literal, Optional
import torch
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.server_args import get_global_server_args
@dataclass
@@ -34,7 +34,7 @@ class ExpertLocationDispatchInfo:
@classmethod
def init_new(cls, layer_id: int):
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
ep_dispatch_algorithm = get_global_server_args().ep_dispatch_algorithm
expert_location_metadata = get_global_expert_location_metadata()
assert expert_location_metadata is not None

View File

@@ -24,7 +24,7 @@ from sglang.srt.eplb.expert_location import (
ExpertLocationMetadata,
get_global_expert_location_metadata,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import get_bool_env_var
logger = logging.getLogger(__name__)
@@ -97,7 +97,7 @@ def _update_expert_weights_with_canary(
canary_tensor = (
_get_canary_value(old_expert_location_metadata, layer_id)
.clone()
.to(device=global_server_args_dict["device"], non_blocking=True)
.to(device=get_global_server_args().device, non_blocking=True)
)
routed_experts_weights_of_layer[layer_id].append(canary_tensor)

View File

@@ -5,8 +5,8 @@ from typing import TYPE_CHECKING
import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
@@ -42,7 +42,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
# TODO: Change the hard-coded block_seq_num
self.BLOCK_SEQ = 128
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
if get_global_server_args().triton_attention_reduce_in_fp32:
self.reduce_dtype = torch.float32
else:
self.reduce_dtype = torch.float16

View File

@@ -11,8 +11,8 @@ import triton.language as tl
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpecInput
if TYPE_CHECKING:
@@ -830,7 +830,7 @@ class FlashAttentionBackend(AttentionBackend):
):
# Do multi-head attention with chunked prefix cache
if forward_batch.attn_attend_prefix_cache:
assert not global_server_args_dict["disable_chunked_prefix_cache"]
assert not get_global_server_args().disable_chunked_prefix_cache
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None

View File

@@ -28,8 +28,8 @@ from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import (
is_flashinfer_available,
@@ -193,9 +193,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.skip_prefill = skip_prefill
self.enable_chunk_kv = (
not skip_prefill
and global_server_args_dict["disaggregation_mode"] != "decode"
and not global_server_args_dict["disable_chunked_prefix_cache"]
and not global_server_args_dict["flashinfer_mla_disable_ragged"]
and get_global_server_args().disaggregation_mode != "decode"
and not get_global_server_args().disable_chunked_prefix_cache
and not get_global_server_args().flashinfer_mla_disable_ragged
)
self.page_size = model_runner.page_size
@@ -306,7 +306,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
prefix_lens = forward_batch.extend_prefix_lens
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
use_ragged = (
not global_server_args_dict["flashinfer_mla_disable_ragged"]
not get_global_server_args().flashinfer_mla_disable_ragged
and extend_no_prefix
)

View File

@@ -23,9 +23,9 @@ from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
@@ -162,7 +162,7 @@ class Indexer(CustomOp):
base=rope_theta, # type: ignore
rope_scaling=rope_scaling,
is_neox_style=False,
device=global_server_args_dict["device"],
device=get_global_server_args().device,
)
self.block_size = block_size
self.scale_fmt = scale_fmt

View File

@@ -2,7 +2,7 @@ import torch
import triton
import triton.language as tl
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda()
@@ -11,7 +11,7 @@ if _is_cuda:
_is_hip = is_hip()
if global_server_args_dict.get("attention_reduce_in_fp32", False):
if get_global_server_args().triton_attention_reduce_in_fp32:
REDUCE_TRITON_TYPE = tl.float32
REDUCE_TORCH_TYPE = torch.float32
else:

View File

@@ -20,8 +20,8 @@ from sglang.srt.layers.attention.utils import (
create_flashmla_kv_indices_triton,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import is_cuda, is_flashinfer_available
if is_flashinfer_available():
@@ -123,9 +123,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
self.disable_chunked_prefix_cache = global_server_args_dict[
"disable_chunked_prefix_cache"
]
self.disable_chunked_prefix_cache = (
get_global_server_args().disable_chunked_prefix_cache
)
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens

View File

@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix
ROTARY_EMBED_CLASSES = {
@@ -468,7 +468,7 @@ class VisionAttention(nn.Module):
_passed_backend = qkv_backend
qkv_backend = self._determine_attention_backend(_passed_backend)
if (
global_server_args_dict["mm_attention_backend"] is None
get_global_server_args().mm_attention_backend is None
and _passed_backend is None
):
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
@@ -528,7 +528,7 @@ class VisionAttention(nn.Module):
- CUDA: "triton_attn"
- Non-CUDA: "sdpa"
"""
override_backend = global_server_args_dict["mm_attention_backend"]
override_backend = get_global_server_args().mm_attention_backend
if override_backend is not None:
backend = override_backend
elif passed_backend is not None:

View File

@@ -40,8 +40,9 @@ from sglang.srt.layers.moe import (
get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
get_bool_env_var,
is_cuda,
@@ -168,7 +169,7 @@ class LayerScatterModes:
def enable_moe_dense_fully_dp():
return global_server_args_dict["moe_dense_tp_size"] == 1
return get_global_server_args().moe_dense_tp_size == 1
class LayerCommunicator:
@@ -314,7 +315,9 @@ class LayerCommunicator:
def should_fuse_mlp_allreduce_with_next_layer(
self, forward_batch: ForwardBatch
) -> bool:
speculative_algo = global_server_args_dict.get("speculative_algorithm", None)
speculative_algo = SpeculativeAlgorithm.from_string(
get_global_server_args().speculative_algorithm
)
if (
is_dp_attention_enabled()
and speculative_algo is not None
@@ -333,7 +336,7 @@ class LayerCommunicator:
static_conditions_met = (
(not self.is_last_layer)
and (self._context.tp_size > 1)
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
and get_global_server_args().enable_flashinfer_allreduce_fusion
and _is_flashinfer_available
)
@@ -531,7 +534,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
(_is_sm100_supported or _is_sm90_supported)
and _is_flashinfer_available
and hasattr(layernorm, "forward_with_allreduce_fusion")
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
and get_global_server_args().enable_flashinfer_allreduce_fusion
and hidden_states.shape[0] <= 4096
):
hidden_states, residual = layernorm.forward_with_allreduce_fusion(

View File

@@ -38,17 +38,15 @@ from sglang.srt.layers.dp_attention import (
get_dp_device,
get_dp_dtype,
get_dp_hidden_size,
get_global_dp_buffer,
get_local_attention_dp_size,
set_dp_buffer_len,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
logger = logging.getLogger(__name__)
@@ -230,8 +228,8 @@ class LogitsProcessor(nn.Module):
super().__init__()
self.config = config
self.logit_scale = logit_scale
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
self.use_attn_tp_group = get_global_server_args().enable_dp_lm_head
self.use_fp32_lm_head = get_global_server_args().enable_fp32_lm_head
if self.use_attn_tp_group:
self.attn_tp_size = get_attention_tp_size()
self.do_tensor_parallel_all_gather = (
@@ -254,8 +252,8 @@ class LogitsProcessor(nn.Module):
):
self.final_logit_softcapping = None
self.debug_tensor_dump_output_folder = global_server_args_dict.get(
"debug_tensor_dump_output_folder", None
self.debug_tensor_dump_output_folder = (
get_global_server_args().debug_tensor_dump_output_folder
)
def compute_logprobs_for_multi_item_scoring(
@@ -372,9 +370,7 @@ class LogitsProcessor(nn.Module):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
# Check if multi-item scoring is enabled via server args (only for prefill-only requests)
multi_item_delimiter = global_server_args_dict.get(
"multi_item_scoring_delimiter"
)
multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter
if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
return self.compute_logprobs_for_multi_item_scoring(
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter

View File

@@ -27,12 +27,10 @@ from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.utils import (
cpu_has_amx_support,

View File

@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
direct_register_custom_op,
is_cuda,
@@ -265,9 +265,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
self.with_bias = False
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
self.flashinfer_mxfp4_moe_precision = global_server_args_dict[
"flashinfer_mxfp4_moe_precision"
]
self.flashinfer_mxfp4_moe_precision = (
get_global_server_args().flashinfer_mxfp4_moe_precision
)
self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None

View File

@@ -11,8 +11,8 @@ from sglang.srt.layers.dp_attention import (
is_dp_attention_enabled,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
if is_cuda():
@@ -33,7 +33,7 @@ RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
self.use_nan_detection = get_global_server_args().enable_nan_detection
self.tp_sync_group = get_tp_group().device_group
if is_dp_attention_enabled():
@@ -104,7 +104,7 @@ class Sampler(nn.Module):
del logits
if True: # Keep this redundant check to simplify some internal code sync
if global_server_args_dict["sampling_backend"] == "flashinfer":
if get_global_server_args().sampling_backend == "flashinfer":
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
@@ -119,7 +119,7 @@ class Sampler(nn.Module):
filter_apply_order="joint",
check_nan=self.use_nan_detection,
)
elif global_server_args_dict["sampling_backend"] == "pytorch":
elif get_global_server_args().sampling_backend == "pytorch":
# A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
probs,
@@ -132,7 +132,7 @@ class Sampler(nn.Module):
)
else:
raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
f"Invalid sampling backend: {get_global_server_args().sampling_backend}"
)
if return_logprob:

View File

@@ -16,10 +16,10 @@ from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
from sglang.utils import logger
@@ -428,7 +428,7 @@ def _adjust_embedding_length(
f"tokens from multimodal embeddings."
)
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
chunked_prefill_size = get_global_server_args().chunked_prefill_size
if chunked_prefill_size != -1:
logger.warning(
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"

View File

@@ -72,7 +72,7 @@ from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.server_args import ServerArgs, get_global_server_args
from sglang.srt.utils import flatten_nested_list
from sglang.srt.utils.common import next_power_of_2
@@ -82,47 +82,6 @@ if TYPE_CHECKING:
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
GLOBAL_SERVER_ARGS_KEYS = [
"attention_backend",
"mm_attention_backend",
"debug_tensor_dump_inject",
"debug_tensor_dump_output_folder",
"chunked_prefill_size",
"device",
"disable_chunked_prefix_cache",
"disable_flashinfer_cutlass_moe_fp4_allgather",
"disable_radix_cache",
"enable_dp_lm_head",
"enable_fp32_lm_head",
"flashinfer_mxfp4_moe_precision",
"enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size",
"ep_dispatch_algorithm",
"ep_num_redundant_experts",
"enable_nan_detection",
"flashinfer_mla_disable_ragged",
"pp_max_micro_batch_size",
"disable_shared_experts_fusion",
"sampling_backend",
"speculative_accept_threshold_single",
"speculative_accept_threshold_acc",
"speculative_attention_mode",
"torchao_config",
"triton_attention_reduce_in_fp32",
"num_reserved_decode_tokens",
"weight_loader_disable_mmap",
"enable_multimodal",
"enable_symm_mem",
"enable_custom_logit_processor",
"disaggregation_mode",
"enable_deterministic_inference",
"nsa_prefill",
"nsa_decode",
"multi_item_scoring_delimiter",
]
# Put some global args for easy access
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
logger = logging.getLogger(__name__)
@@ -683,12 +642,9 @@ class Req:
def is_prefill_only(self) -> bool:
"""Check if this request is prefill-only (no token generation needed)."""
# NOTE: when spec is enabled, prefill_only optimizations are disabled
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
spec_alg = global_server_args_dict["speculative_algorithm"]
return self.sampling_params.max_new_tokens == 0 and (
spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE
)
spec_alg = get_global_server_args().speculative_algorithm
return self.sampling_params.max_new_tokens == 0 and spec_alg is None
def add_latency(self, stage: RequestStage):
if self.metrics_collector is None:

View File

@@ -122,7 +122,6 @@ from sglang.srt.managers.schedule_batch import (
Req,
RequestStage,
ScheduleBatch,
global_server_args_dict,
)
from sglang.srt.managers.schedule_policy import (
AddReqResult,
@@ -150,7 +149,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.tracing.trace import (
@@ -447,13 +446,12 @@ class Scheduler(
self.max_req_input_len,
self.random_seed,
self.device,
worker_global_server_args_dict,
_,
_,
_,
) = self.tp_worker.get_worker_info()
if global_server_args_dict["pp_max_micro_batch_size"] is None:
global_server_args_dict["pp_max_micro_batch_size"] = max(
if get_global_server_args().pp_max_micro_batch_size is None:
get_global_server_args().pp_max_micro_batch_size = max(
self.max_running_requests // server_args.pp_size, 1
)
@@ -465,7 +463,6 @@ class Scheduler(
self.world_group = get_world_group()
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed)
# Hybrid memory pool
@@ -1866,7 +1863,7 @@ class Scheduler(
return ret
def get_num_allocatable_reqs(self, running_bs):
res = global_server_args_dict["pp_max_micro_batch_size"] - running_bs
res = get_global_server_args().pp_max_micro_batch_size - running_bs
if self.pp_size > 1:
res = min(res, self.req_to_token_pool.available_size())
return res
@@ -2610,7 +2607,7 @@ class Scheduler(
)
def get_internal_state(self, recv_req: GetInternalStateReq):
ret = dict(global_server_args_dict)
ret = vars(get_global_server_args())
ret["last_gen_throughput"] = self.last_gen_throughput
ret["memory_usage"] = {
"weight": round(
@@ -2666,11 +2663,11 @@ class Scheduler(
logger.info(f"{avg_spec_accept_length=}")
self.cum_spec_accept_length = self.cum_spec_accept_count = 0
for k, v in server_args_dict.items():
global_server_args_dict[k] = v
logger.info(f"Global server args updated! {global_server_args_dict=}")
setattr(get_global_server_args(), k, v)
logger.info(f"Global server args updated! {get_global_server_args()=}")
return SetInternalStateReqOutput(
updated=True,
server_args=global_server_args_dict,
server_args=vars(get_global_server_args()),
)
def handle_rpc_request(self, recv_req: RpcReqInput):

View File

@@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
@@ -190,7 +190,6 @@ class TpModelWorker:
self.max_req_input_len,
self.random_seed,
self.device,
global_server_args_dict,
self.model_runner.req_to_token_pool.size,
self.model_runner.req_to_token_pool.max_context_len,
self.model_runner.token_to_kv_pool.size,

View File

@@ -11,7 +11,7 @@ from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.server_args import ServerArgs
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import support_triton
if TYPE_CHECKING:
@@ -19,10 +19,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
GLOBAL_SERVER_ARGS_KEYS = ["attention_backend"]
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
@triton.jit
def write_req_to_token_pool_triton(
@@ -88,7 +84,7 @@ def write_cache_indices(
prefix_tensors: list[torch.Tensor],
req_to_token_pool: ReqToTokenPool,
):
if support_triton(global_server_args_dict.get("attention_backend")):
if support_triton(get_global_server_args().attention_backend):
prefix_pointers = torch.tensor(
[t.data_ptr() for t in prefix_tensors],
device=req_to_token_pool.device,
@@ -129,8 +125,8 @@ def get_last_loc(
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
if (
global_server_args_dict["attention_backend"] != "ascend"
and global_server_args_dict["attention_backend"] != "torch_native"
get_global_server_args().attention_backend != "ascend"
and get_global_server_args().attention_backend != "torch_native"
):
impl = get_last_loc_triton
else:

View File

@@ -83,10 +83,6 @@ from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.managers.schedule_batch import (
GLOBAL_SERVER_ARGS_KEYS,
global_server_args_dict,
)
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
PagedTokenToKVPoolAllocator,
@@ -125,7 +121,11 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.server_args import (
ServerArgs,
get_global_server_args,
set_global_server_args_for_scheduler,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
MultiprocessingSerializer,
@@ -275,15 +275,12 @@ class ModelRunner:
# Model-specific adjustment
self.model_specific_adjustment()
# Global vars
global_server_args_dict.update(
{k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
| {
# TODO it is indeed not a "server args"
"use_mla_backend": self.use_mla_backend,
"speculative_algorithm": self.spec_algorithm,
}
)
# Set the global server_args in the scheduler process
set_global_server_args_for_scheduler(server_args)
global_server_args = get_global_server_args()
# FIXME: hacky set `use_mla_backend`
global_server_args.use_mla_backend = self.use_mla_backend
# Init OpenMP threads binding for CPU
if self.device == "cpu":
@@ -432,7 +429,7 @@ class ModelRunner:
# In layered loading, torchao may have been applied
if not torchao_applied:
apply_torchao_config_to_model(
self.model, global_server_args_dict["torchao_config"]
self.model, get_global_server_args().torchao_config
)
# Apply torch TP if the model supports it
@@ -1838,12 +1835,10 @@ class ModelRunner:
self.server_args.attention_backend
)
global_server_args_dict.update(
{
"decode_attention_backend": self.decode_attention_backend_str,
"prefill_attention_backend": self.prefill_attention_backend_str,
}
)
(
get_global_server_args().prefill_attention_backend,
get_global_server_args().decode_attention_backend,
) = (self.prefill_attention_backend_str, self.decode_attention_backend_str)
return attn_backend
def _get_attention_backend_from_str(self, backend_str: str):

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
# ruff: noqa: SIM117
import collections
import concurrent
import dataclasses
import fnmatch
import glob
@@ -12,12 +11,10 @@ import json
import logging
import math
import os
import re
import socket
import threading
import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager, suppress
from typing import (
TYPE_CHECKING,
@@ -33,10 +30,10 @@ from typing import (
import huggingface_hub
import numpy as np
import requests
import safetensors.torch
import torch
from sglang.srt.server_args import get_global_server_args
# Try to import accelerate (optional dependency)
try:
from accelerate import infer_auto_device_map, init_empty_weights
@@ -81,8 +78,6 @@ DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration
)
from sglang.srt.model_loader.weight_utils import (
_BAR_FORMAT,
default_weight_loader,
download_safetensors_index_file_from_hf,
download_weights_from_hf,
filter_duplicate_safetensors_files,
@@ -445,10 +440,8 @@ class DefaultModelLoader(BaseModelLoader):
hf_weights_files,
)
elif use_safetensors:
from sglang.srt.managers.schedule_batch import global_server_args_dict
weight_loader_disable_mmap = global_server_args_dict.get(
"weight_loader_disable_mmap"
weight_loader_disable_mmap = (
get_global_server_args().weight_loader_disable_mmap
)
if extra_config.get("enable_multithread_load"):
@@ -616,9 +609,9 @@ class LayeredModelLoader(DefaultModelLoader):
device_config: DeviceConfig,
) -> nn.Module:
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.server_args import get_global_server_args
torchao_config = global_server_args_dict.get("torchao_config")
torchao_config = get_global_server_args().torchao_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):

View File

@@ -46,15 +46,14 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, make_layers
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
@@ -447,7 +446,7 @@ class ApertusForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

View File

@@ -42,13 +42,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, make_layers
logger = logging.getLogger(__name__)
@@ -407,7 +407,7 @@ class ArceeForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

View File

@@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" SGLang BailingMoE model."""
"""SGLang BailingMoE model."""
import logging
from typing import Any, Dict, Iterable, Optional, Tuple, Union
@@ -68,7 +68,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -76,6 +75,7 @@ from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
LoraConfig = None
@@ -204,8 +204,8 @@ class BailingMoESparseMoeBlock(nn.Module):
else:
self.router_dtype = torch.bfloat16
# TODO global_server_args_dict["ep_num_redundant_experts"] is used for eplb, not supported now
assert global_server_args_dict["ep_num_redundant_experts"] == 0
# TODO global_server_args.ep_num_redundant_experts is used for eplb, not supported now
assert get_global_server_args().ep_num_redundant_experts == 0
# check group topk
self.num_expert_group = getattr(config, "n_group", 0)
self.topk_group = getattr(config, "topk_group", 0)
@@ -220,7 +220,7 @@ class BailingMoESparseMoeBlock(nn.Module):
self.use_grouped_topk = False
self.num_experts = (
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
config.num_experts + get_global_server_args().ep_num_redundant_experts
)
self.gate = BailingMoEGate(
@@ -824,7 +824,7 @@ class BailingMoEForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.logits_processor = LogitsProcessor(config)

View File

@@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" SGLang BailingMoENextN model."""
"""SGLang BailingMoENextN model."""
import logging
from typing import Iterable, Optional, Tuple
@@ -29,15 +29,14 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.bailing_moe import BailingMoEBlock, BailingMoEForCausalLM
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix
LoraConfig = None
@@ -145,7 +144,7 @@ class BailingMoeForCausalLMNextN(BailingMoEForCausalLM):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.logits_processor = LogitsProcessor(config)

View File

@@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda
logger = logging.getLogger(__name__)
@@ -152,7 +152,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.logits_processor = LogitsProcessor(config)

View File

@@ -35,7 +35,6 @@ from sglang.srt.configs.model_config import (
get_nsa_index_topk,
is_deepseek_nsa,
)
from sglang.srt.debug_utils.dumper import dumper
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group,
@@ -108,10 +107,11 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args
from sglang.srt.single_batch_overlap import SboFlags
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.two_batch_overlap import (
MaybeTboDeepEPDispatcher,
model_forward_maybe_tbo,
@@ -520,7 +520,7 @@ class DeepseekV2MoE(nn.Module):
self.n_shared_experts = config.n_shared_experts
self.num_fused_shared_experts = (
0
if global_server_args_dict["disable_shared_experts_fusion"]
if get_global_server_args().disable_shared_experts_fusion
else config.n_shared_experts
)
self.config = config
@@ -549,7 +549,7 @@ class DeepseekV2MoE(nn.Module):
self.experts = get_moe_impl_class(quant_config)(
num_experts=config.n_routed_experts
+ self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"],
+ get_global_server_args().ep_num_redundant_experts,
num_fused_shared_experts=self.num_fused_shared_experts,
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size,
@@ -627,7 +627,7 @@ class DeepseekV2MoE(nn.Module):
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
config.n_routed_experts
+ global_server_args_dict["ep_num_redundant_experts"]
+ get_global_server_args().ep_num_redundant_experts
)
self.renormalize = config.norm_topk_prob
self.topk_group = config.topk_group
@@ -1125,7 +1125,7 @@ class DeepseekV2AttentionMLA(nn.Module):
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False,
device=global_server_args_dict["device"],
device=get_global_server_args().device,
)
if rope_scaling:
@@ -1169,12 +1169,12 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_scale_v = None
self.use_deep_gemm_bmm = False
self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged"
]
self.disable_chunked_prefix_cache = global_server_args_dict[
"disable_chunked_prefix_cache"
]
self.flashinfer_mla_disable_ragged = (
get_global_server_args().flashinfer_mla_disable_ragged,
)
self.disable_chunked_prefix_cache = (
get_global_server_args().disable_chunked_prefix_cache
)
self.current_attention_backend = (
None # Attention backend used by current forward batch
@@ -1253,18 +1253,18 @@ class DeepseekV2AttentionMLA(nn.Module):
) -> AttnForwardMethod:
# Determine attention backend used by current forward batch
if forward_batch.forward_mode.is_decode_or_idle():
attention_backend = global_server_args_dict["decode_attention_backend"]
attention_backend = get_global_server_args().decode_attention_backend
elif (
forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend()
):
# Use the specified backend for speculative operations (both verify and draft extend)
if global_server_args_dict["speculative_attention_mode"] == "decode":
attention_backend = global_server_args_dict["decode_attention_backend"]
if get_global_server_args().speculative_attention_mode == "decode":
attention_backend = get_global_server_args().decode_attention_backend
else: # default to prefill
attention_backend = global_server_args_dict["prefill_attention_backend"]
attention_backend = get_global_server_args().prefill_attention_backend
else:
attention_backend = global_server_args_dict["prefill_attention_backend"]
attention_backend = get_global_server_args().prefill_attention_backend
self.current_attention_backend = attention_backend
handler = AttentionBackendRegistry.get_handler(attention_backend)
@@ -2365,7 +2365,9 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
get_global_server_args().speculative_algorithm
)
self.layer_id = layer_id
self.is_nextn = is_nextn
self.self_attn = DeepseekV2AttentionMLA(
@@ -2817,7 +2819,7 @@ class DeepseekV2ForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.logits_processor = LogitsProcessor(config)
@@ -2837,7 +2839,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self, architecture: str = "DeepseekV3ForCausalLM"
):
self.num_fused_shared_experts = 0
if global_server_args_dict["disable_shared_experts_fusion"]:
if get_global_server_args().disable_shared_experts_fusion:
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
@@ -2856,7 +2858,7 @@ class DeepseekV2ForCausalLM(nn.Module):
disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True
get_global_server_args().disable_shared_experts_fusion = True
self.num_fused_shared_experts = 0
log_info_on_rank0(
logger,

View File

@@ -33,9 +33,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_cuda, make_layers
logger = logging.getLogger(__name__)
@@ -483,7 +483,7 @@ class FalconH1ForCausalLM(nn.Module):
quant_config=quant_config,
org_num_embeddings=config.vocab_size,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.lm_head = self.lm_head.float()
self.lm_head_multiplier = config.lm_head_multiplier

View File

@@ -56,18 +56,13 @@ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
per_tensor_quant_mla_fp8,
per_token_group_quant_mla_deep_gemm_masked_fp8,
)
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -77,6 +72,7 @@ from sglang.srt.models.deepseek_v2 import (
DeepseekV2Model,
DeepseekV2MoE,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import (
BumpAllocator,
@@ -395,7 +391,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.n_shared_experts = config.n_shared_experts
self.num_fused_shared_experts = (
0
if global_server_args_dict["disable_shared_experts_fusion"]
if get_global_server_args().disable_shared_experts_fusion
else config.n_shared_experts
)
self.config = config
@@ -432,7 +428,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.experts = get_moe_impl_class(quant_config)(
num_experts=config.n_routed_experts
+ self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"],
+ get_global_server_args().ep_num_redundant_experts,
num_fused_shared_experts=self.num_fused_shared_experts,
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size,
@@ -476,7 +472,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
config.n_routed_experts
+ global_server_args_dict["ep_num_redundant_experts"]
+ get_global_server_args().ep_num_redundant_experts
)
self.renormalize = config.norm_topk_prob
self.topk_group = config.topk_group
@@ -758,7 +754,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.logits_processor = LogitsProcessor(config)
@@ -774,7 +770,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
self, architecture: str = "Glm4MoeForCausalLM"
):
self.num_fused_shared_experts = 0
if global_server_args_dict["disable_shared_experts_fusion"]:
if get_global_server_args().disable_shared_experts_fusion:
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
@@ -790,7 +786,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True
get_global_server_args().disable_shared_experts_fusion = True
self.num_fused_shared_experts = 0
log_info_on_rank0(
logger,

View File

@@ -30,9 +30,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import BumpAllocator, add_prefix
logger = logging.getLogger(__name__)
@@ -145,7 +145,7 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.logits_processor = LogitsProcessor(config)

View File

@@ -16,10 +16,10 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.glm4_moe import Glm4MoeModel
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
from sglang.srt.utils.hf_transformers_utils import get_processor
@@ -47,7 +47,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
self.num_fused_shared_experts = (
0
if global_server_args_dict["disable_shared_experts_fusion"]
if get_global_server_args().disable_shared_experts_fusion
else config.n_shared_experts
)
@@ -68,7 +68,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -81,7 +81,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
self, architecture: str = "Glm4MoeForCausalLM"
):
self.num_fused_shared_experts = 0
if global_server_args_dict["disable_shared_experts_fusion"]:
if get_global_server_args().disable_shared_experts_fusion:
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
@@ -97,7 +97,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True
get_global_server_args().disable_shared_experts_fusion = True
self.num_fused_shared_experts = 0
log_info_on_rank0(
logger,

View File

@@ -63,13 +63,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
LazyValue,
add_prefix,
@@ -138,7 +138,7 @@ class GptOssSparseMoeBlock(nn.Module):
}
self.experts = experts_type(
num_experts=config.num_local_experts
+ global_server_args_dict["ep_num_redundant_experts"],
+ get_global_server_args().ep_num_redundant_experts,
top_k=config.num_experts_per_tok,
layer_id=layer_id,
hidden_size=config.hidden_size,
@@ -259,7 +259,7 @@ class GptOssAttention(nn.Module):
# Choose dtype of sinks based on attention backend: trtllm_mha requires float32,
# others can use bfloat16
attn_backend = global_server_args_dict.get("attention_backend")
attn_backend = get_global_server_args().attention_backend
sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16
self.sinks = nn.Parameter(
torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False
@@ -591,7 +591,7 @@ class GptOssForCausalLM(nn.Module):
config.hidden_size,
# quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.logits_processor = LogitsProcessor(config)
self.capture_aux_hidden_states = False

View File

@@ -28,7 +28,6 @@ from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
@@ -36,7 +35,6 @@ from sglang.srt.distributed import (
)
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.elementwise import (
experts_combine_triton,
fused_dual_residual_rmsnorm,
fused_rmsnorm,
gelu_and_mul_triton,
@@ -64,10 +62,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.loader import DefaultModelLoader
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file
logger = logging.getLogger(__name__)
@@ -869,10 +867,10 @@ class Grok1ForCausalLM(nn.Module):
# Dump tensors for debugging
global debug_tensor_dump_output_folder, debug_tensor_dump_inject
debug_tensor_dump_output_folder = global_server_args_dict[
"debug_tensor_dump_output_folder"
]
debug_tensor_dump_inject = global_server_args_dict["debug_tensor_dump_inject"]
debug_tensor_dump_output_folder = (
get_global_server_args().debug_tensor_dump_output_folder
)
debug_tensor_dump_inject = get_global_server_args().debug_tensor_dump_inject
warnings.filterwarnings("ignore", category=FutureWarning)
if get_tensor_model_parallel_rank() == 0:

View File

@@ -45,13 +45,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, make_layers
from sglang.utils import get_exception_traceback
@@ -433,7 +433,7 @@ class LlamaForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

View File

@@ -32,14 +32,10 @@
import concurrent.futures
import logging
import os
from enum import IntEnum, auto
from typing import Any, Dict, Iterable, Optional, Tuple, Union
from typing import Iterable, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from sglang.srt.configs import LongcatFlashConfig
from sglang.srt.distributed import (
@@ -85,10 +81,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2AttentionMLA
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
BumpAllocator,
LazyValue,
@@ -595,7 +591,7 @@ class LongcatFlashForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.logits_processor = LogitsProcessor(config)

View File

@@ -31,9 +31,9 @@ from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import is_cpu
_is_cpu = is_cpu()
@@ -448,7 +448,7 @@ class Llama4ForConditionalGeneration(nn.Module):
)
self.has_vision = (
self.has_vision_weights and global_server_args_dict["enable_multimodal"]
self.has_vision_weights and get_global_server_args().enable_multimodal
)
if self.has_vision:

View File

@@ -64,10 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
from sglang.srt.utils import add_prefix, is_cuda, make_layers
@@ -156,7 +156,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
layer_id=self.layer_id,
top_k=config.num_experts_per_tok,
num_experts=config.num_experts
+ global_server_args_dict["ep_num_redundant_experts"],
+ get_global_server_args().ep_num_redundant_experts,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
quant_config=quant_config,
@@ -192,7 +192,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
# TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
config.num_experts + get_global_server_args().ep_num_redundant_experts
)
self.top_k = config.num_experts_per_tok
@@ -643,7 +643,7 @@ class Qwen2MoeForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.logits_processor = LogitsProcessor(config)
# For EAGLE3 support

View File

@@ -54,7 +54,6 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope
from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -64,6 +63,7 @@ from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
add_prefix,
is_cuda,
@@ -104,7 +104,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self.experts = get_moe_impl_class(quant_config)(
num_experts=config.num_experts
+ global_server_args_dict["ep_num_redundant_experts"],
+ get_global_server_args().ep_num_redundant_experts,
top_k=config.num_experts_per_tok,
layer_id=layer_id,
hidden_size=config.hidden_size,
@@ -125,7 +125,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
config.num_experts + get_global_server_args().ep_num_redundant_experts
)
self.top_k = config.num_experts_per_tok
@@ -693,7 +693,7 @@ class Qwen3MoeForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.logits_processor = LogitsProcessor(config)
self.capture_aux_hidden_states = False

View File

@@ -39,7 +39,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
@@ -47,6 +46,7 @@ from sglang.srt.model_loader.weight_utils import (
sharded_weight_loader,
)
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
LazyValue,
add_prefix,
@@ -905,7 +905,7 @@ class Qwen3NextForCausalLM(nn.Module):
quant_config=quant_config,
org_num_embeddings=config.vocab_size,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.lm_head = self.lm_head.float()
self.logits_processor = LogitsProcessor(config)

View File

@@ -21,14 +21,13 @@ from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
from sglang.srt.layers.layernorm import GemmaRMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.qwen3_moe import Qwen3MoeModel
from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
@@ -69,7 +68,7 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.logits_processor = LogitsProcessor(config)

View File

@@ -38,20 +38,12 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import (
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.managers.mm_utils import general_mm_embed_routine
from sglang.srt.managers.schedule_batch import MultimodalDataItem
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
from sglang.srt.models.qwen3_moe import Qwen3MoeModel
from sglang.srt.models.qwen3_vl import (
Qwen3_VisionTransformer,
Qwen3VLForConditionalGeneration,

View File

@@ -57,7 +57,6 @@ from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -300,7 +299,7 @@ class Step3TextDecoderLayer(nn.Module):
# self.n_shared_experts = 1
# self.num_fused_shared_experts = (
# 0
# if global_server_args_dict["disable_shared_experts_fusion"]
# if global_server_args.disable_shared_experts_fusion
# else self.n_shared_experts
# )
self.num_fused_shared_experts = 0
@@ -774,7 +773,7 @@ class Step3VLForConditionalGeneration(nn.Module):
# self.n_shared_experts = 1
# self.num_fused_shared_experts = (
# 0
# if global_server_args_dict["disable_shared_experts_fusion"]
# if global_server_args.disable_shared_experts_fusion
# else self.n_shared_experts
# )
self.num_fused_shared_experts = 0

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import dataclasses
import logging
import threading
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
import torch
@@ -10,6 +9,7 @@ import torch
import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_params import TOP_K_ALL
from sglang.srt.server_args import get_global_server_args
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -66,16 +66,10 @@ class SamplingBatchInfo:
# Handle logit bias
logit_bias: Optional[torch.Tensor] = None
@classmethod
def _get_global_server_args_dict(cls):
from sglang.srt.managers.schedule_batch import global_server_args_dict
return global_server_args_dict
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
global_server_args_dict = cls._get_global_server_args_dict()
enable_deterministic = global_server_args_dict["enable_deterministic_inference"]
global_server_args = get_global_server_args()
enable_deterministic = global_server_args.enable_deterministic_inference
reqs = batch.reqs
device = batch.device
@@ -112,10 +106,9 @@ class SamplingBatchInfo:
logit_bias[i, int(key)] = value
# Check if any request has custom logit processor
has_custom_logit_processor = global_server_args_dict[
"enable_custom_logit_processor"
] and any( # check the flag first.
r.custom_logit_processor for r in reqs
has_custom_logit_processor = (
global_server_args.enable_custom_logit_processor
and any(r.custom_logit_processor for r in reqs) # check the flag first.
) # then check the requests.
if has_custom_logit_processor:

View File

@@ -53,6 +53,7 @@ from sglang.utils import is_in_ci
logger = logging.getLogger(__name__)
# Define constants
LOAD_FORMAT_CHOICES = [
"auto",
@@ -3329,6 +3330,22 @@ class ServerArgs:
)
# NOTE: This is a global variable to hold the server args for scheduler.
_global_server_args: Optional[ServerArgs] = None
def set_global_server_args_for_scheduler(server_args: ServerArgs):
global _global_server_args
_global_server_args = server_args
def get_global_server_args() -> ServerArgs:
if _global_server_args is None:
raise ValueError("Global server args is not set yet!")
return _global_server_args
def prepare_server_args(argv: List[str]) -> ServerArgs:
"""
Prepare the server arguments from the command line arguments.
@@ -3363,8 +3380,8 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
raw_args = parser.parse_args(argv)
server_args = ServerArgs.from_cli_args(raw_args)
return server_args
return ServerArgs.from_cli_args(raw_args)
ZMQ_TCP_PORT_DELTA = 233

View File

@@ -6,7 +6,6 @@ import torch
from sglang.srt.layers.moe import get_moe_runner_backend
from sglang.srt.layers.moe.utils import is_sbo_enabled
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import get_int_env_var

View File

@@ -11,7 +11,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.overlap_utils import FutureIndices
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.common import (
alloc_paged_token_slots_extend,
@@ -19,6 +19,7 @@ from sglang.srt.mem_cache.common import (
get_last_loc,
)
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.eagle_info_v2 import (
EagleDraftInputV2Mixin,
EagleVerifyInputV2Mixin,
@@ -332,12 +333,8 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs,
draft_probs=draft_probs,
threshold_single=global_server_args_dict[
"speculative_accept_threshold_single"
],
threshold_acc=global_server_args_dict[
"speculative_accept_threshold_acc"
],
threshold_single=get_global_server_args().speculative_accept_threshold_single,
threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
deterministic=True,
)

View File

@@ -11,7 +11,6 @@ import triton.language as tl
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.scheduler import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
@@ -19,6 +18,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode,
)
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.build_eagle_tree import TreeMaskMode
from sglang.srt.speculative.spec_utils import (
SIMULATE_ACC_LEN,
@@ -265,12 +265,8 @@ class EagleVerifyInputV2Mixin:
uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs,
draft_probs=draft_probs,
threshold_single=global_server_args_dict[
"speculative_accept_threshold_single"
],
threshold_acc=global_server_args_dict[
"speculative_accept_threshold_acc"
],
threshold_single=get_global_server_args().speculative_accept_threshold_single,
threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
deterministic=True,
)

View File

@@ -14,7 +14,7 @@ from sglang.srt.distributed import (
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.mem_cache.common import (
@@ -27,7 +27,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
)
from sglang.srt.server_args import ServerArgs
from sglang.srt.server_args import ServerArgs, get_global_server_args
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner,
@@ -261,7 +261,7 @@ class EAGLEWorker(TpModelWorker):
)
def _create_flashinfer_decode_backend(self):
if not global_server_args_dict["use_mla_backend"]:
if not get_global_server_args().use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferMultiStepDraftBackend,
)
@@ -325,7 +325,7 @@ class EAGLEWorker(TpModelWorker):
)
def _create_trtllm_mla_decode_backend(self):
if not global_server_args_dict["use_mla_backend"]:
if not get_global_server_args().use_mla_backend:
raise ValueError(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
@@ -340,7 +340,7 @@ class EAGLEWorker(TpModelWorker):
)
def _create_flashinfer_prefill_backend(self):
if not global_server_args_dict["use_mla_backend"]:
if not get_global_server_args().use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferAttnBackend,
)
@@ -376,7 +376,7 @@ class EAGLEWorker(TpModelWorker):
return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
def _create_trtllm_mla_prefill_backend(self):
if not global_server_args_dict["use_mla_backend"]:
if not get_global_server_args().use_mla_backend:
raise ValueError(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)

View File

@@ -7,6 +7,8 @@ from typing import Optional, Tuple
import torch
import triton
from sglang.srt.server_args import get_global_server_args
logger = logging.getLogger(__name__)
from dataclasses import dataclass
@@ -16,7 +18,7 @@ import torch.nn.functional as F
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.common import (
alloc_paged_token_slots_extend,
alloc_token_slots,
@@ -350,10 +352,8 @@ class NgramVerifyInput(SpecInput):
uniform_samples_for_final_sampling=coins_for_final_sampling,
target_probs=target_probs,
draft_probs=draft_probs,
threshold_single=global_server_args_dict[
"speculative_accept_threshold_single"
],
threshold_acc=global_server_args_dict["speculative_accept_threshold_acc"],
threshold_single=get_global_server_args().speculative_accept_threshold_single,
threshold_acc=get_global_server_args().speculative_accept_threshold_acc,
deterministic=True,
)

View File

@@ -22,7 +22,7 @@ from sglang.srt.layers.moe import (
)
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
@@ -30,6 +30,7 @@ from sglang.srt.model_executor.forward_batch_info import (
)
from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip
@@ -153,7 +154,7 @@ def _update_device_and_sum_field_from_cpu_field(
cpu_value
if isinstance(cpu_value, torch.Tensor)
else torch.tensor(cpu_value, dtype=old_device_value.dtype)
).to(device=global_server_args_dict["device"], non_blocking=True)
).to(device=get_global_server_args().device, non_blocking=True)
setattr(batch, device_field, new_device_value)
if sum_field is not None:
@@ -582,7 +583,7 @@ class TboForwardBatchPreparer:
sum_field=None,
)
_, child_b.extend_start_loc = compute_position(
global_server_args_dict["attention_backend"],
get_global_server_args().attention_backend,
child_b.extend_prefix_lens,
child_b.extend_seq_lens,
child_b.extend_num_tokens,
@@ -687,7 +688,7 @@ class TboForwardBatchPreparer:
# TODO improve, e.g. unify w/ `init_raw`
if (
global_server_args_dict["moe_dense_tp_size"] == 1
get_global_server_args().moe_dense_tp_size == 1
and batch.global_dp_buffer_len is not None
):
sum_len = end_token_index - start_token_index
@@ -755,7 +756,7 @@ class TboForwardBatchPreparer:
value_a = min(tbo_split_token_index, num_token_non_padded)
value_b = max(0, num_token_non_padded - tbo_split_token_index)
return torch.tensor([value_a, value_b], dtype=torch.int32).to(
device=global_server_args_dict["device"], non_blocking=True
device=get_global_server_args().device, non_blocking=True
)
@classmethod