diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index cd244d570..d24caaaba 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -141,6 +141,7 @@ class EPMoE(torch.nn.Module): top_k: int, hidden_size: int, intermediate_size: int, + layer_id: int, params_dtype: Optional[torch.dtype] = None, renormalize: bool = True, use_grouped_topk: bool = False, @@ -164,6 +165,7 @@ class EPMoE(torch.nn.Module): ) self.tp_rank = get_tensor_model_parallel_rank() + self.layer_id = layer_id self.num_experts = num_experts assert self.num_experts % self.tp_size == 0 self.num_experts_per_partition = self.num_experts // self.tp_size @@ -837,6 +839,7 @@ class DeepEPMoE(EPMoE): top_k: int, hidden_size: int, intermediate_size: int, + layer_id: int, params_dtype: Optional[torch.dtype] = None, renormalize: bool = True, use_grouped_topk: bool = False, @@ -856,6 +859,7 @@ class DeepEPMoE(EPMoE): top_k, hidden_size, intermediate_size, + layer_id, params_dtype, renormalize, use_grouped_topk, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 74c6216ac..cef36660a 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -283,6 +283,7 @@ class FusedMoE(torch.nn.Module): top_k: int, hidden_size: int, intermediate_size: int, + layer_id: Optional[int] = None, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = False, renormalize: bool = True, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index b11734d85..849e3c76d 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -51,7 +51,7 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE, get_moe_impl_class +from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -114,7 +114,6 @@ if _is_hip: decode_attention_fwd_grouped_rope, ) - logger = logging.getLogger(__name__) @@ -216,6 +215,7 @@ class DeepseekV2MoE(nn.Module): def __init__( self, config: PretrainedConfig, + layer_id: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -224,6 +224,7 @@ class DeepseekV2MoE(nn.Module): self.routed_scaling_factor = config.routed_scaling_factor self.n_shared_experts = config.n_shared_experts self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] + self.layer_id = layer_id if self.tp_size > config.n_routed_experts: raise ValueError( @@ -244,6 +245,7 @@ class DeepseekV2MoE(nn.Module): top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1), hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, + layer_id=self.layer_id, renormalize=config.norm_topk_prob, quant_config=quant_config, use_grouped_topk=True, @@ -344,6 +346,9 @@ class DeepseekV2MoE(nn.Module): num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, routed_scaling_factor=self.routed_scaling_factor, + expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( + layer_id=self.layer_id, + ), ) else: state.topk_idx_local = torch.full( @@ -1183,6 +1188,7 @@ class DeepseekV2DecoderLayer(nn.Module): config=config, quant_config=quant_config, prefix=add_prefix("mlp", prefix), + layer_id=self.layer_id, ) else: if enable_moe_dense_fully_dp(): @@ -1246,9 +1252,7 @@ class DeepseekV2DecoderLayer(nn.Module): zero_allocator: BumpAllocator, ): state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = ( - self.layer_communicator.prepare_attn( - hidden_states, residual, state.forward_batch - ) + self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch) ) state.update( dict(