Deprecate global_server_args_dict (#11331)

This commit is contained in:
Liangsheng Yin
2025-10-13 01:20:47 +08:00
committed by GitHub
parent 2157d12ae8
commit 1083e7e3df
54 changed files with 240 additions and 321 deletions

View File

@@ -35,7 +35,6 @@ from sglang.srt.configs.model_config import (
get_nsa_index_topk,
is_deepseek_nsa,
)
from sglang.srt.debug_utils.dumper import dumper
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group,
@@ -108,10 +107,11 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.server_args import get_global_server_args
from sglang.srt.single_batch_overlap import SboFlags
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.two_batch_overlap import (
MaybeTboDeepEPDispatcher,
model_forward_maybe_tbo,
@@ -520,7 +520,7 @@ class DeepseekV2MoE(nn.Module):
self.n_shared_experts = config.n_shared_experts
self.num_fused_shared_experts = (
0
if global_server_args_dict["disable_shared_experts_fusion"]
if get_global_server_args().disable_shared_experts_fusion
else config.n_shared_experts
)
self.config = config
@@ -549,7 +549,7 @@ class DeepseekV2MoE(nn.Module):
self.experts = get_moe_impl_class(quant_config)(
num_experts=config.n_routed_experts
+ self.num_fused_shared_experts
+ global_server_args_dict["ep_num_redundant_experts"],
+ get_global_server_args().ep_num_redundant_experts,
num_fused_shared_experts=self.num_fused_shared_experts,
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
hidden_size=config.hidden_size,
@@ -627,7 +627,7 @@ class DeepseekV2MoE(nn.Module):
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
config.n_routed_experts
+ global_server_args_dict["ep_num_redundant_experts"]
+ get_global_server_args().ep_num_redundant_experts
)
self.renormalize = config.norm_topk_prob
self.topk_group = config.topk_group
@@ -1125,7 +1125,7 @@ class DeepseekV2AttentionMLA(nn.Module):
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False,
device=global_server_args_dict["device"],
device=get_global_server_args().device,
)
if rope_scaling:
@@ -1169,12 +1169,12 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_scale_v = None
self.use_deep_gemm_bmm = False
self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged"
]
self.disable_chunked_prefix_cache = global_server_args_dict[
"disable_chunked_prefix_cache"
]
self.flashinfer_mla_disable_ragged = (
get_global_server_args().flashinfer_mla_disable_ragged,
)
self.disable_chunked_prefix_cache = (
get_global_server_args().disable_chunked_prefix_cache
)
self.current_attention_backend = (
None # Attention backend used by current forward batch
@@ -1253,18 +1253,18 @@ class DeepseekV2AttentionMLA(nn.Module):
) -> AttnForwardMethod:
# Determine attention backend used by current forward batch
if forward_batch.forward_mode.is_decode_or_idle():
attention_backend = global_server_args_dict["decode_attention_backend"]
attention_backend = get_global_server_args().decode_attention_backend
elif (
forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend()
):
# Use the specified backend for speculative operations (both verify and draft extend)
if global_server_args_dict["speculative_attention_mode"] == "decode":
attention_backend = global_server_args_dict["decode_attention_backend"]
if get_global_server_args().speculative_attention_mode == "decode":
attention_backend = get_global_server_args().decode_attention_backend
else: # default to prefill
attention_backend = global_server_args_dict["prefill_attention_backend"]
attention_backend = get_global_server_args().prefill_attention_backend
else:
attention_backend = global_server_args_dict["prefill_attention_backend"]
attention_backend = get_global_server_args().prefill_attention_backend
self.current_attention_backend = attention_backend
handler = AttentionBackendRegistry.get_handler(attention_backend)
@@ -2365,7 +2365,9 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
get_global_server_args().speculative_algorithm
)
self.layer_id = layer_id
self.is_nextn = is_nextn
self.self_attn = DeepseekV2AttentionMLA(
@@ -2817,7 +2819,7 @@ class DeepseekV2ForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
)
self.logits_processor = LogitsProcessor(config)
@@ -2837,7 +2839,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self, architecture: str = "DeepseekV3ForCausalLM"
):
self.num_fused_shared_experts = 0
if global_server_args_dict["disable_shared_experts_fusion"]:
if get_global_server_args().disable_shared_experts_fusion:
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
@@ -2856,7 +2858,7 @@ class DeepseekV2ForCausalLM(nn.Module):
disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True
get_global_server_args().disable_shared_experts_fusion = True
self.num_fused_shared_experts = 0
log_info_on_rank0(
logger,