diff --git a/python/sglang/srt/managers/deepseek_eplb.py b/python/sglang/srt/managers/deepseek_eplb.py index 6fdf3d97d..7dd015bfe 100644 --- a/python/sglang/srt/managers/deepseek_eplb.py +++ b/python/sglang/srt/managers/deepseek_eplb.py @@ -1,6 +1,6 @@ # This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package -from typing import Literal, Tuple +from typing import Literal, Optional, Tuple import torch @@ -257,11 +257,15 @@ def rebalance_experts( tokens_per_expert: torch.Tensor, num_physical_experts: int, num_local_physical_experts: int, - num_groups: int, + num_groups: Optional[int], num_nodes: int, - phase: Literal["prefill", "decode"], + phase: Literal["prefill", "decode", "null"], ): - if phase == "prefill": + if ( + (phase == "prefill") + and (num_groups is not None) + and (num_groups % num_nodes == 0) + ): return prefill_rebalance_experts( tokens_per_expert=tokens_per_expert, num_physical_experts=num_physical_experts, @@ -269,10 +273,8 @@ def rebalance_experts( num_groups=num_groups, num_nodes=num_nodes, ) - if phase == "decode": - return decode_rebalance_experts( - tokens_per_expert=tokens_per_expert, - num_physical_experts=num_physical_experts, - num_local_physical_experts=num_local_physical_experts, - ) - raise NotImplementedError + return decode_rebalance_experts( + tokens_per_expert=tokens_per_expert, + num_physical_experts=num_physical_experts, + num_local_physical_experts=num_local_physical_experts, + ) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 97d439558..3979c762f 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -135,10 +135,6 @@ class ExpertLocationMetadata: model_config_for_expert_location = common["model_config_for_expert_location"] num_physical_experts = common["num_physical_experts"] - phase = server_args.disaggregation_mode - if phase == "null" or model_config_for_expert_location.num_groups is None: - phase = "decode" - physical_to_logical_map, logical_to_all_physical_map, expert_count = ( deepseek_eplb.rebalance_experts( tokens_per_expert=logical_count, @@ -146,7 +142,7 @@ class ExpertLocationMetadata: num_local_physical_experts=num_physical_experts // common["ep_size"], num_groups=model_config_for_expert_location.num_groups, num_nodes=server_args.nnodes, - phase=phase, + phase=server_args.disaggregation_mode, ) )