diff --git a/python/sglang/srt/eplb/expert_distribution.py b/python/sglang/srt/eplb/expert_distribution.py index 83fd42250..a7b8875d2 100644 --- a/python/sglang/srt/eplb/expert_distribution.py +++ b/python/sglang/srt/eplb/expert_distribution.py @@ -47,6 +47,11 @@ class ExpertDistributionRecorder(ABC): rank: int, ): if server_args.expert_distribution_recorder_mode is not None: + assert ( + expert_location_metadata is not None + ), "ExpertLocationMetadata is required for expert distribution recording. One possible" + "reason is that you are using a model that does not support expert distribution" + "recording. Try setting `get_model_config_for_expert_location` in your model." return _ExpertDistributionRecorderReal( server_args, expert_location_metadata, rank ) diff --git a/python/sglang/srt/eplb/expert_location.py b/python/sglang/srt/eplb/expert_location.py index 822429dc4..ef35ce7a6 100644 --- a/python/sglang/srt/eplb/expert_location.py +++ b/python/sglang/srt/eplb/expert_location.py @@ -82,6 +82,10 @@ class ExpertLocationMetadata: def init_trivial(server_args: ServerArgs, model_config: ModelConfig): """Trivial location - logical expert i corresponds to physical expert i""" common = ExpertLocationMetadata._init_common(server_args, model_config) + + if common is None: + return None + num_physical_experts = common["num_physical_experts"] model_config_for_expert_location = common["model_config_for_expert_location"] num_layers = model_config_for_expert_location.num_layers @@ -109,6 +113,10 @@ class ExpertLocationMetadata: physical_to_logical_map = physical_to_logical_map.to(server_args.device) common = ExpertLocationMetadata._init_common(server_args, model_config) + + if common is None: + return None + model_config_for_expert_location = common["model_config_for_expert_location"] logical_to_all_physical_map = _compute_logical_to_all_physical_map( physical_to_logical_map, @@ -133,6 +141,10 @@ class ExpertLocationMetadata: logical_count = logical_count.to(server_args.device) common = ExpertLocationMetadata._init_common(server_args, model_config) + + if common is None: + return None + model_config_for_expert_location = common["model_config_for_expert_location"] num_physical_experts = common["num_physical_experts"] num_groups = model_config_for_expert_location.num_groups @@ -168,6 +180,9 @@ class ExpertLocationMetadata: ModelConfigForExpertLocation.from_model_config(model_config) ) + if model_config_for_expert_location is None: + return None + num_physical_experts = ( model_config_for_expert_location.num_logical_experts + server_args.ep_num_redundant_experts @@ -398,10 +413,6 @@ class ModelConfigForExpertLocation: num_logical_experts: int num_groups: Optional[int] = None - @staticmethod - def init_dummy(): - return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1) - @staticmethod def from_model_config(model_config: ModelConfig): model_class, _ = get_model_architecture(model_config) @@ -410,12 +421,12 @@ class ModelConfigForExpertLocation: model_config.hf_config ) else: - return ModelConfigForExpertLocation.init_dummy() + return None def compute_initial_expert_location_metadata( server_args: ServerArgs, model_config: ModelConfig -) -> ExpertLocationMetadata: +) -> Optional[ExpertLocationMetadata]: data = server_args.init_expert_location if data == "trivial": return ExpertLocationMetadata.init_trivial(server_args, model_config) diff --git a/python/sglang/srt/eplb/expert_location_dispatch.py b/python/sglang/srt/eplb/expert_location_dispatch.py index 8d2160b6e..624dc3fd1 100644 --- a/python/sglang/srt/eplb/expert_location_dispatch.py +++ b/python/sglang/srt/eplb/expert_location_dispatch.py @@ -36,6 +36,7 @@ class ExpertLocationDispatchInfo: def init_new(cls, layer_id: int): ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"] expert_location_metadata = get_global_expert_location_metadata() + assert expert_location_metadata is not None if ep_dispatch_algorithm is None: return None diff --git a/python/sglang/srt/eplb/expert_location_updater.py b/python/sglang/srt/eplb/expert_location_updater.py index 6fdeb0322..9887abc97 100644 --- a/python/sglang/srt/eplb/expert_location_updater.py +++ b/python/sglang/srt/eplb/expert_location_updater.py @@ -50,6 +50,8 @@ class ExpertLocationUpdater: torch.cuda.empty_cache() old_expert_location_metadata = get_global_expert_location_metadata() + assert old_expert_location_metadata is not None + _update_expert_weights( routed_experts_weights_of_layer=routed_experts_weights_of_layer, old_expert_location_metadata=old_expert_location_metadata, diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index f2c1ab24d..c9a20d276 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -183,6 +183,7 @@ class EPMoE(FusedMoE): hidden_size: int, intermediate_size: int, layer_id: int, + num_fused_shared_experts: int = 0, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, @@ -196,6 +197,7 @@ class EPMoE(FusedMoE): hidden_size=hidden_size, intermediate_size=intermediate_size, top_k=top_k, + num_fused_shared_experts=num_fused_shared_experts, layer_id=layer_id, params_dtype=params_dtype, quant_config=quant_config, @@ -728,10 +730,19 @@ class EPMoE(FusedMoE): shard_id: str, expert_id: int, ) -> None: - physical_expert_ids = ( - get_global_expert_location_metadata().logical_to_all_physical( - self.layer_id, expert_id + global_expert_location_metadata = get_global_expert_location_metadata() + if global_expert_location_metadata is None: + self._weight_loader_impl( + param=param, + loaded_weight=loaded_weight, + weight_name=weight_name, + shard_id=shard_id, + expert_id=expert_id, ) + return + + physical_expert_ids = global_expert_location_metadata.logical_to_all_physical( + self.layer_id, expert_id ) for physical_expert_id in physical_expert_ids: self._weight_loader_physical( @@ -778,6 +789,7 @@ class DeepEPMoE(EPMoE): hidden_size: int, intermediate_size: int, layer_id: int, + num_fused_shared_experts: int = 0, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, @@ -792,6 +804,7 @@ class DeepEPMoE(EPMoE): hidden_size=hidden_size, intermediate_size=intermediate_size, layer_id=layer_id, + num_fused_shared_experts=num_fused_shared_experts, params_dtype=params_dtype, quant_config=quant_config, tp_size=tp_size, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 39368e879..316bced90 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -11,6 +11,7 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) +from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, @@ -62,8 +63,9 @@ class FusedMoE(torch.nn.Module): num_experts: int, hidden_size: int, intermediate_size: int, + layer_id: int, top_k: Optional[int] = None, - layer_id: Optional[int] = None, + num_fused_shared_experts: int = 0, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = False, quant_config: Optional[QuantizationConfig] = None, @@ -84,6 +86,7 @@ class FusedMoE(torch.nn.Module): if params_dtype is None: params_dtype = torch.get_default_dtype() + self.layer_id = layer_id self.top_k = top_k self.hidden_size = hidden_size self.tp_size = ( @@ -91,6 +94,7 @@ class FusedMoE(torch.nn.Module): ) self.tp_rank = get_tensor_model_parallel_rank() self.num_experts = num_experts + self.num_fused_shared_experts = num_fused_shared_experts self.expert_map = None if enable_flashinfer_cutlass_moe and quant_config is None: @@ -375,6 +379,45 @@ class FusedMoE(torch.nn.Module): shard_id: str, expert_id: int, ) -> None: + + global_expert_location_metadata = get_global_expert_location_metadata() + if global_expert_location_metadata is None: + self._weight_loader_impl( + param=param, + loaded_weight=loaded_weight, + weight_name=weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) + return + + if expert_id >= self.num_experts - self.num_fused_shared_experts: + # This is a shared expert. + physical_expert_ids = [expert_id] + else: + physical_expert_ids = ( + global_expert_location_metadata.logical_to_all_physical( + self.layer_id, expert_id + ) + ) + + for physical_expert_id in physical_expert_ids: + self._weight_loader_physical( + param=param, + loaded_weight=loaded_weight, + weight_name=weight_name, + shard_id=shard_id, + expert_id=physical_expert_id, + ) + + def _weight_loader_physical( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, + ) -> None: expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) if expert_id == -1: return diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index b5305f923..ace06cb7b 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -325,6 +325,7 @@ class DeepseekV2MoE(nn.Module): num_experts=config.n_routed_experts + self.num_fused_shared_experts + global_server_args_dict["ep_num_redundant_experts"], + num_fused_shared_experts=self.num_fused_shared_experts, top_k=config.num_experts_per_tok + self.num_fused_shared_experts, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, @@ -2112,6 +2113,7 @@ class DeepseekV2ForCausalLM(nn.Module): if disable_reason is not None: global_server_args_dict["disable_shared_experts_fusion"] = True + self.num_fused_shared_experts = 0 log_info_on_rank0( logger, f"{disable_reason} Shared experts fusion optimization is disabled.", diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index f080beb50..6031e7600 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -434,6 +434,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): num_experts=config.n_routed_experts + self.num_fused_shared_experts + global_server_args_dict["ep_num_redundant_experts"], + num_fused_shared_experts=self.num_fused_shared_experts, top_k=config.num_experts_per_tok + self.num_fused_shared_experts, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, @@ -740,10 +741,11 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): global_server_args_dict["enable_deepep_moe"] or global_server_args_dict["enable_ep_moe"] ): - disable_reason = "Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode." + disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode." if disable_reason is not None: global_server_args_dict["disable_shared_experts_fusion"] = True + self.num_fused_shared_experts = 0 log_info_on_rank0( logger, f"{disable_reason} Shared experts fusion optimization is disabled.", diff --git a/python/sglang/srt/models/granitemoe.py b/python/sglang/srt/models/granitemoe.py index 1e6109209..2da7d857f 100644 --- a/python/sglang/srt/models/granitemoe.py +++ b/python/sglang/srt/models/granitemoe.py @@ -43,6 +43,7 @@ class GraniteMoeMoE(nn.Module): top_k: int, hidden_size: int, intermediate_size: int, + layer_id: int, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, @@ -71,6 +72,7 @@ class GraniteMoeMoE(nn.Module): top_k=top_k, hidden_size=hidden_size, intermediate_size=intermediate_size, + layer_id=layer_id, params_dtype=params_dtype, reduce_results=True, quant_config=quant_config, @@ -203,6 +205,7 @@ class GraniteMoeDecoderLayer(nn.Module): top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, + layer_id=layer_id, quant_config=quant_config, prefix=f"{prefix}.block_sparse_moe", ) diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 4a46bf197..aa458bb65 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -78,6 +78,7 @@ class Grok1MoE(nn.Module): def __init__( self, config: PretrainedConfig, + layer_id: int, num_experts: int, top_k: int, hidden_size: int, @@ -128,6 +129,7 @@ class Grok1MoE(nn.Module): self.experts = MoEImpl( num_experts=num_experts, top_k=top_k, + layer_id=layer_id, hidden_size=hidden_size, intermediate_size=intermediate_size, params_dtype=params_dtype, @@ -331,6 +333,7 @@ class Grok1DecoderLayer(nn.Module): ) self.block_sparse_moe = Grok1MoE( config=config, + layer_id=layer_id, num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, diff --git a/python/sglang/srt/models/hunyuan.py b/python/sglang/srt/models/hunyuan.py index 58e95bbb1..c1ed2543c 100644 --- a/python/sglang/srt/models/hunyuan.py +++ b/python/sglang/srt/models/hunyuan.py @@ -163,6 +163,7 @@ class HunYuanSparseMoeBlock(nn.Module): hidden_size=config.hidden_size, intermediate_size=intermediate_size, reduce_results=False, + layer_id=layer_id, quant_config=quant_config, ) diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index cf0b20800..265a9391d 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -87,6 +87,7 @@ class Llama4MoE(nn.Module): def __init__( self, config: Llama4TextConfig, + layer_id: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -114,6 +115,7 @@ class Llama4MoE(nn.Module): num_experts=config.num_local_experts, hidden_size=config.hidden_size, intermediate_size=intermediate_size_moe, + layer_id=layer_id, reduce_results=False, quant_config=quant_config, apply_router_weight_on_input=True, @@ -373,6 +375,7 @@ class Llama4DecoderLayer(nn.Module): if is_moe_layer: self.feed_forward = Llama4MoE( config=config, + layer_id=layer_id, quant_config=quant_config, prefix=add_prefix("feed_forward", prefix), ) diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index b09fc2f24..365825d20 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -69,6 +69,7 @@ class MixtralMoE(nn.Module): top_k: int, hidden_size: int, intermediate_size: int, + layer_id: int, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, @@ -97,6 +98,7 @@ class MixtralMoE(nn.Module): self.experts = MoEImpl( num_experts=num_experts, top_k=top_k, + layer_id=layer_id, hidden_size=hidden_size, intermediate_size=intermediate_size, params_dtype=params_dtype, @@ -226,6 +228,7 @@ class MixtralDecoderLayer(nn.Module): top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, + layer_id=layer_id, quant_config=quant_config, prefix=add_prefix("block_sparse_moe", prefix), ) diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index ce53f2b01..e2db2dceb 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -63,6 +63,7 @@ class OlmoeMoE(nn.Module): params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, + layer_id: int = 0, prefix: str = "", ): super().__init__() @@ -89,6 +90,7 @@ class OlmoeMoE(nn.Module): reduce_results=True, quant_config=quant_config, tp_size=tp_size, + layer_id=layer_id, prefix=add_prefix("experts", prefix), ) @@ -224,6 +226,7 @@ class OlmoeDecoderLayer(nn.Module): top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, + layer_id=layer_id, quant_config=quant_config, prefix=add_prefix("mlp", prefix), ) diff --git a/python/sglang/srt/models/phimoe.py b/python/sglang/srt/models/phimoe.py index 865b94f51..4604aeef9 100644 --- a/python/sglang/srt/models/phimoe.py +++ b/python/sglang/srt/models/phimoe.py @@ -210,6 +210,7 @@ class PhiMoE(nn.Module): self.experts = FusedMoE( num_experts=num_experts, top_k=top_k, + layer_id=layer_id, hidden_size=hidden_size, intermediate_size=intermediate_size, reduce_results=True,