Revert "Deprecate global_server_args_dict" (#11520)

This commit is contained in:
Cheng Wan
2025-10-12 17:40:40 -07:00
committed by GitHub
parent 6cd296940a
commit 1bdd010291
54 changed files with 321 additions and 240 deletions

View File

@@ -5,8 +5,8 @@ from typing import TYPE_CHECKING
import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
@@ -42,7 +42,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
# TODO: Change the hard-coded block_seq_num
self.BLOCK_SEQ = 128
if get_global_server_args().triton_attention_reduce_in_fp32:
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
self.reduce_dtype = torch.float32
else:
self.reduce_dtype = torch.float16

View File

@@ -11,8 +11,8 @@ import triton.language as tl
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpecInput
if TYPE_CHECKING:
@@ -830,7 +830,7 @@ class FlashAttentionBackend(AttentionBackend):
):
# Do multi-head attention with chunked prefix cache
if forward_batch.attn_attend_prefix_cache:
assert not get_global_server_args().disable_chunked_prefix_cache
assert not global_server_args_dict["disable_chunked_prefix_cache"]
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None

View File

@@ -28,8 +28,8 @@ from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import (
is_flashinfer_available,
@@ -193,9 +193,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.skip_prefill = skip_prefill
self.enable_chunk_kv = (
not skip_prefill
and get_global_server_args().disaggregation_mode != "decode"
and not get_global_server_args().disable_chunked_prefix_cache
and not get_global_server_args().flashinfer_mla_disable_ragged
and global_server_args_dict["disaggregation_mode"] != "decode"
and not global_server_args_dict["disable_chunked_prefix_cache"]
and not global_server_args_dict["flashinfer_mla_disable_ragged"]
)
self.page_size = model_runner.page_size
@@ -306,7 +306,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
prefix_lens = forward_batch.extend_prefix_lens
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
use_ragged = (
not get_global_server_args().flashinfer_mla_disable_ragged
not global_server_args_dict["flashinfer_mla_disable_ragged"]
and extend_no_prefix
)

View File

@@ -23,9 +23,9 @@ from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
@@ -162,7 +162,7 @@ class Indexer(CustomOp):
base=rope_theta, # type: ignore
rope_scaling=rope_scaling,
is_neox_style=False,
device=get_global_server_args().device,
device=global_server_args_dict["device"],
)
self.block_size = block_size
self.scale_fmt = scale_fmt

View File

@@ -2,7 +2,7 @@ import torch
import triton
import triton.language as tl
from sglang.srt.server_args import get_global_server_args
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda()
@@ -11,7 +11,7 @@ if _is_cuda:
_is_hip = is_hip()
if get_global_server_args().triton_attention_reduce_in_fp32:
if global_server_args_dict.get("attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32
REDUCE_TORCH_TYPE = torch.float32
else:

View File

@@ -20,8 +20,8 @@ from sglang.srt.layers.attention.utils import (
create_flashmla_kv_indices_triton,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import is_cuda, is_flashinfer_available
if is_flashinfer_available():
@@ -123,9 +123,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
self.disable_chunked_prefix_cache = (
get_global_server_args().disable_chunked_prefix_cache
)
self.disable_chunked_prefix_cache = global_server_args_dict[
"disable_chunked_prefix_cache"
]
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens

View File

@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
from sglang.srt.server_args import get_global_server_args
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import add_prefix
ROTARY_EMBED_CLASSES = {
@@ -468,7 +468,7 @@ class VisionAttention(nn.Module):
_passed_backend = qkv_backend
qkv_backend = self._determine_attention_backend(_passed_backend)
if (
get_global_server_args().mm_attention_backend is None
global_server_args_dict["mm_attention_backend"] is None
and _passed_backend is None
):
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
@@ -528,7 +528,7 @@ class VisionAttention(nn.Module):
- CUDA: "triton_attn"
- Non-CUDA: "sdpa"
"""
override_backend = get_global_server_args().mm_attention_backend
override_backend = global_server_args_dict["mm_attention_backend"]
if override_backend is not None:
backend = override_backend
elif passed_backend is not None:

View File

@@ -40,9 +40,8 @@ from sglang.srt.layers.moe import (
get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
get_bool_env_var,
is_cuda,
@@ -169,7 +168,7 @@ class LayerScatterModes:
def enable_moe_dense_fully_dp():
return get_global_server_args().moe_dense_tp_size == 1
return global_server_args_dict["moe_dense_tp_size"] == 1
class LayerCommunicator:
@@ -315,9 +314,7 @@ class LayerCommunicator:
def should_fuse_mlp_allreduce_with_next_layer(
self, forward_batch: ForwardBatch
) -> bool:
speculative_algo = SpeculativeAlgorithm.from_string(
get_global_server_args().speculative_algorithm
)
speculative_algo = global_server_args_dict.get("speculative_algorithm", None)
if (
is_dp_attention_enabled()
and speculative_algo is not None
@@ -336,7 +333,7 @@ class LayerCommunicator:
static_conditions_met = (
(not self.is_last_layer)
and (self._context.tp_size > 1)
and get_global_server_args().enable_flashinfer_allreduce_fusion
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
and _is_flashinfer_available
)
@@ -534,7 +531,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
(_is_sm100_supported or _is_sm90_supported)
and _is_flashinfer_available
and hasattr(layernorm, "forward_with_allreduce_fusion")
and get_global_server_args().enable_flashinfer_allreduce_fusion
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
and hidden_states.shape[0] <= 4096
):
hidden_states, residual = layernorm.forward_with_allreduce_fusion(

View File

@@ -38,15 +38,17 @@ from sglang.srt.layers.dp_attention import (
get_dp_device,
get_dp_dtype,
get_dp_hidden_size,
get_global_dp_buffer,
get_local_attention_dp_size,
set_dp_buffer_len,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
logger = logging.getLogger(__name__)
@@ -228,8 +230,8 @@ class LogitsProcessor(nn.Module):
super().__init__()
self.config = config
self.logit_scale = logit_scale
self.use_attn_tp_group = get_global_server_args().enable_dp_lm_head
self.use_fp32_lm_head = get_global_server_args().enable_fp32_lm_head
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
if self.use_attn_tp_group:
self.attn_tp_size = get_attention_tp_size()
self.do_tensor_parallel_all_gather = (
@@ -252,8 +254,8 @@ class LogitsProcessor(nn.Module):
):
self.final_logit_softcapping = None
self.debug_tensor_dump_output_folder = (
get_global_server_args().debug_tensor_dump_output_folder
self.debug_tensor_dump_output_folder = global_server_args_dict.get(
"debug_tensor_dump_output_folder", None
)
def compute_logprobs_for_multi_item_scoring(
@@ -370,7 +372,9 @@ class LogitsProcessor(nn.Module):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
# Check if multi-item scoring is enabled via server args (only for prefill-only requests)
multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter
multi_item_delimiter = global_server_args_dict.get(
"multi_item_scoring_delimiter"
)
if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
return self.compute_logprobs_for_multi_item_scoring(
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter

View File

@@ -27,10 +27,12 @@ from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
from sglang.srt.utils import (
cpu_has_amx_support,

View File

@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.server_args import get_global_server_args
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import (
direct_register_custom_op,
is_cuda,
@@ -265,9 +265,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
self.with_bias = False
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
self.flashinfer_mxfp4_moe_precision = (
get_global_server_args().flashinfer_mxfp4_moe_precision
)
self.flashinfer_mxfp4_moe_precision = global_server_args_dict[
"flashinfer_mxfp4_moe_precision"
]
self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None

View File

@@ -11,8 +11,8 @@ from sglang.srt.layers.dp_attention import (
is_dp_attention_enabled,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
if is_cuda():
@@ -33,7 +33,7 @@ RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.use_nan_detection = get_global_server_args().enable_nan_detection
self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
self.tp_sync_group = get_tp_group().device_group
if is_dp_attention_enabled():
@@ -103,7 +103,7 @@ class Sampler(nn.Module):
del logits
if True: # Keep this redundant check to simplify some internal code sync
if get_global_server_args().sampling_backend == "flashinfer":
if global_server_args_dict["sampling_backend"] == "flashinfer":
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
@@ -118,7 +118,7 @@ class Sampler(nn.Module):
filter_apply_order="joint",
check_nan=self.use_nan_detection,
)
elif get_global_server_args().sampling_backend == "pytorch":
elif global_server_args_dict["sampling_backend"] == "pytorch":
# A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
probs,
@@ -131,7 +131,7 @@ class Sampler(nn.Module):
)
else:
raise ValueError(
f"Invalid sampling backend: {get_global_server_args().sampling_backend}"
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
)
if return_logprob: