diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 9d3ae3947..52752a7ce 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -65,6 +65,7 @@ def fused_topk( gating_output: torch.Tensor, topk: int, renormalize: bool, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" @@ -88,7 +89,7 @@ def fused_topk( if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - + topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) return topk_weights, topk_ids @@ -355,12 +356,13 @@ def select_experts( assert ( num_token_non_padded is None ), "num_token_non_padded is not yet supported in fused_topk" - assert expert_location_dispatch_info is None + # Qwen3MOE uses fused_topk topk_weights, topk_ids = fused_topk( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize, + expert_location_dispatch_info=expert_location_dispatch_info, ) else: assert ( diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index db5b82a6d..7191bedd8 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -690,7 +690,9 @@ def _convert_global_physical_count_to_logical_count( ) logical_count.scatter_add_( dim=2, - index=physical_to_logical_map.unsqueeze(0).expand(dim_extra, -1, -1), + index=physical_to_logical_map.unsqueeze(0) + .expand(dim_extra, -1, -1) + .to(torch.int64), src=global_physical_count, ) return logical_count diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index b00650cc5..36ab20a8e 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -55,7 +55,7 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput -from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE +from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import select_experts @@ -67,6 +67,8 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.expert_location import ModelConfigForExpertLocation +from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, @@ -86,28 +88,25 @@ logger = logging.getLogger(__name__) class Qwen3MoeSparseMoeBlock(nn.Module): def __init__( self, + layer_id: int, config: Qwen3MoeConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() - + self.layer_id = layer_id if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}." ) - MoEImpl = ( - DeepEPMoE - if global_server_args_dict["enable_deepep_moe"] - else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE) - ) - - self.experts = MoEImpl( - num_experts=config.num_experts, + self.experts = get_moe_impl_class()( + num_experts=config.num_experts + + global_server_args_dict["ep_num_redundant_experts"], top_k=config.num_experts_per_tok, + layer_id=layer_id, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, renormalize=config.norm_topk_prob, @@ -131,7 +130,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): if global_server_args_dict["enable_deepep_moe"]: # TODO: we will support tp < ep in the future self.ep_size = get_tensor_model_parallel_world_size() - self.num_experts = config.num_experts + self.num_experts = ( + config.num_experts + global_server_args_dict["ep_num_redundant_experts"] + ) self.top_k = config.num_experts_per_tok self.renormalize = config.norm_topk_prob @@ -139,7 +140,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): group=parallel_state.get_tp_group().device_group, router_topk=self.top_k, permute_fusion=True, - num_experts=config.num_experts, + num_experts=self.num_experts, num_local_experts=config.num_experts // self.tp_size, hidden_size=config.hidden_size, params_dtype=config.torch_dtype, @@ -157,8 +158,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module): else: return self.forward_deepep(hidden_states, forward_mode) - def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: + def get_moe_weights(self): + return [ + x.data + for name, x in self.experts.named_parameters() + if name not in ["correction_bias"] + ] + def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -189,6 +196,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): top_k=self.top_k, use_grouped_topk=False, renormalize=self.renormalize, + expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( + layer_id=self.layer_id, + ), ) else: topk_idx = torch.full( @@ -408,6 +418,7 @@ class Qwen3MoeDecoderLayer(nn.Module): if self.info.is_sparse: self.mlp = Qwen3MoeSparseMoeBlock( + layer_id=self.layer_id, config=config, quant_config=quant_config, prefix=add_prefix("mlp", prefix), @@ -685,15 +696,7 @@ class Qwen3MoeForCausalLM(nn.Module): ("gate_up_proj", "up_proj", 1), ] - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - MoEImpl = ( - DeepEPMoE - if global_server_args_dict["enable_deepep_moe"] - else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE) - ) - - expert_params_mapping = MoEImpl.make_expert_params_mapping( + expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", @@ -770,5 +773,19 @@ class Qwen3MoeForCausalLM(nn.Module): else: logger.warning(f"Parameter {name} not found in params_dict") + self.routed_experts_weights_of_layer = { + layer_id: layer.mlp.get_moe_weights() + for layer_id, layer in enumerate(self.model.layers) + if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock) + } + + @classmethod + def get_model_config_for_expert_location(cls, config): + return ModelConfigForExpertLocation( + num_layers=config.num_hidden_layers, + num_logical_experts=config.num_experts, + num_groups=None, + ) + EntryClass = Qwen3MoeForCausalLM