Revert "Deprecate global_server_args_dict" (#11520)
This commit is contained in:
@@ -5,8 +5,8 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -42,7 +42,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
||||
# TODO: Change the hard-coded block_seq_num
|
||||
self.BLOCK_SEQ = 128
|
||||
|
||||
if get_global_server_args().triton_attention_reduce_in_fp32:
|
||||
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
||||
self.reduce_dtype = torch.float32
|
||||
else:
|
||||
self.reduce_dtype = torch.float16
|
||||
|
||||
@@ -11,8 +11,8 @@ import triton.language as tl
|
||||
from sglang.srt.configs.model_config import AttentionArch
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -830,7 +830,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
):
|
||||
# Do multi-head attention with chunked prefix cache
|
||||
if forward_batch.attn_attend_prefix_cache:
|
||||
assert not get_global_server_args().disable_chunked_prefix_cache
|
||||
assert not global_server_args_dict["disable_chunked_prefix_cache"]
|
||||
# MHA for chunked prefix kv cache when running model with MLA
|
||||
assert forward_batch.prefix_chunk_idx is not None
|
||||
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
||||
|
||||
@@ -28,8 +28,8 @@ from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
create_flashinfer_kv_indices_triton,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
from sglang.srt.utils import (
|
||||
is_flashinfer_available,
|
||||
@@ -193,9 +193,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
self.skip_prefill = skip_prefill
|
||||
self.enable_chunk_kv = (
|
||||
not skip_prefill
|
||||
and get_global_server_args().disaggregation_mode != "decode"
|
||||
and not get_global_server_args().disable_chunked_prefix_cache
|
||||
and not get_global_server_args().flashinfer_mla_disable_ragged
|
||||
and global_server_args_dict["disaggregation_mode"] != "decode"
|
||||
and not global_server_args_dict["disable_chunked_prefix_cache"]
|
||||
and not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
||||
)
|
||||
self.page_size = model_runner.page_size
|
||||
|
||||
@@ -306,7 +306,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
||||
use_ragged = (
|
||||
not get_global_server_args().flashinfer_mla_disable_ragged
|
||||
not global_server_args_dict["flashinfer_mla_disable_ragged"]
|
||||
and extend_no_prefix
|
||||
)
|
||||
|
||||
|
||||
@@ -23,9 +23,9 @@ from sglang.srt.layers.linear import ReplicatedLinear
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
|
||||
@@ -162,7 +162,7 @@ class Indexer(CustomOp):
|
||||
base=rope_theta, # type: ignore
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False,
|
||||
device=get_global_server_args().device,
|
||||
device=global_server_args_dict["device"],
|
||||
)
|
||||
self.block_size = block_size
|
||||
self.scale_fmt = scale_fmt
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -11,7 +11,7 @@ if _is_cuda:
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
if get_global_server_args().triton_attention_reduce_in_fp32:
|
||||
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
||||
REDUCE_TRITON_TYPE = tl.float32
|
||||
REDUCE_TORCH_TYPE = torch.float32
|
||||
else:
|
||||
|
||||
@@ -20,8 +20,8 @@ from sglang.srt.layers.attention.utils import (
|
||||
create_flashmla_kv_indices_triton,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import is_cuda, is_flashinfer_available
|
||||
|
||||
if is_flashinfer_available():
|
||||
@@ -123,9 +123,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
|
||||
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
||||
|
||||
self.disable_chunked_prefix_cache = (
|
||||
get_global_server_args().disable_chunked_prefix_cache
|
||||
)
|
||||
self.disable_chunked_prefix_cache = global_server_args_dict[
|
||||
"disable_chunked_prefix_cache"
|
||||
]
|
||||
|
||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import (
|
||||
)
|
||||
from sglang.srt.layers.quantization import QuantizationConfig
|
||||
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
ROTARY_EMBED_CLASSES = {
|
||||
@@ -468,7 +468,7 @@ class VisionAttention(nn.Module):
|
||||
_passed_backend = qkv_backend
|
||||
qkv_backend = self._determine_attention_backend(_passed_backend)
|
||||
if (
|
||||
get_global_server_args().mm_attention_backend is None
|
||||
global_server_args_dict["mm_attention_backend"] is None
|
||||
and _passed_backend is None
|
||||
):
|
||||
print_info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
|
||||
@@ -528,7 +528,7 @@ class VisionAttention(nn.Module):
|
||||
- CUDA: "triton_attn"
|
||||
- Non-CUDA: "sdpa"
|
||||
"""
|
||||
override_backend = get_global_server_args().mm_attention_backend
|
||||
override_backend = global_server_args_dict["mm_attention_backend"]
|
||||
if override_backend is not None:
|
||||
backend = override_backend
|
||||
elif passed_backend is not None:
|
||||
|
||||
@@ -40,9 +40,8 @@ from sglang.srt.layers.moe import (
|
||||
get_moe_a2a_backend,
|
||||
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
is_cuda,
|
||||
@@ -169,7 +168,7 @@ class LayerScatterModes:
|
||||
|
||||
|
||||
def enable_moe_dense_fully_dp():
|
||||
return get_global_server_args().moe_dense_tp_size == 1
|
||||
return global_server_args_dict["moe_dense_tp_size"] == 1
|
||||
|
||||
|
||||
class LayerCommunicator:
|
||||
@@ -315,9 +314,7 @@ class LayerCommunicator:
|
||||
def should_fuse_mlp_allreduce_with_next_layer(
|
||||
self, forward_batch: ForwardBatch
|
||||
) -> bool:
|
||||
speculative_algo = SpeculativeAlgorithm.from_string(
|
||||
get_global_server_args().speculative_algorithm
|
||||
)
|
||||
speculative_algo = global_server_args_dict.get("speculative_algorithm", None)
|
||||
if (
|
||||
is_dp_attention_enabled()
|
||||
and speculative_algo is not None
|
||||
@@ -336,7 +333,7 @@ class LayerCommunicator:
|
||||
static_conditions_met = (
|
||||
(not self.is_last_layer)
|
||||
and (self._context.tp_size > 1)
|
||||
and get_global_server_args().enable_flashinfer_allreduce_fusion
|
||||
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
||||
and _is_flashinfer_available
|
||||
)
|
||||
|
||||
@@ -534,7 +531,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
(_is_sm100_supported or _is_sm90_supported)
|
||||
and _is_flashinfer_available
|
||||
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
||||
and get_global_server_args().enable_flashinfer_allreduce_fusion
|
||||
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
||||
and hidden_states.shape[0] <= 4096
|
||||
):
|
||||
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
||||
|
||||
@@ -38,15 +38,17 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_dp_device,
|
||||
get_dp_dtype,
|
||||
get_dp_hidden_size,
|
||||
get_global_dp_buffer,
|
||||
get_local_attention_dp_size,
|
||||
set_dp_buffer_len,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -228,8 +230,8 @@ class LogitsProcessor(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.logit_scale = logit_scale
|
||||
self.use_attn_tp_group = get_global_server_args().enable_dp_lm_head
|
||||
self.use_fp32_lm_head = get_global_server_args().enable_fp32_lm_head
|
||||
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
|
||||
self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
|
||||
if self.use_attn_tp_group:
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.do_tensor_parallel_all_gather = (
|
||||
@@ -252,8 +254,8 @@ class LogitsProcessor(nn.Module):
|
||||
):
|
||||
self.final_logit_softcapping = None
|
||||
|
||||
self.debug_tensor_dump_output_folder = (
|
||||
get_global_server_args().debug_tensor_dump_output_folder
|
||||
self.debug_tensor_dump_output_folder = global_server_args_dict.get(
|
||||
"debug_tensor_dump_output_folder", None
|
||||
)
|
||||
|
||||
def compute_logprobs_for_multi_item_scoring(
|
||||
@@ -370,7 +372,9 @@ class LogitsProcessor(nn.Module):
|
||||
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
||||
|
||||
# Check if multi-item scoring is enabled via server args (only for prefill-only requests)
|
||||
multi_item_delimiter = get_global_server_args().multi_item_scoring_delimiter
|
||||
multi_item_delimiter = global_server_args_dict.get(
|
||||
"multi_item_scoring_delimiter"
|
||||
)
|
||||
if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
|
||||
return self.compute_logprobs_for_multi_item_scoring(
|
||||
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
|
||||
|
||||
@@ -27,10 +27,12 @@ from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
|
||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
|
||||
@@ -31,7 +31,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
is_cuda,
|
||||
@@ -265,9 +265,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
||||
self.with_bias = False
|
||||
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
|
||||
self.flashinfer_mxfp4_moe_precision = (
|
||||
get_global_server_args().flashinfer_mxfp4_moe_precision
|
||||
)
|
||||
self.flashinfer_mxfp4_moe_precision = global_server_args_dict[
|
||||
"flashinfer_mxfp4_moe_precision"
|
||||
]
|
||||
|
||||
self.triton_kernel_moe_forward = None
|
||||
self.triton_kernel_moe_with_bias_forward = None
|
||||
|
||||
@@ -11,8 +11,8 @@ from sglang.srt.layers.dp_attention import (
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
|
||||
|
||||
if is_cuda():
|
||||
@@ -33,7 +33,7 @@ RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
|
||||
class Sampler(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.use_nan_detection = get_global_server_args().enable_nan_detection
|
||||
self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
|
||||
self.tp_sync_group = get_tp_group().device_group
|
||||
|
||||
if is_dp_attention_enabled():
|
||||
@@ -103,7 +103,7 @@ class Sampler(nn.Module):
|
||||
del logits
|
||||
|
||||
if True: # Keep this redundant check to simplify some internal code sync
|
||||
if get_global_server_args().sampling_backend == "flashinfer":
|
||||
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
||||
if sampling_info.need_min_p_sampling:
|
||||
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
||||
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
||||
@@ -118,7 +118,7 @@ class Sampler(nn.Module):
|
||||
filter_apply_order="joint",
|
||||
check_nan=self.use_nan_detection,
|
||||
)
|
||||
elif get_global_server_args().sampling_backend == "pytorch":
|
||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||
# A slower fallback implementation with torch native operations.
|
||||
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
probs,
|
||||
@@ -131,7 +131,7 @@ class Sampler(nn.Module):
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid sampling backend: {get_global_server_args().sampling_backend}"
|
||||
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
||||
)
|
||||
|
||||
if return_logprob:
|
||||
|
||||
Reference in New Issue
Block a user