From 0ca1811715eaeb4bbb47dca3adba98d58be0dcd0 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Mon, 26 May 2025 13:35:51 +0800 Subject: [PATCH] Support fake perfectly balanced EP dispatch algorithm (#6571) --- python/sglang/srt/layers/moe/topk.py | 10 ++++++++++ .../srt/managers/expert_location_dispatch.py | 14 +++++++++++++- python/sglang/srt/server_args.py | 2 +- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 52752a7ce..624162799 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -18,6 +18,7 @@ from typing import Callable, Optional import torch import torch.nn.functional as F +from sglang.srt.managers import expert_location_dispatch from sglang.srt.managers.expert_distribution import ( ExpertDistributionRecorder, get_global_expert_distribution_recorder, @@ -310,6 +311,15 @@ def select_experts( expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ): n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] + + router_logits, correction_bias = ( + expert_location_dispatch.transform_select_experts_inputs( + router_logits=router_logits, + correction_bias=correction_bias, + info=expert_location_dispatch_info, + ) + ) + # DeepSeek V2/V3/R1 series models use grouped_top_k if use_grouped_topk: assert topk_group is not None diff --git a/python/sglang/srt/managers/expert_location_dispatch.py b/python/sglang/srt/managers/expert_location_dispatch.py index 1e4d7b06e..6880b01a2 100644 --- a/python/sglang/srt/managers/expert_location_dispatch.py +++ b/python/sglang/srt/managers/expert_location_dispatch.py @@ -55,6 +55,18 @@ class ExpertLocationDispatchInfo: ) +def transform_select_experts_inputs( + router_logits: torch.Tensor, + correction_bias: Optional[torch.Tensor], + info: Optional[ExpertLocationDispatchInfo], +): + if (info is not None) and (info.ep_dispatch_algorithm == "fake"): + router_logits = torch.randn_like(router_logits) + if correction_bias is not None: + correction_bias = torch.zeros_like(correction_bias) + return router_logits, correction_bias + + def topk_ids_logical_to_physical( topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo] ) -> torch.Tensor: @@ -63,7 +75,7 @@ def topk_ids_logical_to_physical( if info.ep_dispatch_algorithm == "static": return _topk_ids_logical_to_physical_static(topk_ids, info) - if info.ep_dispatch_algorithm == "dynamic": + if info.ep_dispatch_algorithm in ["dynamic", "fake"]: return _topk_ids_logical_to_physical_dynamic(topk_ids, info) raise NotImplementedError diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f00aa11ac..8ac4fe494 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -172,7 +172,7 @@ class ServerArgs: enable_deepep_moe: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" ep_num_redundant_experts: int = 0 - ep_dispatch_algorithm: Optional[Literal["static", "dynamic"]] = None + ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None init_expert_location: str = "trivial" enable_eplb: bool = False eplb_rebalance_num_iterations: int = 1000