From 5c7aa009766983361b2acaf3c3a3864101ab7b8a Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Mon, 26 May 2025 23:43:24 +0800 Subject: [PATCH] Fix EPLB algorithm fail to run when using 3 nodes for prefill (#6629) --- python/sglang/srt/managers/deepseek_eplb.py | 24 ++++++++++--------- python/sglang/srt/managers/expert_location.py | 6 +---- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/managers/deepseek_eplb.py b/python/sglang/srt/managers/deepseek_eplb.py index 6fdf3d97d..7dd015bfe 100644 --- a/python/sglang/srt/managers/deepseek_eplb.py +++ b/python/sglang/srt/managers/deepseek_eplb.py @@ -1,6 +1,6 @@ # This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package -from typing import Literal, Tuple +from typing import Literal, Optional, Tuple import torch @@ -257,11 +257,15 @@ def rebalance_experts( tokens_per_expert: torch.Tensor, num_physical_experts: int, num_local_physical_experts: int, - num_groups: int, + num_groups: Optional[int], num_nodes: int, - phase: Literal["prefill", "decode"], + phase: Literal["prefill", "decode", "null"], ): - if phase == "prefill": + if ( + (phase == "prefill") + and (num_groups is not None) + and (num_groups % num_nodes == 0) + ): return prefill_rebalance_experts( tokens_per_expert=tokens_per_expert, num_physical_experts=num_physical_experts, @@ -269,10 +273,8 @@ def rebalance_experts( num_groups=num_groups, num_nodes=num_nodes, ) - if phase == "decode": - return decode_rebalance_experts( - tokens_per_expert=tokens_per_expert, - num_physical_experts=num_physical_experts, - num_local_physical_experts=num_local_physical_experts, - ) - raise NotImplementedError + return decode_rebalance_experts( + tokens_per_expert=tokens_per_expert, + num_physical_experts=num_physical_experts, + num_local_physical_experts=num_local_physical_experts, + ) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 97d439558..3979c762f 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -135,10 +135,6 @@ class ExpertLocationMetadata: model_config_for_expert_location = common["model_config_for_expert_location"] num_physical_experts = common["num_physical_experts"] - phase = server_args.disaggregation_mode - if phase == "null" or model_config_for_expert_location.num_groups is None: - phase = "decode" - physical_to_logical_map, logical_to_all_physical_map, expert_count = ( deepseek_eplb.rebalance_experts( tokens_per_expert=logical_count, @@ -146,7 +142,7 @@ class ExpertLocationMetadata: num_local_physical_experts=num_physical_experts // common["ep_size"], num_groups=model_config_for_expert_location.num_groups, num_nodes=server_args.nnodes, - phase=phase, + phase=server_args.disaggregation_mode, ) )