Support picking variants of EPLB algorithms (#6728)
This commit is contained in:
63
python/sglang/srt/managers/eplb_algorithms/__init__.py
Normal file
63
python/sglang/srt/managers/eplb_algorithms/__init__.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from enum import Enum, auto
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.eplb_algorithms import deepseek, deepseek_vec
|
||||
|
||||
|
||||
class EplbAlgorithm(Enum):
|
||||
deepseek = auto()
|
||||
deepseek_hierarchical = auto()
|
||||
deepseek_vec = auto()
|
||||
deepseek_vec_hierarchical = auto()
|
||||
# TODO may have more algorithm later
|
||||
|
||||
|
||||
def rebalance_experts(
|
||||
tokens_per_expert: torch.Tensor,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
num_groups: Optional[int],
|
||||
num_nodes: int,
|
||||
algorithm: EplbAlgorithm,
|
||||
):
|
||||
if algorithm in [EplbAlgorithm.deepseek, EplbAlgorithm.deepseek_hierarchical]:
|
||||
return deepseek.rebalance_experts(
|
||||
weight=tokens_per_expert.sum(dim=0),
|
||||
num_replicas=num_physical_experts,
|
||||
num_groups=num_groups,
|
||||
num_nodes=num_nodes,
|
||||
num_gpus=num_physical_experts // num_local_physical_experts,
|
||||
enable_hierarchical=algorithm == EplbAlgorithm.deepseek_hierarchical,
|
||||
)
|
||||
|
||||
if algorithm in [
|
||||
EplbAlgorithm.deepseek_vec,
|
||||
EplbAlgorithm.deepseek_vec_hierarchical,
|
||||
]:
|
||||
return deepseek_vec.rebalance_experts(
|
||||
tokens_per_expert=tokens_per_expert,
|
||||
num_physical_experts=num_physical_experts,
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
num_groups=num_groups,
|
||||
num_nodes=num_nodes,
|
||||
enable_hierarchical=algorithm == EplbAlgorithm.deepseek_vec_hierarchical,
|
||||
)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def compute_algorithm(
|
||||
raw_algorithm: str,
|
||||
num_groups: Optional[int],
|
||||
num_nodes: int,
|
||||
) -> EplbAlgorithm:
|
||||
if raw_algorithm != "auto":
|
||||
return EplbAlgorithm[raw_algorithm]
|
||||
|
||||
# TODO test on real scenarios and know which ones perform better
|
||||
if (num_groups is not None) and (num_groups % num_nodes == 0):
|
||||
return EplbAlgorithm.deepseek_hierarchical
|
||||
else:
|
||||
return EplbAlgorithm.deepseek
|
||||
223
python/sglang/srt/managers/eplb_algorithms/deepseek.py
Normal file
223
python/sglang/srt/managers/eplb_algorithms/deepseek.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
|
||||
def balanced_packing(
|
||||
weight: torch.Tensor, num_packs: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs
|
||||
are as balanced as possible.
|
||||
|
||||
Parameters:
|
||||
weight: [X, n], the weight of each item
|
||||
num_packs: number of packs
|
||||
|
||||
Returns:
|
||||
pack_index: [X, n], the pack index of each item
|
||||
rank_in_pack: [X, n], the rank of the item in the pack
|
||||
"""
|
||||
num_layers, num_groups = weight.shape
|
||||
assert num_groups % num_packs == 0
|
||||
groups_per_pack = num_groups // num_packs
|
||||
|
||||
if groups_per_pack == 1:
|
||||
pack_index = torch.arange(
|
||||
weight.size(-1), dtype=torch.int64, device=weight.device
|
||||
).expand(weight.shape)
|
||||
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
indices = weight.float().sort(-1, descending=True).indices.cpu()
|
||||
pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu")
|
||||
rank_in_pack = torch.full_like(pack_index, fill_value=-1)
|
||||
for i in range(num_layers):
|
||||
pack_weights = [0] * num_packs
|
||||
pack_items = [0] * num_packs
|
||||
for group in indices[i]:
|
||||
pack = min(
|
||||
(i for i in range(num_packs) if pack_items[i] < groups_per_pack),
|
||||
key=pack_weights.__getitem__,
|
||||
)
|
||||
assert pack_items[pack] < groups_per_pack
|
||||
pack_index[i, group] = pack
|
||||
rank_in_pack[i, group] = pack_items[pack]
|
||||
pack_weights[pack] += weight[i, group]
|
||||
pack_items[pack] += 1
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
|
||||
def replicate_experts(
|
||||
weight: torch.Tensor, num_phy: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.
|
||||
|
||||
Parameters:
|
||||
weight: [X, num_log]
|
||||
num_phy: total number of experts after replication
|
||||
|
||||
Returns:
|
||||
phy2log: [X, num_phy], logical expert id of each physical expert
|
||||
rank: [X, num_phy], the replica rank
|
||||
logcnt: [X, num_log], number of replicas for each logical expert
|
||||
"""
|
||||
n, num_log = weight.shape
|
||||
num_redundant = num_phy - num_log
|
||||
assert num_redundant >= 0
|
||||
device = weight.device
|
||||
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
|
||||
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
|
||||
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
|
||||
arangen = torch.arange(n, dtype=torch.int64, device=device)
|
||||
for i in range(num_log, num_phy):
|
||||
redundant_indices = (weight / logcnt).max(dim=-1).indices
|
||||
phy2log[:, i] = redundant_indices
|
||||
rank[:, i] = logcnt[arangen, redundant_indices]
|
||||
logcnt[arangen, redundant_indices] += 1
|
||||
return phy2log, rank, logcnt
|
||||
|
||||
|
||||
def rebalance_experts_hierarchical(
|
||||
weight: torch.Tensor,
|
||||
num_physical_experts: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_gpus: int,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
weight: [num_moe_layers, num_logical_experts]
|
||||
num_physical_experts: number of physical experts after replication
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
|
||||
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||
|
||||
Returns:
|
||||
physical_to_logical_map: [num_moe_layers, num_physical_experts]
|
||||
logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
|
||||
logical_count: [num_moe_layers, num_logical_experts]
|
||||
"""
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
assert num_logical_experts % num_groups == 0
|
||||
group_size = num_logical_experts // num_groups
|
||||
assert num_groups % num_nodes == 0
|
||||
groups_per_node = num_groups // num_nodes
|
||||
assert num_gpus % num_nodes == 0
|
||||
assert num_physical_experts % num_gpus == 0
|
||||
phy_experts_per_gpu = num_physical_experts // num_gpus
|
||||
|
||||
def inverse(perm: torch.Tensor) -> torch.Tensor:
|
||||
inv = torch.empty_like(perm)
|
||||
inv.scatter_(
|
||||
1,
|
||||
perm,
|
||||
torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(
|
||||
perm.shape
|
||||
),
|
||||
)
|
||||
return inv
|
||||
|
||||
# Step 1: pack groups to nodes
|
||||
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
|
||||
group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes)
|
||||
log2mlog = (
|
||||
(
|
||||
(group_pack_index * groups_per_node + group_rank_in_pack) * group_size
|
||||
).unsqueeze(-1)
|
||||
+ torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)
|
||||
).flatten(-2)
|
||||
mlog2log = inverse(log2mlog)
|
||||
|
||||
# Step 2: construct redundant experts within nodes
|
||||
# [num_layers * num_nodes, num_logical_experts // num_nodes]
|
||||
tokens_per_mlog = weight.gather(-1, mlog2log).view(
|
||||
-1, num_logical_experts // num_nodes
|
||||
)
|
||||
phy2mlog, phyrank, mlogcnt = replicate_experts(
|
||||
tokens_per_mlog, num_physical_experts // num_nodes
|
||||
)
|
||||
|
||||
# Step 3: pack physical_experts to GPUs
|
||||
# [num_layers * num_nodes, num_physical_experts // num_nodes]
|
||||
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
|
||||
pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes)
|
||||
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
|
||||
pphy2phy = inverse(phy2pphy)
|
||||
|
||||
pphy2mlog = phy2mlog.gather(
|
||||
-1, pphy2phy
|
||||
) # [num_layers * num_nodes, num_log_per_nodes]
|
||||
pphy2mlog = (
|
||||
pphy2mlog.view(num_layers, num_nodes, -1)
|
||||
+ torch.arange(
|
||||
0,
|
||||
num_logical_experts,
|
||||
num_logical_experts // num_nodes,
|
||||
device=group_pack_index.device,
|
||||
).view(1, -1, 1)
|
||||
).flatten(-2)
|
||||
pphy2log = mlog2log.gather(-1, pphy2mlog)
|
||||
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
|
||||
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
|
||||
return pphy2log, pphyrank, logcnt
|
||||
|
||||
|
||||
def rebalance_experts(
|
||||
weight: torch.Tensor,
|
||||
num_replicas: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_gpus: int,
|
||||
enable_hierarchical: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Entry point for expert-parallelism load balancer.
|
||||
|
||||
Parameters:
|
||||
weight: [layers, num_logical_experts], the load statistics for all logical experts
|
||||
num_replicas: number of physical experts, must be a multiple of `num_gpus`
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
|
||||
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||
|
||||
Returns:
|
||||
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
|
||||
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
|
||||
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
|
||||
"""
|
||||
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
weight = weight.float().cpu()
|
||||
if enable_hierarchical:
|
||||
# use hierarchical load-balance policy
|
||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
else:
|
||||
# use global load-balance policy
|
||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
|
||||
weight, num_replicas, 1, 1, num_gpus
|
||||
)
|
||||
maxlogcnt = logcnt.max().item()
|
||||
log2phy: torch.Tensor = torch.full(
|
||||
(num_layers, num_logical_experts, maxlogcnt),
|
||||
-1,
|
||||
dtype=torch.int64,
|
||||
device=logcnt.device,
|
||||
)
|
||||
log2phy.view(num_layers, -1).scatter_(
|
||||
-1,
|
||||
phy2log * maxlogcnt + phyrank,
|
||||
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
|
||||
num_layers, -1
|
||||
),
|
||||
)
|
||||
return phy2log, log2phy, logcnt
|
||||
|
||||
|
||||
__all__ = ["rebalance_experts"]
|
||||
@@ -1,6 +1,5 @@
|
||||
# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
|
||||
|
||||
from typing import Literal, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -259,13 +258,9 @@ def rebalance_experts(
|
||||
num_local_physical_experts: int,
|
||||
num_groups: Optional[int],
|
||||
num_nodes: int,
|
||||
phase: Literal["prefill", "decode", "null"],
|
||||
enable_hierarchical: bool,
|
||||
):
|
||||
if (
|
||||
(phase == "prefill")
|
||||
and (num_groups is not None)
|
||||
and (num_groups % num_nodes == 0)
|
||||
):
|
||||
if enable_hierarchical:
|
||||
return prefill_rebalance_experts(
|
||||
tokens_per_expert=tokens_per_expert,
|
||||
num_physical_experts=num_physical_experts,
|
||||
@@ -273,8 +268,9 @@ def rebalance_experts(
|
||||
num_groups=num_groups,
|
||||
num_nodes=num_nodes,
|
||||
)
|
||||
return decode_rebalance_experts(
|
||||
tokens_per_expert=tokens_per_expert,
|
||||
num_physical_experts=num_physical_experts,
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
)
|
||||
else:
|
||||
return decode_rebalance_experts(
|
||||
tokens_per_expert=tokens_per_expert,
|
||||
num_physical_experts=num_physical_experts,
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
)
|
||||
@@ -22,7 +22,7 @@ import torch.distributed
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.managers import deepseek_eplb
|
||||
from sglang.srt.managers import eplb_algorithms
|
||||
from sglang.srt.model_loader import get_model_architecture
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
@@ -134,15 +134,21 @@ class ExpertLocationMetadata:
|
||||
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
||||
model_config_for_expert_location = common["model_config_for_expert_location"]
|
||||
num_physical_experts = common["num_physical_experts"]
|
||||
num_groups = model_config_for_expert_location.num_groups
|
||||
num_nodes = server_args.nnodes
|
||||
|
||||
physical_to_logical_map, logical_to_all_physical_map, expert_count = (
|
||||
deepseek_eplb.rebalance_experts(
|
||||
eplb_algorithms.rebalance_experts(
|
||||
tokens_per_expert=logical_count,
|
||||
num_physical_experts=num_physical_experts,
|
||||
num_local_physical_experts=num_physical_experts // common["ep_size"],
|
||||
num_groups=model_config_for_expert_location.num_groups,
|
||||
num_nodes=server_args.nnodes,
|
||||
phase=server_args.disaggregation_mode,
|
||||
num_groups=num_groups,
|
||||
num_nodes=num_nodes,
|
||||
algorithm=eplb_algorithms.compute_algorithm(
|
||||
raw_algorithm=server_args.eplb_algorithm,
|
||||
num_groups=num_groups,
|
||||
num_nodes=num_nodes,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -175,6 +175,7 @@ class ServerArgs:
|
||||
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
|
||||
init_expert_location: str = "trivial"
|
||||
enable_eplb: bool = False
|
||||
eplb_algorithm: str = "auto"
|
||||
eplb_rebalance_num_iterations: int = 1000
|
||||
expert_distribution_recorder_mode: Optional[
|
||||
Literal["stat", "per_pass", "per_token"]
|
||||
@@ -1328,6 +1329,12 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enable EPLB algorithm",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eplb-algorithm",
|
||||
type=str,
|
||||
default=ServerArgs.eplb_algorithm,
|
||||
help="Chosen EPLB algorithm",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eplb-rebalance-num-iterations",
|
||||
type=int,
|
||||
|
||||
Reference in New Issue
Block a user