diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 261b707d7..63b124fd7 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -607,7 +607,10 @@ class Qwen2MoeModel(nn.Module): ) else: if hidden_states.shape[0] != 0: - hidden_states, _ = self.norm(hidden_states, residual) + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index d553395f2..b00650cc5 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -32,6 +32,7 @@ from sglang.srt.distributed import ( get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + parallel_state, split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, @@ -54,8 +55,10 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput -from sglang.srt.layers.moe.ep_moe.layer import EPMoE +from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE +from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -65,11 +68,15 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_executor.forward_batch_info import ( + ForwardBatch, + ForwardMode, + PPProxyTensors, +) from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeModel -from sglang.srt.utils import add_prefix +from sglang.srt.utils import DeepEPMode, add_prefix Qwen3MoeConfig = None @@ -92,7 +99,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module): f"the number of experts {config.num_experts}." ) - MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + MoEImpl = ( + DeepEPMoE + if global_server_args_dict["enable_deepep_moe"] + else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE) + ) self.experts = MoEImpl( num_experts=config.num_experts, @@ -102,6 +113,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module): renormalize=config.norm_topk_prob, quant_config=quant_config, prefix=add_prefix("experts", prefix), + **( + dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]]) + if global_server_args_dict["enable_deepep_moe"] + else {} + ), ) self.gate = ReplicatedLinear( @@ -112,7 +128,37 @@ class Qwen3MoeSparseMoeBlock(nn.Module): prefix=add_prefix("gate", prefix), ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if global_server_args_dict["enable_deepep_moe"]: + # TODO: we will support tp < ep in the future + self.ep_size = get_tensor_model_parallel_world_size() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.renormalize = config.norm_topk_prob + + self.deepep_dispatcher = DeepEPDispatcher( + group=parallel_state.get_tp_group().device_group, + router_topk=self.top_k, + permute_fusion=True, + num_experts=config.num_experts, + num_local_experts=config.num_experts // self.tp_size, + hidden_size=config.hidden_size, + params_dtype=config.torch_dtype, + deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], + async_finish=True, # TODO + return_recv_hook=True, + ) + + def forward( + self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None + ) -> torch.Tensor: + + if not global_server_args_dict["enable_deepep_moe"]: + return self.forward_normal(hidden_states) + else: + return self.forward_deepep(hidden_states, forward_mode) + + def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -126,6 +172,68 @@ class Qwen3MoeSparseMoeBlock(nn.Module): return final_hidden_states.view(num_tokens, hidden_dim) + def forward_deepep( + self, hidden_states: torch.Tensor, forward_mode: ForwardMode + ) -> torch.Tensor: + if ( + forward_mode is not None + and not forward_mode.is_idle() + and hidden_states.shape[0] > 0 + ): + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + topk_weights, topk_idx = select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=False, + renormalize=self.renormalize, + ) + else: + topk_idx = torch.full( + (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device + ) + topk_weights = torch.empty( + (0, self.top_k), dtype=torch.float32, device=hidden_states.device + ) + if self.ep_size > 1: + # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value + ( + hidden_states, + topk_idx, + topk_weights, + reorder_topk_ids, + num_recv_tokens_per_expert, + seg_indptr, + masked_m, + expected_m, + ) = self.deepep_dispatcher.dispatch( + hidden_states, + topk_idx, + topk_weights, + forward_mode=forward_mode, + ) + final_hidden_states = self.experts( + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + reorder_topk_ids=reorder_topk_ids, + seg_indptr=seg_indptr, + masked_m=masked_m, + expected_m=expected_m, + num_recv_tokens_per_expert=num_recv_tokens_per_expert, + forward_mode=forward_mode, + ) + if self.ep_size > 1: + final_hidden_states = self.deepep_dispatcher.combine( + final_hidden_states, + topk_idx, + topk_weights, + forward_mode, + ) + return final_hidden_states + class Qwen3MoeAttention(nn.Module): def __init__( @@ -403,7 +511,7 @@ class Qwen3MoeDecoderLayer(nn.Module): ) # Fully Connected - hidden_states = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) # TODO: use reduce-scatter in MLP to avoid this scatter # Scatter @@ -577,7 +685,13 @@ class Qwen3MoeForCausalLM(nn.Module): ("gate_up_proj", "up_proj", 1), ] - MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + MoEImpl = ( + DeepEPMoE + if global_server_args_dict["enable_deepep_moe"] + else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE) + ) expert_params_mapping = MoEImpl.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj",