diff --git a/python/sglang/srt/managers/eplb_algorithms/__init__.py b/python/sglang/srt/managers/eplb_algorithms/__init__.py new file mode 100644 index 000000000..7a970c320 --- /dev/null +++ b/python/sglang/srt/managers/eplb_algorithms/__init__.py @@ -0,0 +1,63 @@ +from enum import Enum, auto +from typing import Optional + +import torch + +from sglang.srt.managers.eplb_algorithms import deepseek, deepseek_vec + + +class EplbAlgorithm(Enum): + deepseek = auto() + deepseek_hierarchical = auto() + deepseek_vec = auto() + deepseek_vec_hierarchical = auto() + # TODO may have more algorithm later + + +def rebalance_experts( + tokens_per_expert: torch.Tensor, + num_physical_experts: int, + num_local_physical_experts: int, + num_groups: Optional[int], + num_nodes: int, + algorithm: EplbAlgorithm, +): + if algorithm in [EplbAlgorithm.deepseek, EplbAlgorithm.deepseek_hierarchical]: + return deepseek.rebalance_experts( + weight=tokens_per_expert.sum(dim=0), + num_replicas=num_physical_experts, + num_groups=num_groups, + num_nodes=num_nodes, + num_gpus=num_physical_experts // num_local_physical_experts, + enable_hierarchical=algorithm == EplbAlgorithm.deepseek_hierarchical, + ) + + if algorithm in [ + EplbAlgorithm.deepseek_vec, + EplbAlgorithm.deepseek_vec_hierarchical, + ]: + return deepseek_vec.rebalance_experts( + tokens_per_expert=tokens_per_expert, + num_physical_experts=num_physical_experts, + num_local_physical_experts=num_local_physical_experts, + num_groups=num_groups, + num_nodes=num_nodes, + enable_hierarchical=algorithm == EplbAlgorithm.deepseek_vec_hierarchical, + ) + + raise NotImplementedError + + +def compute_algorithm( + raw_algorithm: str, + num_groups: Optional[int], + num_nodes: int, +) -> EplbAlgorithm: + if raw_algorithm != "auto": + return EplbAlgorithm[raw_algorithm] + + # TODO test on real scenarios and know which ones perform better + if (num_groups is not None) and (num_groups % num_nodes == 0): + return EplbAlgorithm.deepseek_hierarchical + else: + return EplbAlgorithm.deepseek diff --git a/python/sglang/srt/managers/eplb_algorithms/deepseek.py b/python/sglang/srt/managers/eplb_algorithms/deepseek.py new file mode 100644 index 000000000..180ccdee4 --- /dev/null +++ b/python/sglang/srt/managers/eplb_algorithms/deepseek.py @@ -0,0 +1,223 @@ +# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package +from typing import Tuple + +import torch + +from sglang.srt.utils import get_bool_env_var + + +def balanced_packing( + weight: torch.Tensor, num_packs: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs + are as balanced as possible. + + Parameters: + weight: [X, n], the weight of each item + num_packs: number of packs + + Returns: + pack_index: [X, n], the pack index of each item + rank_in_pack: [X, n], the rank of the item in the pack + """ + num_layers, num_groups = weight.shape + assert num_groups % num_packs == 0 + groups_per_pack = num_groups // num_packs + + if groups_per_pack == 1: + pack_index = torch.arange( + weight.size(-1), dtype=torch.int64, device=weight.device + ).expand(weight.shape) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64) + return pack_index, rank_in_pack + + indices = weight.float().sort(-1, descending=True).indices.cpu() + pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu") + rank_in_pack = torch.full_like(pack_index, fill_value=-1) + for i in range(num_layers): + pack_weights = [0] * num_packs + pack_items = [0] * num_packs + for group in indices[i]: + pack = min( + (i for i in range(num_packs) if pack_items[i] < groups_per_pack), + key=pack_weights.__getitem__, + ) + assert pack_items[pack] < groups_per_pack + pack_index[i, group] = pack + rank_in_pack[i, group] = pack_items[pack] + pack_weights[pack] += weight[i, group] + pack_items[pack] += 1 + return pack_index, rank_in_pack + + +def replicate_experts( + weight: torch.Tensor, num_phy: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized. + + Parameters: + weight: [X, num_log] + num_phy: total number of experts after replication + + Returns: + phy2log: [X, num_phy], logical expert id of each physical expert + rank: [X, num_phy], the replica rank + logcnt: [X, num_log], number of replicas for each logical expert + """ + n, num_log = weight.shape + num_redundant = num_phy - num_log + assert num_redundant >= 0 + device = weight.device + phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) + rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) + logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) + arangen = torch.arange(n, dtype=torch.int64, device=device) + for i in range(num_log, num_phy): + redundant_indices = (weight / logcnt).max(dim=-1).indices + phy2log[:, i] = redundant_indices + rank[:, i] = logcnt[arangen, redundant_indices] + logcnt[arangen, redundant_indices] += 1 + return phy2log, rank, logcnt + + +def rebalance_experts_hierarchical( + weight: torch.Tensor, + num_physical_experts: int, + num_groups: int, + num_nodes: int, + num_gpus: int, +): + """ + Parameters: + weight: [num_moe_layers, num_logical_experts] + num_physical_experts: number of physical experts after replication + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [num_moe_layers, num_physical_experts] + logical_to_physical_map: [num_moe_layers, num_logical_experts, X] + logical_count: [num_moe_layers, num_logical_experts] + """ + num_layers, num_logical_experts = weight.shape + assert num_logical_experts % num_groups == 0 + group_size = num_logical_experts // num_groups + assert num_groups % num_nodes == 0 + groups_per_node = num_groups // num_nodes + assert num_gpus % num_nodes == 0 + assert num_physical_experts % num_gpus == 0 + phy_experts_per_gpu = num_physical_experts // num_gpus + + def inverse(perm: torch.Tensor) -> torch.Tensor: + inv = torch.empty_like(perm) + inv.scatter_( + 1, + perm, + torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand( + perm.shape + ), + ) + return inv + + # Step 1: pack groups to nodes + tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) + group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes) + log2mlog = ( + ( + (group_pack_index * groups_per_node + group_rank_in_pack) * group_size + ).unsqueeze(-1) + + torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device) + ).flatten(-2) + mlog2log = inverse(log2mlog) + + # Step 2: construct redundant experts within nodes + # [num_layers * num_nodes, num_logical_experts // num_nodes] + tokens_per_mlog = weight.gather(-1, mlog2log).view( + -1, num_logical_experts // num_nodes + ) + phy2mlog, phyrank, mlogcnt = replicate_experts( + tokens_per_mlog, num_physical_experts // num_nodes + ) + + # Step 3: pack physical_experts to GPUs + # [num_layers * num_nodes, num_physical_experts // num_nodes] + tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) + pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes) + phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack + pphy2phy = inverse(phy2pphy) + + pphy2mlog = phy2mlog.gather( + -1, pphy2phy + ) # [num_layers * num_nodes, num_log_per_nodes] + pphy2mlog = ( + pphy2mlog.view(num_layers, num_nodes, -1) + + torch.arange( + 0, + num_logical_experts, + num_logical_experts // num_nodes, + device=group_pack_index.device, + ).view(1, -1, 1) + ).flatten(-2) + pphy2log = mlog2log.gather(-1, pphy2mlog) + pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) + logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) + return pphy2log, pphyrank, logcnt + + +def rebalance_experts( + weight: torch.Tensor, + num_replicas: int, + num_groups: int, + num_nodes: int, + num_gpus: int, + enable_hierarchical: bool, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Entry point for expert-parallelism load balancer. + + Parameters: + weight: [layers, num_logical_experts], the load statistics for all logical experts + num_replicas: number of physical experts, must be a multiple of `num_gpus` + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [layers, num_replicas], the expert index of each replica + logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert + expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert + """ + + num_layers, num_logical_experts = weight.shape + weight = weight.float().cpu() + if enable_hierarchical: + # use hierarchical load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, num_groups, num_nodes, num_gpus + ) + else: + # use global load-balance policy + phy2log, phyrank, logcnt = rebalance_experts_hierarchical( + weight, num_replicas, 1, 1, num_gpus + ) + maxlogcnt = logcnt.max().item() + log2phy: torch.Tensor = torch.full( + (num_layers, num_logical_experts, maxlogcnt), + -1, + dtype=torch.int64, + device=logcnt.device, + ) + log2phy.view(num_layers, -1).scatter_( + -1, + phy2log * maxlogcnt + phyrank, + torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( + num_layers, -1 + ), + ) + return phy2log, log2phy, logcnt + + +__all__ = ["rebalance_experts"] diff --git a/python/sglang/srt/managers/deepseek_eplb.py b/python/sglang/srt/managers/eplb_algorithms/deepseek_vec.py similarity index 96% rename from python/sglang/srt/managers/deepseek_eplb.py rename to python/sglang/srt/managers/eplb_algorithms/deepseek_vec.py index 7dd015bfe..cb165448a 100644 --- a/python/sglang/srt/managers/deepseek_eplb.py +++ b/python/sglang/srt/managers/eplb_algorithms/deepseek_vec.py @@ -1,6 +1,5 @@ # This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package - -from typing import Literal, Optional, Tuple +from typing import Optional, Tuple import torch @@ -259,13 +258,9 @@ def rebalance_experts( num_local_physical_experts: int, num_groups: Optional[int], num_nodes: int, - phase: Literal["prefill", "decode", "null"], + enable_hierarchical: bool, ): - if ( - (phase == "prefill") - and (num_groups is not None) - and (num_groups % num_nodes == 0) - ): + if enable_hierarchical: return prefill_rebalance_experts( tokens_per_expert=tokens_per_expert, num_physical_experts=num_physical_experts, @@ -273,8 +268,9 @@ def rebalance_experts( num_groups=num_groups, num_nodes=num_nodes, ) - return decode_rebalance_experts( - tokens_per_expert=tokens_per_expert, - num_physical_experts=num_physical_experts, - num_local_physical_experts=num_local_physical_experts, - ) + else: + return decode_rebalance_experts( + tokens_per_expert=tokens_per_expert, + num_physical_experts=num_physical_experts, + num_local_physical_experts=num_local_physical_experts, + ) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 3979c762f..e0bca7ec2 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -22,7 +22,7 @@ import torch.distributed import torch.nn.functional as F from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.managers import deepseek_eplb +from sglang.srt.managers import eplb_algorithms from sglang.srt.model_loader import get_model_architecture from sglang.srt.server_args import ServerArgs @@ -134,15 +134,21 @@ class ExpertLocationMetadata: common = ExpertLocationMetadata._init_common(server_args, model_config) model_config_for_expert_location = common["model_config_for_expert_location"] num_physical_experts = common["num_physical_experts"] + num_groups = model_config_for_expert_location.num_groups + num_nodes = server_args.nnodes physical_to_logical_map, logical_to_all_physical_map, expert_count = ( - deepseek_eplb.rebalance_experts( + eplb_algorithms.rebalance_experts( tokens_per_expert=logical_count, num_physical_experts=num_physical_experts, num_local_physical_experts=num_physical_experts // common["ep_size"], - num_groups=model_config_for_expert_location.num_groups, - num_nodes=server_args.nnodes, - phase=server_args.disaggregation_mode, + num_groups=num_groups, + num_nodes=num_nodes, + algorithm=eplb_algorithms.compute_algorithm( + raw_algorithm=server_args.eplb_algorithm, + num_groups=num_groups, + num_nodes=num_nodes, + ), ) ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 2fa0a0eee..12823b5ab 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -175,6 +175,7 @@ class ServerArgs: ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None init_expert_location: str = "trivial" enable_eplb: bool = False + eplb_algorithm: str = "auto" eplb_rebalance_num_iterations: int = 1000 expert_distribution_recorder_mode: Optional[ Literal["stat", "per_pass", "per_token"] @@ -1328,6 +1329,12 @@ class ServerArgs: action="store_true", help="Enable EPLB algorithm", ) + parser.add_argument( + "--eplb-algorithm", + type=str, + default=ServerArgs.eplb_algorithm, + help="Chosen EPLB algorithm", + ) parser.add_argument( "--eplb-rebalance-num-iterations", type=int,