Move files related to EPLB (#7580)
This commit is contained in:
@@ -1,63 +0,0 @@
|
||||
from enum import Enum, auto
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.eplb_algorithms import deepseek, deepseek_vec
|
||||
|
||||
|
||||
class EplbAlgorithm(Enum):
|
||||
deepseek = auto()
|
||||
deepseek_hierarchical = auto()
|
||||
deepseek_vec = auto()
|
||||
deepseek_vec_hierarchical = auto()
|
||||
# TODO may have more algorithm later
|
||||
|
||||
|
||||
def rebalance_experts(
|
||||
tokens_per_expert: torch.Tensor,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
num_groups: Optional[int],
|
||||
num_nodes: int,
|
||||
algorithm: EplbAlgorithm,
|
||||
):
|
||||
if algorithm in [EplbAlgorithm.deepseek, EplbAlgorithm.deepseek_hierarchical]:
|
||||
return deepseek.rebalance_experts(
|
||||
weight=tokens_per_expert.sum(dim=0),
|
||||
num_replicas=num_physical_experts,
|
||||
num_groups=num_groups,
|
||||
num_nodes=num_nodes,
|
||||
num_gpus=num_physical_experts // num_local_physical_experts,
|
||||
enable_hierarchical=algorithm == EplbAlgorithm.deepseek_hierarchical,
|
||||
)
|
||||
|
||||
if algorithm in [
|
||||
EplbAlgorithm.deepseek_vec,
|
||||
EplbAlgorithm.deepseek_vec_hierarchical,
|
||||
]:
|
||||
return deepseek_vec.rebalance_experts(
|
||||
tokens_per_expert=tokens_per_expert,
|
||||
num_physical_experts=num_physical_experts,
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
num_groups=num_groups,
|
||||
num_nodes=num_nodes,
|
||||
enable_hierarchical=algorithm == EplbAlgorithm.deepseek_vec_hierarchical,
|
||||
)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def compute_algorithm(
|
||||
raw_algorithm: str,
|
||||
num_groups: Optional[int],
|
||||
num_nodes: int,
|
||||
) -> EplbAlgorithm:
|
||||
if raw_algorithm != "auto":
|
||||
return EplbAlgorithm[raw_algorithm]
|
||||
|
||||
# TODO test on real scenarios and know which ones perform better
|
||||
if (num_groups is not None) and (num_groups % num_nodes == 0):
|
||||
return EplbAlgorithm.deepseek_hierarchical
|
||||
else:
|
||||
return EplbAlgorithm.deepseek
|
||||
@@ -1,223 +0,0 @@
|
||||
# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
|
||||
def balanced_packing(
|
||||
weight: torch.Tensor, num_packs: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs
|
||||
are as balanced as possible.
|
||||
|
||||
Parameters:
|
||||
weight: [X, n], the weight of each item
|
||||
num_packs: number of packs
|
||||
|
||||
Returns:
|
||||
pack_index: [X, n], the pack index of each item
|
||||
rank_in_pack: [X, n], the rank of the item in the pack
|
||||
"""
|
||||
num_layers, num_groups = weight.shape
|
||||
assert num_groups % num_packs == 0
|
||||
groups_per_pack = num_groups // num_packs
|
||||
|
||||
if groups_per_pack == 1:
|
||||
pack_index = torch.arange(
|
||||
weight.size(-1), dtype=torch.int64, device=weight.device
|
||||
).expand(weight.shape)
|
||||
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
indices = weight.float().sort(-1, descending=True).indices.cpu()
|
||||
pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu")
|
||||
rank_in_pack = torch.full_like(pack_index, fill_value=-1)
|
||||
for i in range(num_layers):
|
||||
pack_weights = [0] * num_packs
|
||||
pack_items = [0] * num_packs
|
||||
for group in indices[i]:
|
||||
pack = min(
|
||||
(i for i in range(num_packs) if pack_items[i] < groups_per_pack),
|
||||
key=pack_weights.__getitem__,
|
||||
)
|
||||
assert pack_items[pack] < groups_per_pack
|
||||
pack_index[i, group] = pack
|
||||
rank_in_pack[i, group] = pack_items[pack]
|
||||
pack_weights[pack] += weight[i, group]
|
||||
pack_items[pack] += 1
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
|
||||
def replicate_experts(
|
||||
weight: torch.Tensor, num_phy: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.
|
||||
|
||||
Parameters:
|
||||
weight: [X, num_log]
|
||||
num_phy: total number of experts after replication
|
||||
|
||||
Returns:
|
||||
phy2log: [X, num_phy], logical expert id of each physical expert
|
||||
rank: [X, num_phy], the replica rank
|
||||
logcnt: [X, num_log], number of replicas for each logical expert
|
||||
"""
|
||||
n, num_log = weight.shape
|
||||
num_redundant = num_phy - num_log
|
||||
assert num_redundant >= 0
|
||||
device = weight.device
|
||||
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
|
||||
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
|
||||
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
|
||||
arangen = torch.arange(n, dtype=torch.int64, device=device)
|
||||
for i in range(num_log, num_phy):
|
||||
redundant_indices = (weight / logcnt).max(dim=-1).indices
|
||||
phy2log[:, i] = redundant_indices
|
||||
rank[:, i] = logcnt[arangen, redundant_indices]
|
||||
logcnt[arangen, redundant_indices] += 1
|
||||
return phy2log, rank, logcnt
|
||||
|
||||
|
||||
def rebalance_experts_hierarchical(
|
||||
weight: torch.Tensor,
|
||||
num_physical_experts: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_gpus: int,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
weight: [num_moe_layers, num_logical_experts]
|
||||
num_physical_experts: number of physical experts after replication
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
|
||||
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||
|
||||
Returns:
|
||||
physical_to_logical_map: [num_moe_layers, num_physical_experts]
|
||||
logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
|
||||
logical_count: [num_moe_layers, num_logical_experts]
|
||||
"""
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
assert num_logical_experts % num_groups == 0
|
||||
group_size = num_logical_experts // num_groups
|
||||
assert num_groups % num_nodes == 0
|
||||
groups_per_node = num_groups // num_nodes
|
||||
assert num_gpus % num_nodes == 0
|
||||
assert num_physical_experts % num_gpus == 0
|
||||
phy_experts_per_gpu = num_physical_experts // num_gpus
|
||||
|
||||
def inverse(perm: torch.Tensor) -> torch.Tensor:
|
||||
inv = torch.empty_like(perm)
|
||||
inv.scatter_(
|
||||
1,
|
||||
perm,
|
||||
torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(
|
||||
perm.shape
|
||||
),
|
||||
)
|
||||
return inv
|
||||
|
||||
# Step 1: pack groups to nodes
|
||||
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
|
||||
group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes)
|
||||
log2mlog = (
|
||||
(
|
||||
(group_pack_index * groups_per_node + group_rank_in_pack) * group_size
|
||||
).unsqueeze(-1)
|
||||
+ torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)
|
||||
).flatten(-2)
|
||||
mlog2log = inverse(log2mlog)
|
||||
|
||||
# Step 2: construct redundant experts within nodes
|
||||
# [num_layers * num_nodes, num_logical_experts // num_nodes]
|
||||
tokens_per_mlog = weight.gather(-1, mlog2log).view(
|
||||
-1, num_logical_experts // num_nodes
|
||||
)
|
||||
phy2mlog, phyrank, mlogcnt = replicate_experts(
|
||||
tokens_per_mlog, num_physical_experts // num_nodes
|
||||
)
|
||||
|
||||
# Step 3: pack physical_experts to GPUs
|
||||
# [num_layers * num_nodes, num_physical_experts // num_nodes]
|
||||
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
|
||||
pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes)
|
||||
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
|
||||
pphy2phy = inverse(phy2pphy)
|
||||
|
||||
pphy2mlog = phy2mlog.gather(
|
||||
-1, pphy2phy
|
||||
) # [num_layers * num_nodes, num_log_per_nodes]
|
||||
pphy2mlog = (
|
||||
pphy2mlog.view(num_layers, num_nodes, -1)
|
||||
+ torch.arange(
|
||||
0,
|
||||
num_logical_experts,
|
||||
num_logical_experts // num_nodes,
|
||||
device=group_pack_index.device,
|
||||
).view(1, -1, 1)
|
||||
).flatten(-2)
|
||||
pphy2log = mlog2log.gather(-1, pphy2mlog)
|
||||
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
|
||||
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
|
||||
return pphy2log, pphyrank, logcnt
|
||||
|
||||
|
||||
def rebalance_experts(
|
||||
weight: torch.Tensor,
|
||||
num_replicas: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_gpus: int,
|
||||
enable_hierarchical: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Entry point for expert-parallelism load balancer.
|
||||
|
||||
Parameters:
|
||||
weight: [layers, num_logical_experts], the load statistics for all logical experts
|
||||
num_replicas: number of physical experts, must be a multiple of `num_gpus`
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
|
||||
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||
|
||||
Returns:
|
||||
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
|
||||
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
|
||||
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
|
||||
"""
|
||||
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
weight = weight.float().cpu()
|
||||
if enable_hierarchical:
|
||||
# use hierarchical load-balance policy
|
||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
else:
|
||||
# use global load-balance policy
|
||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
|
||||
weight, num_replicas, 1, 1, num_gpus
|
||||
)
|
||||
maxlogcnt = logcnt.max().item()
|
||||
log2phy: torch.Tensor = torch.full(
|
||||
(num_layers, num_logical_experts, maxlogcnt),
|
||||
-1,
|
||||
dtype=torch.int64,
|
||||
device=logcnt.device,
|
||||
)
|
||||
log2phy.view(num_layers, -1).scatter_(
|
||||
-1,
|
||||
phy2log * maxlogcnt + phyrank,
|
||||
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
|
||||
num_layers, -1
|
||||
),
|
||||
)
|
||||
return phy2log, log2phy, logcnt
|
||||
|
||||
|
||||
__all__ = ["rebalance_experts"]
|
||||
@@ -1,276 +0,0 @@
|
||||
# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def pack_groups(tokens_per_group: torch.Tensor, num_nodes: int) -> torch.Tensor:
|
||||
num_layers, num_groups = tokens_per_group.shape
|
||||
assert num_groups % num_nodes == 0
|
||||
groups_per_rank = num_groups // num_nodes
|
||||
|
||||
indices = tokens_per_group.float().sort(-1, descending=True).indices.cpu()
|
||||
ret = torch.full_like(
|
||||
tokens_per_group, fill_value=-1, dtype=torch.int64, device="cpu"
|
||||
)
|
||||
for layer in range(num_layers):
|
||||
node_tokens = [0] * num_nodes
|
||||
node_groups = [0] * num_nodes
|
||||
for group in indices[layer]:
|
||||
|
||||
def key_func(rank: int) -> int:
|
||||
if node_groups[rank] >= groups_per_rank:
|
||||
return 1, 0
|
||||
else:
|
||||
return 0, node_tokens[rank]
|
||||
|
||||
rank = min(range(num_nodes), key=key_func)
|
||||
assert node_groups[rank] < groups_per_rank
|
||||
ret[layer, group] = rank * groups_per_rank + node_groups[rank]
|
||||
node_tokens[rank] += tokens_per_group[layer, group]
|
||||
node_groups[rank] += 1
|
||||
return ret
|
||||
|
||||
|
||||
def make_redundant_experts_chunkwise(
|
||||
tokens_per_expert: torch.Tensor,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
num_physical_experts_per_chunk: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
num_steps, num_moe_layers, num_logical_experts = tokens_per_expert.shape
|
||||
num_redundancy_experts = num_physical_experts - num_logical_experts
|
||||
|
||||
physical_to_logical_map = torch.empty(
|
||||
num_moe_layers,
|
||||
num_physical_experts,
|
||||
dtype=torch.int,
|
||||
device=tokens_per_expert.device,
|
||||
)
|
||||
logical_to_physical_map = torch.full(
|
||||
(num_moe_layers, num_logical_experts, num_redundancy_experts + 1),
|
||||
-1,
|
||||
dtype=torch.int,
|
||||
device=tokens_per_expert.device,
|
||||
)
|
||||
logical_count = torch.ones(
|
||||
num_moe_layers,
|
||||
num_logical_experts,
|
||||
dtype=torch.int,
|
||||
device=tokens_per_expert.device,
|
||||
)
|
||||
|
||||
assert num_physical_experts % num_physical_experts_per_chunk == 0
|
||||
num_chunks = num_physical_experts // num_physical_experts_per_chunk
|
||||
assert num_logical_experts % num_chunks == 0
|
||||
num_logical_experts_per_group = num_logical_experts // num_chunks
|
||||
assert num_redundancy_experts % num_chunks == 0
|
||||
num_redundancy_experts_per_group = num_redundancy_experts // num_chunks
|
||||
|
||||
arange_num_moe_layers_num_groups = torch.arange(
|
||||
num_moe_layers * num_chunks, dtype=torch.int, device=tokens_per_expert.device
|
||||
)
|
||||
arange_num_logical_experts = torch.arange(
|
||||
num_logical_experts, dtype=torch.int, device=tokens_per_expert.device
|
||||
)
|
||||
arange_num_logical_experts_per_group = torch.arange(
|
||||
num_logical_experts_per_group, dtype=torch.int, device=tokens_per_expert.device
|
||||
)
|
||||
arange_num_groups = torch.arange(
|
||||
num_chunks, dtype=torch.int, device=tokens_per_expert.device
|
||||
)
|
||||
physical_to_logical_map.view(
|
||||
num_moe_layers, num_chunks, num_physical_experts_per_chunk
|
||||
)[:, :, :num_logical_experts_per_group] = arange_num_logical_experts.view(
|
||||
num_chunks, num_logical_experts_per_group
|
||||
)
|
||||
logical_to_physical_map[:, :, 0] = (
|
||||
arange_num_logical_experts_per_group.expand(
|
||||
num_chunks, num_logical_experts_per_group
|
||||
)
|
||||
+ arange_num_groups[:, None] * num_physical_experts_per_chunk
|
||||
).view(num_logical_experts)
|
||||
|
||||
tokens_per_expert_all_diff = tokens_per_expert + arange_num_logical_experts * 1e-4
|
||||
for i in range(num_redundancy_experts_per_group):
|
||||
score = (
|
||||
tokens_per_expert_all_diff / logical_count
|
||||
) # NOTE: Values in score must be different from each other
|
||||
score1 = tokens_per_expert / (logical_count + 1)
|
||||
score = score.view(
|
||||
num_steps, num_moe_layers, num_chunks, num_logical_experts_per_group
|
||||
)
|
||||
score1 = score1.view_as(score)
|
||||
values, indices = score.max(-1, keepdim=True)
|
||||
values = values.expand_as(score).contiguous()
|
||||
score.scatter_(-1, indices, score1.gather(-1, indices))
|
||||
values.scatter_(-1, indices, score.max(-1, keepdim=True).values)
|
||||
redundancy_indices = values.sum(0).argmin(-1)
|
||||
physical_to_logical_map.view(
|
||||
num_moe_layers, num_chunks, num_physical_experts_per_chunk
|
||||
)[:, :, num_logical_experts_per_group + i] = (
|
||||
redundancy_indices + arange_num_groups * num_logical_experts_per_group
|
||||
)
|
||||
redundancy_count = (
|
||||
logical_count.view(
|
||||
num_moe_layers * num_chunks, num_logical_experts_per_group
|
||||
)
|
||||
.gather(-1, redundancy_indices.view(num_moe_layers * num_chunks, 1))
|
||||
.squeeze(1)
|
||||
)
|
||||
physical_redundancy_indices = (
|
||||
(
|
||||
arange_num_groups * num_physical_experts_per_chunk
|
||||
+ num_logical_experts_per_group
|
||||
+ i
|
||||
)
|
||||
.expand(num_moe_layers, num_chunks)
|
||||
.flatten()
|
||||
)
|
||||
logical_to_physical_map.view(
|
||||
num_moe_layers * num_chunks,
|
||||
num_logical_experts_per_group,
|
||||
num_redundancy_experts + 1,
|
||||
)[
|
||||
arange_num_moe_layers_num_groups,
|
||||
redundancy_indices.view(num_moe_layers * num_chunks),
|
||||
redundancy_count,
|
||||
] = physical_redundancy_indices
|
||||
logical_count.view(num_moe_layers * num_chunks, num_logical_experts_per_group)[
|
||||
arange_num_moe_layers_num_groups,
|
||||
redundancy_indices.view(num_moe_layers * num_chunks),
|
||||
] += 1
|
||||
|
||||
if num_local_physical_experts > 1:
|
||||
# Load-balancing between GPUs
|
||||
physical_to_logical_map_int64 = physical_to_logical_map.to(torch.int64)
|
||||
counts = logical_count.gather(-1, physical_to_logical_map_int64)
|
||||
score = tokens_per_expert.sum(0).gather(-1, physical_to_logical_map_int64)
|
||||
score = score / counts
|
||||
score = score.view(num_moe_layers, num_chunks, num_physical_experts_per_chunk)
|
||||
indices = score.argsort(-1, descending=True)
|
||||
indices += torch.arange(
|
||||
0,
|
||||
num_physical_experts,
|
||||
num_physical_experts_per_chunk,
|
||||
dtype=indices.dtype,
|
||||
device=indices.device,
|
||||
)[None, :, None]
|
||||
|
||||
assert num_physical_experts_per_chunk % num_local_physical_experts == 0
|
||||
num_local_groups = num_physical_experts_per_chunk // num_local_physical_experts
|
||||
indices = indices.view(
|
||||
num_moe_layers, num_chunks, num_local_physical_experts, num_local_groups
|
||||
)
|
||||
indices[:, :, 1::2, :] = indices[:, :, 1::2, :].flip(-1)
|
||||
indices = indices.transpose(2, 3)
|
||||
indices = indices.reshape(num_moe_layers, num_physical_experts)
|
||||
physical_to_logical_map = physical_to_logical_map.gather(-1, indices)
|
||||
mask = logical_to_physical_map == -1
|
||||
logical_to_physical_map[mask] = 0
|
||||
logical_to_physical_map = (
|
||||
indices.argsort(-1)
|
||||
.gather(
|
||||
-1, logical_to_physical_map.view(num_moe_layers, -1).to(torch.int64)
|
||||
)
|
||||
.view_as(logical_to_physical_map)
|
||||
.to(torch.int)
|
||||
)
|
||||
logical_to_physical_map[mask] = -1
|
||||
|
||||
return physical_to_logical_map, logical_to_physical_map, logical_count
|
||||
|
||||
|
||||
def decode_rebalance_experts(
|
||||
tokens_per_expert: torch.Tensor,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
):
|
||||
return make_redundant_experts_chunkwise(
|
||||
tokens_per_expert,
|
||||
num_physical_experts,
|
||||
num_local_physical_experts,
|
||||
num_physical_experts,
|
||||
)
|
||||
|
||||
|
||||
def prefill_rebalance_experts(
|
||||
tokens_per_expert: torch.Tensor,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
):
|
||||
tokens_per_expert = tokens_per_expert.float().cpu()
|
||||
|
||||
num_steps, _, num_logical_experts = tokens_per_expert.shape
|
||||
assert num_logical_experts % num_groups == 0
|
||||
group_size = num_logical_experts // num_groups
|
||||
assert num_groups % num_nodes == 0, f"{num_groups=} {num_nodes=}"
|
||||
|
||||
tokens_per_group = tokens_per_expert.sum(0).unflatten(-1, (num_groups, -1)).sum(-1)
|
||||
group_perm = pack_groups(
|
||||
tokens_per_group, num_nodes
|
||||
) # [num_moe_layers, num_groups] => [num_moe_layers, num_nodes]
|
||||
|
||||
# log2mlog [layers, #logexp] -> [layers, #logexp]
|
||||
log2mlog = (
|
||||
(group_perm * group_size).unsqueeze(-1)
|
||||
+ torch.arange(group_size, dtype=torch.int64, device=group_perm.device)
|
||||
).flatten(-2)
|
||||
|
||||
# mlog2log [layers, #logexp] -> [layers, #logexp], inverse of log2mlog
|
||||
mlog2log = torch.empty_like(log2mlog)
|
||||
arange = torch.arange(
|
||||
num_logical_experts, dtype=torch.int64, device=mlog2log.device
|
||||
)
|
||||
mlog2log.scatter_(1, log2mlog, arange.expand(log2mlog.size(0), -1))
|
||||
|
||||
# tokens_per_mlog[i][j][k] = tokens_per_expert[i][j][mlog2log[j][k]]
|
||||
tokens_per_mlog = tokens_per_expert.gather(
|
||||
2, mlog2log.unsqueeze(0).expand(num_steps, -1, -1)
|
||||
)
|
||||
|
||||
phy2mlog, mlog2phy, mlog_count = make_redundant_experts_chunkwise(
|
||||
tokens_per_mlog,
|
||||
num_physical_experts,
|
||||
num_local_physical_experts,
|
||||
num_physical_experts // num_nodes,
|
||||
)
|
||||
|
||||
# phy2log[i][j] = mlog2log[i][phy2mlog[i][j]]
|
||||
phy2log = mlog2log.gather(1, phy2mlog.to(torch.int64))
|
||||
|
||||
# mlog2phy: [num_moe_layers, num_logical_experts, ...]
|
||||
# log2phy[i][j][k] = mlog2phy[i][log2mlog[i][j]][k]
|
||||
log2phy = mlog2phy.gather(
|
||||
1, log2mlog.unsqueeze(-1).expand(-1, -1, mlog2phy.size(-1)).to(torch.int64)
|
||||
)
|
||||
|
||||
# log_count[i][j] = mlog_count[i][log2mlog[i][j]]
|
||||
log_count = mlog_count.gather(1, log2mlog)
|
||||
return phy2log, log2phy, log_count
|
||||
|
||||
|
||||
def rebalance_experts(
|
||||
tokens_per_expert: torch.Tensor,
|
||||
num_physical_experts: int,
|
||||
num_local_physical_experts: int,
|
||||
num_groups: Optional[int],
|
||||
num_nodes: int,
|
||||
enable_hierarchical: bool,
|
||||
):
|
||||
if enable_hierarchical:
|
||||
return prefill_rebalance_experts(
|
||||
tokens_per_expert=tokens_per_expert,
|
||||
num_physical_experts=num_physical_experts,
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
num_groups=num_groups,
|
||||
num_nodes=num_nodes,
|
||||
)
|
||||
else:
|
||||
return decode_rebalance_experts(
|
||||
tokens_per_expert=tokens_per_expert,
|
||||
num_physical_experts=num_physical_experts,
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
)
|
||||
@@ -1,96 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import torch.cuda
|
||||
|
||||
from sglang.srt.managers.expert_distribution import (
|
||||
get_global_expert_distribution_recorder,
|
||||
)
|
||||
from sglang.srt.managers.expert_location import ExpertLocationMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EPLBManager:
|
||||
def __init__(self, model_runner: "ModelRunner"):
|
||||
super().__init__()
|
||||
self._model_runner = model_runner
|
||||
self._server_args = model_runner.server_args
|
||||
self._rebalance_layers_per_chunk = (
|
||||
self._server_args.eplb_rebalance_layers_per_chunk
|
||||
)
|
||||
self._rebalance_num_iterations = self._server_args.eplb_rebalance_num_iterations
|
||||
|
||||
# Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented.
|
||||
assert (
|
||||
self._server_args.eplb_rebalance_num_iterations
|
||||
>= self._server_args.expert_distribution_recorder_buffer_size
|
||||
), "eplb_rebalance_num_iterations must be greater than expert_distribution_recorder_buffer_size"
|
||||
|
||||
if not get_global_expert_distribution_recorder().recording:
|
||||
get_global_expert_distribution_recorder().start_record()
|
||||
|
||||
logger.info(
|
||||
f"[EPLBManager] system started, will rebalance per {self._rebalance_num_iterations} iterations."
|
||||
)
|
||||
|
||||
self._main_generator = self._entrypoint()
|
||||
|
||||
def on_forward_pass_end(self):
|
||||
next(self._main_generator)
|
||||
|
||||
# can be more complex if needed
|
||||
def _entrypoint(self):
|
||||
while True:
|
||||
for _ in range(self._rebalance_num_iterations):
|
||||
yield
|
||||
|
||||
yield from self.rebalance()
|
||||
|
||||
def rebalance(self):
|
||||
logger.info("[EPLBManager] rebalance start")
|
||||
|
||||
enable_timing = self._rebalance_layers_per_chunk is None
|
||||
|
||||
if enable_timing:
|
||||
torch.cuda.synchronize()
|
||||
time_start = time.time()
|
||||
|
||||
logical_count = get_global_expert_distribution_recorder().dump_record(
|
||||
output_mode="object"
|
||||
)["logical_count"]
|
||||
expert_location_metadata = ExpertLocationMetadata.init_by_eplb(
|
||||
self._server_args, self._model_runner.model_config, logical_count
|
||||
)
|
||||
|
||||
update_layer_ids_chunks = self._compute_update_layer_ids_chunks()
|
||||
for chunk_index, update_layer_ids in enumerate(update_layer_ids_chunks):
|
||||
if len(update_layer_ids_chunks) > 1:
|
||||
yield
|
||||
self._model_runner.update_expert_location(
|
||||
expert_location_metadata,
|
||||
update_layer_ids=update_layer_ids,
|
||||
)
|
||||
|
||||
msg = f"[EPLBManager] rebalance end"
|
||||
if enable_timing:
|
||||
torch.cuda.synchronize()
|
||||
time_end = time.time()
|
||||
msg += f" time={time_end - time_start:.3f}s"
|
||||
logger.info(msg)
|
||||
|
||||
def _compute_update_layer_ids_chunks(self) -> List[List[int]]:
|
||||
all_layer_ids = sorted(
|
||||
list(self._model_runner.model.routed_experts_weights_of_layer.keys())
|
||||
)
|
||||
chunk_size = self._rebalance_layers_per_chunk or 1000000
|
||||
return list(_chunk_list(all_layer_ids, chunk_size=chunk_size))
|
||||
|
||||
|
||||
def _chunk_list(items: List, chunk_size):
|
||||
for start_index in range(0, len(items), chunk_size):
|
||||
yield items[start_index : start_index + chunk_size]
|
||||
@@ -1,920 +0,0 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from sglang.srt.managers.expert_location import ExpertLocationMetadata
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import Withable, get_bool_env_var
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --------------------------------------- Entrypoint -----------------------------------------
|
||||
|
||||
_OutputMode = Literal["file", "object"]
|
||||
|
||||
|
||||
class ExpertDistributionRecorder(ABC):
|
||||
"""Global expert distribution recording"""
|
||||
|
||||
@staticmethod
|
||||
def init_new(
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
rank: int,
|
||||
):
|
||||
if server_args.expert_distribution_recorder_mode is not None:
|
||||
return _ExpertDistributionRecorderReal(
|
||||
server_args, expert_location_metadata, rank
|
||||
)
|
||||
else:
|
||||
return _ExpertDistributionRecorderNoop()
|
||||
|
||||
@contextmanager
|
||||
def with_current_layer(self, layer_idx):
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def with_debug_name(self, debug_name):
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def disable_this_region(self):
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
|
||||
yield
|
||||
|
||||
def on_select_experts(self, topk_ids: torch.Tensor):
|
||||
pass
|
||||
|
||||
def on_deepep_dispatch_normal(
|
||||
self,
|
||||
local_physical_count_of_layer: List[int],
|
||||
num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank,
|
||||
num_tokens_per_expert,
|
||||
):
|
||||
pass
|
||||
|
||||
def on_deepep_dispatch_low_latency(
|
||||
self, local_physical_count_of_layer: torch.Tensor
|
||||
):
|
||||
pass
|
||||
|
||||
def start_record(self):
|
||||
self._on_not_implemented()
|
||||
|
||||
def stop_record(self):
|
||||
self._on_not_implemented()
|
||||
|
||||
def dump_record(self, output_mode: _OutputMode = "file"):
|
||||
self._on_not_implemented()
|
||||
|
||||
@property
|
||||
def recording(self):
|
||||
return False
|
||||
|
||||
def _on_not_implemented(self):
|
||||
raise Exception(
|
||||
"Please set ServerArgs.expert_distribution_recorder_mode to use ExpertDistributionRecorder."
|
||||
)
|
||||
|
||||
|
||||
class _ExpertDistributionRecorderNoop(ExpertDistributionRecorder):
|
||||
pass
|
||||
|
||||
|
||||
class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
rank: int,
|
||||
):
|
||||
self._server_args = server_args
|
||||
self._expert_location_metadata = expert_location_metadata
|
||||
|
||||
self._recording = False
|
||||
self._disable_all = False
|
||||
self._current_forward_pass_id = Withable()
|
||||
self._current_layer_idx = Withable()
|
||||
self._current_debug_name = Withable()
|
||||
self._accumulator = _Accumulator.init_new(
|
||||
server_args, expert_location_metadata, rank
|
||||
)
|
||||
self._single_pass_gatherers = {
|
||||
k: _SinglePassGatherer.init_new(server_args, expert_location_metadata, rank)
|
||||
for k in self._accumulator.get_single_pass_gatherer_keys()
|
||||
}
|
||||
|
||||
if server_args.enable_expert_distribution_metrics:
|
||||
logger.info(
|
||||
"ExpertDistributionRecorder auto start record since enable_expert_distribution_metrics"
|
||||
)
|
||||
self.start_record()
|
||||
|
||||
def with_current_layer(self, layer_idx):
|
||||
return self._current_layer_idx.with_value(layer_idx)
|
||||
|
||||
def with_debug_name(self, debug_name):
|
||||
return self._current_debug_name.with_value(debug_name)
|
||||
|
||||
@contextmanager
|
||||
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
|
||||
with self._current_forward_pass_id.with_value(forward_pass_id):
|
||||
self._on_forward_pass_start(forward_batch)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._on_forward_pass_end(forward_pass_id)
|
||||
|
||||
@contextmanager
|
||||
def disable_this_region(self):
|
||||
"""Context manager to temporarily disable recording."""
|
||||
previous_disable_all = self._disable_all
|
||||
self._disable_all = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._disable_all = previous_disable_all
|
||||
|
||||
def _on_forward_pass_start(self, forward_batch: ForwardBatch):
|
||||
if not self._recording:
|
||||
return
|
||||
for gatherer_key, gatherer in self._single_pass_gatherers.items():
|
||||
gatherer.reset()
|
||||
gatherer.on_forward_pass_start(forward_batch)
|
||||
|
||||
def _on_forward_pass_end(self, forward_pass_id: int):
|
||||
if not self._recording:
|
||||
return
|
||||
for gatherer_key, gatherer in self._single_pass_gatherers.items():
|
||||
single_pass_data = gatherer.collect()
|
||||
self._accumulator.append(forward_pass_id, gatherer_key, single_pass_data)
|
||||
|
||||
def on_select_experts(self, topk_ids: torch.Tensor):
|
||||
self._on_hook("on_select_experts", topk_ids=topk_ids)
|
||||
|
||||
def on_deepep_dispatch_normal(
|
||||
self,
|
||||
local_physical_count_of_layer: List[int],
|
||||
num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank,
|
||||
num_tokens_per_expert,
|
||||
):
|
||||
self._on_hook(
|
||||
"on_deepep_dispatch_normal",
|
||||
local_physical_count_of_layer=local_physical_count_of_layer,
|
||||
num_tokens_per_rank=num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
||||
num_tokens_per_expert=num_tokens_per_expert,
|
||||
)
|
||||
|
||||
def on_deepep_dispatch_low_latency(
|
||||
self, local_physical_count_of_layer: torch.Tensor
|
||||
):
|
||||
self._on_hook(
|
||||
"on_deepep_dispatch_low_latency",
|
||||
local_physical_count_of_layer=local_physical_count_of_layer,
|
||||
)
|
||||
|
||||
def _on_hook(self, hook_name: str, **kwargs):
|
||||
if self._disable_all:
|
||||
return
|
||||
if not (self._recording or torch.cuda.is_current_stream_capturing()):
|
||||
return
|
||||
gatherer = self._single_pass_gatherers[
|
||||
self._accumulator.get_single_pass_gatherer_key(
|
||||
self._current_debug_name.value
|
||||
)
|
||||
]
|
||||
getattr(gatherer, hook_name)(layer_idx=self._current_layer_idx.value, **kwargs)
|
||||
|
||||
def _reset(self):
|
||||
"""Reset the expert distribution recorder."""
|
||||
logger.info("Resetting ExpertDistributionRecorder...")
|
||||
assert (
|
||||
self._current_layer_idx.value is None
|
||||
), f"{self._current_layer_idx.value=}"
|
||||
for gatherer in self._single_pass_gatherers.values():
|
||||
gatherer.reset()
|
||||
self._accumulator.reset()
|
||||
|
||||
def start_record(self):
|
||||
"""Start recording the expert distribution."""
|
||||
if self._recording:
|
||||
logger.warning(
|
||||
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
|
||||
)
|
||||
self._reset()
|
||||
self._recording = True
|
||||
|
||||
def stop_record(self):
|
||||
"""Stop recording the expert distribution."""
|
||||
if not self._recording:
|
||||
logger.warning(
|
||||
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
|
||||
)
|
||||
self._recording = False
|
||||
|
||||
def dump_record(self, output_mode: _OutputMode = "file"):
|
||||
"""Dump the expert distribution record and reset the recorder after dumping."""
|
||||
output = self._accumulator.dump(output_mode=output_mode)
|
||||
self._reset()
|
||||
return output
|
||||
|
||||
@property
|
||||
def recording(self):
|
||||
return self._recording
|
||||
|
||||
|
||||
_global_expert_distribution_recorder: Optional[ExpertDistributionRecorder] = (
|
||||
_ExpertDistributionRecorderNoop()
|
||||
)
|
||||
|
||||
|
||||
def get_global_expert_distribution_recorder():
|
||||
return _global_expert_distribution_recorder
|
||||
|
||||
|
||||
def set_global_expert_distribution_recorder(value):
|
||||
global _global_expert_distribution_recorder
|
||||
_global_expert_distribution_recorder = value
|
||||
|
||||
|
||||
# --------------------------------------- SinglePassGatherer -----------------------------------------
|
||||
|
||||
|
||||
class _SinglePassGatherer(ABC):
|
||||
@staticmethod
|
||||
def init_new(
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
rank: int,
|
||||
) -> "_SinglePassGatherer":
|
||||
if server_args.expert_distribution_recorder_mode == "per_token":
|
||||
return _DetailSinglePassGatherer(
|
||||
server_args, expert_location_metadata, rank
|
||||
)
|
||||
|
||||
if server_args.expert_distribution_recorder_mode == "stat_approx":
|
||||
if server_args.enable_deepep_moe and (server_args.deepep_mode == "normal"):
|
||||
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if server_args.enable_deepep_moe:
|
||||
if server_args.deepep_mode == "normal":
|
||||
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
|
||||
elif server_args.deepep_mode == "low_latency":
|
||||
return _DeepepLowLatencySinglePassGatherer(
|
||||
expert_location_metadata, rank
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
|
||||
|
||||
def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int):
|
||||
self._expert_location_metadata = expert_location_metadata
|
||||
self._rank = rank
|
||||
|
||||
def on_forward_pass_start(self, forward_batch: ForwardBatch):
|
||||
pass
|
||||
|
||||
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
||||
pass
|
||||
|
||||
def on_deepep_dispatch_normal(
|
||||
self,
|
||||
layer_idx: int,
|
||||
local_physical_count_of_layer: List[int],
|
||||
num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank,
|
||||
num_tokens_per_expert,
|
||||
):
|
||||
pass
|
||||
|
||||
def on_deepep_dispatch_low_latency(
|
||||
self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
|
||||
):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def collect(self) -> Dict:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _DetailSinglePassGatherer(_SinglePassGatherer):
|
||||
# DeepSeek V3 has this value; should generalize later
|
||||
_TOP_K_NUM = 8
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
rank: int,
|
||||
):
|
||||
super().__init__(expert_location_metadata, rank)
|
||||
self._metadata: Optional[Dict[str, Any]] = None
|
||||
self._topk_ids_of_layer = torch.zeros(
|
||||
(
|
||||
expert_location_metadata.num_layers,
|
||||
# TODO determine the max number
|
||||
server_args.chunked_prefill_size * 8,
|
||||
self._TOP_K_NUM,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=server_args.device,
|
||||
)
|
||||
self._misc_objects: List[Dict[str, Any]] = []
|
||||
assert (
|
||||
not server_args.enable_two_batch_overlap
|
||||
), "DetailSinglePassGatherer does not support TBO yet"
|
||||
# TODO assert shared experts fusion is disabled, o/w data is wrong
|
||||
|
||||
def on_forward_pass_start(self, forward_batch: ForwardBatch):
|
||||
assert self._metadata is None
|
||||
self._metadata = dict(
|
||||
# TODO pr-chain
|
||||
# rids=forward_batch.rids,
|
||||
input_ids=forward_batch.input_ids.cpu().tolist(),
|
||||
positions=forward_batch.positions.cpu().tolist(),
|
||||
extend_seq_lens=forward_batch.extend_seq_lens_cpu,
|
||||
forward_mode=forward_batch.forward_mode.value,
|
||||
)
|
||||
|
||||
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
||||
self._topk_ids_of_layer[layer_idx, : topk_ids.shape[0], : topk_ids.shape[1]] = (
|
||||
topk_ids
|
||||
)
|
||||
|
||||
def on_deepep_dispatch_normal(
|
||||
self,
|
||||
layer_idx: int,
|
||||
local_physical_count_of_layer: List[int],
|
||||
num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank,
|
||||
num_tokens_per_expert,
|
||||
):
|
||||
self._misc_objects.append(
|
||||
dict(
|
||||
layer_id=layer_idx,
|
||||
num_tokens_per_rank=num_tokens_per_rank.cpu().tolist(),
|
||||
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank.cpu().tolist(),
|
||||
num_tokens_per_expert=num_tokens_per_expert.cpu().tolist(),
|
||||
)
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
self._topk_ids_of_layer[...] = -1
|
||||
self._misc_objects.clear()
|
||||
self._metadata = None
|
||||
|
||||
def collect(self) -> Dict:
|
||||
num_tokens = len(self._metadata["input_ids"])
|
||||
return dict(
|
||||
**self._metadata,
|
||||
topk_ids_of_layer=self._topk_ids_of_layer[:, :num_tokens, :].clone().cpu(),
|
||||
misc_objects=self._misc_objects,
|
||||
)
|
||||
|
||||
|
||||
class _LayerBasedCpuSinglePassGatherer(_SinglePassGatherer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._objects_of_layer = {}
|
||||
|
||||
def _on_layer_data(self, layer_idx: int, objects: List[int]):
|
||||
assert 0 <= layer_idx < self._expert_location_metadata.num_layers
|
||||
if layer_idx in self._objects_of_layer:
|
||||
self._objects_of_layer[layer_idx] = _list_sum(
|
||||
self._objects_of_layer[layer_idx], objects
|
||||
)
|
||||
else:
|
||||
self._objects_of_layer[layer_idx] = objects
|
||||
|
||||
def reset(self):
|
||||
self._objects_of_layer.clear()
|
||||
|
||||
def _collect_objects(self, pad_len: int) -> torch.Tensor:
|
||||
data = [
|
||||
self._objects_of_layer.get(layer_index) or ([0] * pad_len)
|
||||
for layer_index in range(self._expert_location_metadata.num_layers)
|
||||
]
|
||||
return torch.tensor(data)
|
||||
|
||||
|
||||
def _list_sum(a: List, b: List) -> List:
|
||||
return [x + y for x, y in zip(a, b, strict=True)]
|
||||
|
||||
|
||||
class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
|
||||
def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._enable_global_physical_experts = enable_global_physical_experts
|
||||
self._data = torch.zeros(
|
||||
(
|
||||
self._expert_location_metadata.num_layers,
|
||||
(
|
||||
self._expert_location_metadata.num_physical_experts
|
||||
if enable_global_physical_experts
|
||||
else self._expert_location_metadata.num_local_physical_experts
|
||||
),
|
||||
),
|
||||
dtype=torch.int,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
self._data[...] = 0
|
||||
|
||||
def collect(self) -> Dict:
|
||||
if self._enable_global_physical_experts:
|
||||
global_physical_count = self._data
|
||||
else:
|
||||
# Can optimize if bottleneck
|
||||
global_physical_count = _convert_local_to_global_physical_count(
|
||||
self._data,
|
||||
rank=self._rank,
|
||||
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
|
||||
num_physical_experts=self._expert_location_metadata.num_physical_experts,
|
||||
)
|
||||
|
||||
return dict(global_physical_count=global_physical_count)
|
||||
|
||||
|
||||
class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs, enable_global_physical_experts=True)
|
||||
|
||||
# can optimize (e.g. fuse / compile)
|
||||
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
||||
topk_ids = topk_ids.flatten()
|
||||
mask = topk_ids != -1
|
||||
self._data[layer_idx, :].scatter_add_(
|
||||
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
|
||||
)
|
||||
|
||||
|
||||
class _DeepepNormalSinglePassGatherer(_LayerBasedCpuSinglePassGatherer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
logger.info(
|
||||
"DeepepNormalSinglePassGatherer gathers approximate statistics. "
|
||||
"If used with small batch size, consider using expert_distribution_recorder_mode=stat."
|
||||
)
|
||||
|
||||
def on_deepep_dispatch_normal(
|
||||
self,
|
||||
layer_idx: int,
|
||||
local_physical_count_of_layer: List[int],
|
||||
num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank,
|
||||
num_tokens_per_expert,
|
||||
):
|
||||
assert isinstance(local_physical_count_of_layer, list)
|
||||
self._on_layer_data(layer_idx, local_physical_count_of_layer)
|
||||
|
||||
def collect(self) -> Dict:
|
||||
local_physical_count = super()._collect_objects(
|
||||
pad_len=self._expert_location_metadata.num_local_physical_experts
|
||||
)
|
||||
global_physical_count = _convert_local_to_global_physical_count(
|
||||
local_physical_count,
|
||||
rank=self._rank,
|
||||
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
|
||||
num_physical_experts=self._expert_location_metadata.num_physical_experts,
|
||||
)
|
||||
return dict(global_physical_count=global_physical_count)
|
||||
|
||||
|
||||
class _DeepepLowLatencySinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs, enable_global_physical_experts=False)
|
||||
|
||||
def on_deepep_dispatch_low_latency(
|
||||
self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
|
||||
):
|
||||
# Most naive implementation, can optimize later
|
||||
self._data[layer_idx, :] += local_physical_count_of_layer
|
||||
|
||||
|
||||
def _convert_local_to_global_physical_count(
|
||||
local_physical_count: torch.Tensor,
|
||||
rank: int,
|
||||
num_local_physical_experts: int,
|
||||
num_physical_experts: int,
|
||||
) -> torch.Tensor:
|
||||
dtype = local_physical_count.dtype
|
||||
device = local_physical_count.device
|
||||
num_layers, _ = local_physical_count.shape
|
||||
|
||||
ans = torch.zeros((num_layers, num_physical_experts), dtype=dtype, device=device)
|
||||
ans[
|
||||
:, num_local_physical_experts * rank : num_local_physical_experts * (rank + 1)
|
||||
] = local_physical_count
|
||||
return ans
|
||||
|
||||
|
||||
# --------------------------------------- Accumulator -----------------------------------------
|
||||
|
||||
_SINGLE_PASS_GATHERER_KEY_PRIMARY = "primary"
|
||||
|
||||
|
||||
class _Accumulator(ABC):
|
||||
@staticmethod
|
||||
def init_new(
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
rank: int,
|
||||
) -> "_Accumulator":
|
||||
return _Accumulator.get_class(server_args)(
|
||||
server_args, expert_location_metadata, rank
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_class(server_args: ServerArgs) -> Type["_Accumulator"]:
|
||||
return {
|
||||
"stat": _StatAccumulator,
|
||||
"stat_approx": _StatAccumulator,
|
||||
"per_pass": _DetailAccumulator,
|
||||
"per_token": _DetailAccumulator,
|
||||
}[server_args.expert_distribution_recorder_mode]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
rank: int,
|
||||
):
|
||||
self._server_args = server_args
|
||||
self._expert_location_metadata = expert_location_metadata
|
||||
self._rank = rank
|
||||
|
||||
def get_single_pass_gatherer_keys(self):
|
||||
return [_SINGLE_PASS_GATHERER_KEY_PRIMARY]
|
||||
|
||||
def get_single_pass_gatherer_key(self, debug_name: Optional[str]):
|
||||
return _SINGLE_PASS_GATHERER_KEY_PRIMARY
|
||||
|
||||
def append(
|
||||
self,
|
||||
forward_pass_id: int,
|
||||
gatherer_key: str,
|
||||
single_pass_data: Dict,
|
||||
):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def dump(self, output_mode: _OutputMode):
|
||||
pass
|
||||
|
||||
|
||||
class _UtilizationRateAccumulatorMixin(_Accumulator):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._enable = self._server_args.enable_expert_distribution_metrics
|
||||
|
||||
if self._enable:
|
||||
window_sizes = [10, 100, 1000]
|
||||
self._history = _DequeCollection(maxlens=window_sizes)
|
||||
self._rank = torch.distributed.get_rank()
|
||||
|
||||
def append(
|
||||
self,
|
||||
forward_pass_id: int,
|
||||
gatherer_key: str,
|
||||
single_pass_data: Dict,
|
||||
):
|
||||
super().append(forward_pass_id, gatherer_key, single_pass_data)
|
||||
if self._enable:
|
||||
self._append_utilization_rate(
|
||||
forward_pass_id, single_pass_data["global_physical_count"]
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
if self._enable:
|
||||
self._history.clear()
|
||||
|
||||
def _append_utilization_rate(
|
||||
self, forward_pass_id: int, single_pass_global_physical_count: torch.Tensor
|
||||
):
|
||||
gpu_physical_count = compute_gpu_physical_count(
|
||||
single_pass_global_physical_count,
|
||||
num_gpu=self._expert_location_metadata.ep_size,
|
||||
)
|
||||
gpu_physical_count = gpu_physical_count.to(self._server_args.device)
|
||||
torch.distributed.reduce(
|
||||
gpu_physical_count, dst=0, op=torch.distributed.ReduceOp.SUM
|
||||
)
|
||||
|
||||
if self._rank == 0:
|
||||
utilization_rate_tensor = compute_utilization_rate(gpu_physical_count)
|
||||
utilization_rate = torch.mean(utilization_rate_tensor).item()
|
||||
self._history.append(utilization_rate)
|
||||
|
||||
gpu_physical_count_sum = gpu_physical_count.sum().item()
|
||||
|
||||
logger.info(
|
||||
f"[Expert Balancedness] "
|
||||
f"forward_pass_id={forward_pass_id} "
|
||||
f"current_pass_balancedness={utilization_rate:.03f} "
|
||||
f"{''.join(f'last_{size}_average_balancedness={value:.03f} ' for size, value in self._history.mean().items())} "
|
||||
f"gpu_physical_count_sum={gpu_physical_count_sum}"
|
||||
# f"current_pass_per_layer={[round(x, 2) for x in utilization_rate_tensor.cpu().tolist()]}"
|
||||
)
|
||||
|
||||
|
||||
class _DequeCollection:
|
||||
def __init__(self, maxlens: List[int]):
|
||||
self._dequeues = [deque(maxlen=maxlen) for maxlen in maxlens]
|
||||
|
||||
def append(self, value):
|
||||
for d in self._dequeues:
|
||||
d.append(value)
|
||||
|
||||
def clear(self):
|
||||
for d in self._dequeues:
|
||||
d.clear()
|
||||
|
||||
def mean(self) -> Dict[int, float]:
|
||||
return {d.maxlen: sum(d) / len(d) for d in self._dequeues}
|
||||
|
||||
|
||||
class _DetailAccumulator(_UtilizationRateAccumulatorMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._records = []
|
||||
|
||||
def get_single_pass_gatherer_keys(self):
|
||||
if False: # TODO `server_args.enable_two_batch_overlap`
|
||||
return [_SINGLE_PASS_GATHERER_KEY_PRIMARY, "child_a", "child_b"]
|
||||
return super().get_single_pass_gatherer_keys()
|
||||
|
||||
def get_single_pass_gatherer_key(self, debug_name: Optional[str]):
|
||||
if False: # TODO `server_args.enable_two_batch_overlap`
|
||||
return debug_name or _SINGLE_PASS_GATHERER_KEY_PRIMARY
|
||||
return super().get_single_pass_gatherer_key(debug_name)
|
||||
|
||||
def append(
|
||||
self,
|
||||
forward_pass_id: int,
|
||||
gatherer_key: str,
|
||||
single_pass_data: Dict,
|
||||
):
|
||||
super().append(forward_pass_id, gatherer_key, single_pass_data)
|
||||
|
||||
def _process_object(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj.cpu().clone()
|
||||
return obj
|
||||
|
||||
single_pass_data_processed = {
|
||||
k: _process_object(v) for k, v in single_pass_data.items()
|
||||
}
|
||||
|
||||
self._records.append(
|
||||
dict(
|
||||
forward_pass_id=forward_pass_id,
|
||||
rank=self._rank,
|
||||
gatherer_key=gatherer_key,
|
||||
**single_pass_data_processed,
|
||||
)
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
self._records.clear()
|
||||
|
||||
def dump(self, output_mode: _OutputMode):
|
||||
assert output_mode == "file"
|
||||
output = dict(
|
||||
records=self._records,
|
||||
# NOTE: This may change during recording, so here we say it is the "last" one
|
||||
last_physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map,
|
||||
)
|
||||
_dump_to_file(
|
||||
f"expert_distribution_recorder_{time.time()}_{self._rank}.pt", output
|
||||
)
|
||||
|
||||
|
||||
class _StatAccumulator(_UtilizationRateAccumulatorMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._global_physical_count_of_buffered_step = _Buffer.init_new(
|
||||
item_shape=(
|
||||
self._expert_location_metadata.num_layers,
|
||||
# Cannot use local_physical_count to support select_experts
|
||||
self._expert_location_metadata.num_physical_experts,
|
||||
),
|
||||
buffer_size=self._server_args.expert_distribution_recorder_buffer_size,
|
||||
dtype=torch.int32,
|
||||
device=self._server_args.device,
|
||||
)
|
||||
self._first_dump = True
|
||||
|
||||
def append(
|
||||
self,
|
||||
forward_pass_id: int,
|
||||
gatherer_key: str,
|
||||
single_pass_data: Dict,
|
||||
):
|
||||
super().append(forward_pass_id, gatherer_key, single_pass_data)
|
||||
# Can optimize if overhead here is large
|
||||
self._global_physical_count_of_buffered_step.append(
|
||||
single_pass_data["global_physical_count"]
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
self._global_physical_count_of_buffered_step.reset()
|
||||
|
||||
def dump(self, output_mode: _OutputMode):
|
||||
logical_count_of_buffered_step = _convert_global_physical_count_to_logical_count(
|
||||
self._global_physical_count_of_buffered_step.get_all(),
|
||||
num_layers=self._expert_location_metadata.num_layers,
|
||||
num_logical_experts=self._expert_location_metadata.num_logical_experts,
|
||||
physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map,
|
||||
)
|
||||
|
||||
if self._first_dump:
|
||||
self._first_dump = False
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
torch.distributed.all_reduce(
|
||||
logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM
|
||||
)
|
||||
|
||||
output = dict(
|
||||
rank=self._rank,
|
||||
logical_count=logical_count_of_buffered_step,
|
||||
)
|
||||
|
||||
if output_mode == "file":
|
||||
if self._rank == 0:
|
||||
_dump_to_file(f"expert_distribution_recorder_{time.time()}.pt", output)
|
||||
elif output_mode == "object":
|
||||
return output
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _dump_to_file(name, data):
|
||||
save_dir = Path(os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR", "/tmp"))
|
||||
path_output = save_dir / name
|
||||
logger.info(f"Write expert distribution to {path_output}")
|
||||
if not save_dir.exists():
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(data, str(path_output))
|
||||
|
||||
|
||||
class _Buffer:
|
||||
@staticmethod
|
||||
def init_new(item_shape: Tuple, buffer_size: int, dtype, device):
|
||||
if buffer_size < 0:
|
||||
return _InfiniteBuffer(item_shape, dtype=dtype, device=device)
|
||||
else:
|
||||
return _CircularBuffer(item_shape, buffer_size, dtype=dtype, device=device)
|
||||
|
||||
def append(self, value: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_all(self) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def reset(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _CircularBuffer(_Buffer):
|
||||
def __init__(self, item_shape: Tuple, buffer_size: int, dtype, device):
|
||||
self._buffer = torch.zeros(
|
||||
(buffer_size, *item_shape), dtype=dtype, device=device
|
||||
)
|
||||
self._curr_index = 0
|
||||
|
||||
def append(self, value: torch.Tensor):
|
||||
self._buffer[self._curr_index] = value
|
||||
self._curr_index = (self._curr_index + 1) % len(self._buffer)
|
||||
|
||||
def get_all(self) -> torch.Tensor:
|
||||
return self._buffer
|
||||
|
||||
def reset(self):
|
||||
self._buffer[...] = 0
|
||||
|
||||
|
||||
class _InfiniteBuffer(_Buffer):
|
||||
def __init__(self, item_shape: Tuple, dtype, device):
|
||||
self._item_shape = item_shape
|
||||
self._buffer = torch.zeros((128, *item_shape), dtype=dtype, device=device)
|
||||
self._size = 0
|
||||
|
||||
def append(self, value: torch.Tensor):
|
||||
curr_buffer_size = len(self._buffer)
|
||||
dtype = self._buffer.dtype
|
||||
device = self._buffer.device
|
||||
|
||||
if self._size == curr_buffer_size:
|
||||
new_buffer = torch.zeros(
|
||||
(2 * curr_buffer_size, *self._item_shape), dtype=dtype, device=device
|
||||
)
|
||||
new_buffer[:curr_buffer_size] = self._buffer
|
||||
self._buffer = new_buffer
|
||||
|
||||
self._buffer[self._size] = value
|
||||
self._size += 1
|
||||
|
||||
def get_all(self) -> torch.Tensor:
|
||||
return self._buffer[: self._size]
|
||||
|
||||
def reset(self):
|
||||
self._buffer[...] = 0
|
||||
self._size = 0
|
||||
|
||||
|
||||
def _convert_global_physical_count_to_logical_count(
|
||||
# (whatever, num_layers, num_physical_experts)
|
||||
global_physical_count: torch.Tensor,
|
||||
num_layers: int,
|
||||
num_logical_experts: int,
|
||||
physical_to_logical_map: torch.Tensor,
|
||||
):
|
||||
dim_extra, _, _ = global_physical_count.shape
|
||||
dtype = global_physical_count.dtype
|
||||
device = global_physical_count.device
|
||||
logical_count = torch.zeros(
|
||||
(dim_extra, num_layers, num_logical_experts), dtype=dtype, device=device
|
||||
)
|
||||
logical_count.scatter_add_(
|
||||
dim=2,
|
||||
index=physical_to_logical_map.unsqueeze(0)
|
||||
.expand(dim_extra, -1, -1)
|
||||
.to(torch.int64),
|
||||
src=global_physical_count,
|
||||
)
|
||||
return logical_count
|
||||
|
||||
|
||||
def compute_gpu_physical_count(
|
||||
physical_count_of_whatever: torch.Tensor, # (..., num_layer, num_physical_expert)
|
||||
num_gpu: int,
|
||||
):
|
||||
"""output: gpu_physical_count_of_batch (..., num_layer, num_gpu)"""
|
||||
return einops.reduce(
|
||||
physical_count_of_whatever,
|
||||
"... num_layer (num_gpu num_expert_per_gpu) -> ... num_layer num_gpu",
|
||||
"sum",
|
||||
num_gpu=num_gpu,
|
||||
)
|
||||
|
||||
|
||||
def compute_utilization_rate(
|
||||
gpu_physical_count_of_batch: torch.Tensor, # (..., num_layer, num_gpu)
|
||||
):
|
||||
"""output: utilization_rate (..., num_layer)"""
|
||||
gpu_physical_count_of_batch = gpu_physical_count_of_batch.float()
|
||||
max_gpu_physical_count = einops.reduce(
|
||||
gpu_physical_count_of_batch,
|
||||
"... num_layer num_gpu -> ... num_layer",
|
||||
"max",
|
||||
)
|
||||
avg_gpu_physical_count = einops.reduce(
|
||||
gpu_physical_count_of_batch,
|
||||
"... num_layer num_gpu -> ... num_layer",
|
||||
"mean",
|
||||
)
|
||||
return (avg_gpu_physical_count + 1e-5) / (max_gpu_physical_count + 1e-5)
|
||||
@@ -1,448 +0,0 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.managers import eplb_algorithms
|
||||
from sglang.srt.model_loader import get_model_architecture
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpertLocationMetadata:
|
||||
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
|
||||
physical_to_logical_map_cpu: torch.Tensor
|
||||
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
|
||||
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
|
||||
# (layers, num_logical_experts)
|
||||
logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
|
||||
|
||||
# -------------------------------- properties ------------------------------------
|
||||
|
||||
@property
|
||||
def num_layers(self) -> int:
|
||||
return self.physical_to_logical_map.shape[0]
|
||||
|
||||
@property
|
||||
def num_physical_experts(self) -> int:
|
||||
return self.physical_to_logical_map.shape[1]
|
||||
|
||||
@property
|
||||
def num_local_physical_experts(self) -> int:
|
||||
ans, remainder = divmod(self.num_physical_experts, self.ep_size)
|
||||
assert remainder == 0
|
||||
return ans
|
||||
|
||||
@property
|
||||
def num_logical_experts(self) -> int:
|
||||
return self.logical_to_all_physical_map.shape[1]
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
# TODO change when EP size != world size
|
||||
return torch.distributed.get_world_size()
|
||||
|
||||
def __post_init__(self):
|
||||
num_layers_0, num_physical_experts_0 = self.physical_to_logical_map.shape
|
||||
num_layers_1, num_logical_experts_0, num_physical_experts_1 = (
|
||||
self.logical_to_all_physical_map.shape
|
||||
)
|
||||
num_layers_2, num_logical_experts_1 = (
|
||||
self.logical_to_all_physical_map_num_valid.shape
|
||||
)
|
||||
assert num_layers_0 == num_layers_1 == num_layers_2
|
||||
assert num_logical_experts_0 == num_logical_experts_1
|
||||
assert num_physical_experts_0 == num_physical_experts_1
|
||||
|
||||
# -------------------------------- construction ------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
|
||||
"""Trivial location - logical expert i corresponds to physical expert i"""
|
||||
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
||||
num_physical_experts = common["num_physical_experts"]
|
||||
model_config_for_expert_location = common["model_config_for_expert_location"]
|
||||
num_layers = model_config_for_expert_location.num_layers
|
||||
num_logical_experts = model_config_for_expert_location.num_logical_experts
|
||||
|
||||
physical_to_logical_map = (
|
||||
torch.arange(0, num_physical_experts).repeat(num_layers, 1)
|
||||
% num_logical_experts
|
||||
)
|
||||
|
||||
return ExpertLocationMetadata.init_by_mapping(
|
||||
server_args,
|
||||
model_config,
|
||||
physical_to_logical_map=physical_to_logical_map,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_by_mapping(
|
||||
server_args: ServerArgs,
|
||||
model_config: ModelConfig,
|
||||
physical_to_logical_map,
|
||||
):
|
||||
if not isinstance(physical_to_logical_map, torch.Tensor):
|
||||
physical_to_logical_map = torch.tensor(physical_to_logical_map)
|
||||
physical_to_logical_map = physical_to_logical_map.to(server_args.device)
|
||||
|
||||
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
||||
model_config_for_expert_location = common["model_config_for_expert_location"]
|
||||
logical_to_all_physical_map = _compute_logical_to_all_physical_map(
|
||||
physical_to_logical_map,
|
||||
num_logical_experts=model_config_for_expert_location.num_logical_experts,
|
||||
)
|
||||
|
||||
return ExpertLocationMetadata._init_raw(
|
||||
server_args=server_args,
|
||||
ep_size=common["ep_size"],
|
||||
physical_to_logical_map=physical_to_logical_map,
|
||||
logical_to_all_physical_map=logical_to_all_physical_map,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def init_by_eplb(
|
||||
server_args: ServerArgs, model_config: ModelConfig, logical_count: torch.Tensor
|
||||
):
|
||||
if not isinstance(logical_count, torch.Tensor):
|
||||
logical_count = torch.tensor(logical_count)
|
||||
if len(logical_count.shape) == 2:
|
||||
logical_count = logical_count.unsqueeze(0)
|
||||
logical_count = logical_count.to(server_args.device)
|
||||
|
||||
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
||||
model_config_for_expert_location = common["model_config_for_expert_location"]
|
||||
num_physical_experts = common["num_physical_experts"]
|
||||
num_groups = model_config_for_expert_location.num_groups
|
||||
num_nodes = server_args.nnodes
|
||||
|
||||
physical_to_logical_map, logical_to_all_physical_map, expert_count = (
|
||||
eplb_algorithms.rebalance_experts(
|
||||
tokens_per_expert=logical_count,
|
||||
num_physical_experts=num_physical_experts,
|
||||
num_local_physical_experts=num_physical_experts // common["ep_size"],
|
||||
num_groups=num_groups,
|
||||
num_nodes=num_nodes,
|
||||
algorithm=eplb_algorithms.compute_algorithm(
|
||||
raw_algorithm=server_args.eplb_algorithm,
|
||||
num_groups=num_groups,
|
||||
num_nodes=num_nodes,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return ExpertLocationMetadata._init_raw(
|
||||
server_args=server_args,
|
||||
ep_size=common["ep_size"],
|
||||
physical_to_logical_map=physical_to_logical_map.to(server_args.device),
|
||||
logical_to_all_physical_map=logical_to_all_physical_map.to(
|
||||
server_args.device
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _init_common(server_args: ServerArgs, model_config: ModelConfig):
|
||||
model_config_for_expert_location = (
|
||||
ModelConfigForExpertLocation.from_model_config(model_config)
|
||||
)
|
||||
|
||||
num_physical_experts = (
|
||||
model_config_for_expert_location.num_logical_experts
|
||||
+ server_args.ep_num_redundant_experts
|
||||
)
|
||||
ep_size = server_args.ep_size
|
||||
assert num_physical_experts % ep_size == 0
|
||||
num_local_physical_experts = num_physical_experts // ep_size
|
||||
|
||||
return dict(
|
||||
model_config_for_expert_location=model_config_for_expert_location,
|
||||
num_physical_experts=num_physical_experts,
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
ep_size=ep_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _init_raw(
|
||||
server_args: ServerArgs,
|
||||
ep_size: int,
|
||||
physical_to_logical_map: torch.Tensor,
|
||||
logical_to_all_physical_map: torch.Tensor,
|
||||
):
|
||||
_, num_physical_experts = physical_to_logical_map.shape
|
||||
|
||||
logical_to_all_physical_map_padded = F.pad(
|
||||
logical_to_all_physical_map,
|
||||
(0, num_physical_experts - logical_to_all_physical_map.shape[-1]),
|
||||
value=-1,
|
||||
)
|
||||
|
||||
logical_to_all_physical_map_num_valid = torch.count_nonzero(
|
||||
logical_to_all_physical_map != -1, dim=-1
|
||||
)
|
||||
|
||||
return ExpertLocationMetadata(
|
||||
physical_to_logical_map=physical_to_logical_map,
|
||||
physical_to_logical_map_cpu=physical_to_logical_map.cpu(),
|
||||
logical_to_all_physical_map=logical_to_all_physical_map_padded,
|
||||
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
||||
logical_to_rank_dispatch_physical_map=(
|
||||
compute_logical_to_rank_dispatch_physical_map(
|
||||
logical_to_all_physical_map=logical_to_all_physical_map,
|
||||
num_gpus=ep_size,
|
||||
num_physical_experts=num_physical_experts,
|
||||
# TODO improve when we have real EP rank
|
||||
ep_rank=torch.distributed.get_rank() % ep_size,
|
||||
)
|
||||
if server_args.ep_dispatch_algorithm == "static"
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# -------------------------------- mutation ------------------------------------
|
||||
|
||||
def update(
|
||||
self,
|
||||
other: "ExpertLocationMetadata",
|
||||
update_layer_ids: List[int],
|
||||
):
|
||||
for field in [
|
||||
"ep_size",
|
||||
]:
|
||||
assert getattr(self, field) == getattr(other, field)
|
||||
|
||||
for field in [
|
||||
"physical_to_logical_map",
|
||||
"physical_to_logical_map_cpu",
|
||||
"logical_to_all_physical_map",
|
||||
"logical_to_all_physical_map_num_valid",
|
||||
"logical_to_rank_dispatch_physical_map",
|
||||
]:
|
||||
other_field = getattr(other, field)
|
||||
self_field = getattr(self, field)
|
||||
assert (other_field is not None) == (self_field is not None)
|
||||
if self_field is not None:
|
||||
mask_update = torch.tensor(
|
||||
[i in update_layer_ids for i in range(self.num_layers)]
|
||||
)
|
||||
mask_update = mask_update.view(*([-1] + [1] * (self_field.dim() - 1)))
|
||||
mask_update = mask_update.to(self_field.device, non_blocking=True)
|
||||
self_field[...] = torch.where(mask_update, other_field, self_field)
|
||||
|
||||
# -------------------------------- usage ------------------------------------
|
||||
|
||||
def logical_to_all_physical(
|
||||
self, layer_id: int, logical_expert_id: int
|
||||
) -> List[int]:
|
||||
return [
|
||||
physical_expert_id
|
||||
for physical_expert_id in self.logical_to_all_physical_map[
|
||||
layer_id, logical_expert_id
|
||||
].tolist()
|
||||
if physical_expert_id != -1
|
||||
]
|
||||
|
||||
|
||||
_global_expert_location_metadata: Optional[ExpertLocationMetadata] = None
|
||||
|
||||
|
||||
def get_global_expert_location_metadata():
|
||||
return _global_expert_location_metadata
|
||||
|
||||
|
||||
def set_global_expert_location_metadata(value):
|
||||
global _global_expert_location_metadata
|
||||
assert _global_expert_location_metadata is None
|
||||
_global_expert_location_metadata = value
|
||||
|
||||
|
||||
def _compute_logical_to_all_physical_map(
|
||||
physical_to_logical_map: torch.Tensor, num_logical_experts: int
|
||||
):
|
||||
# This is rarely called, so we use for loops for maximum clarity
|
||||
|
||||
num_layers, num_physical_experts = physical_to_logical_map.shape
|
||||
|
||||
logical_to_all_physical_map = [
|
||||
[[] for _ in range(num_logical_experts)] for _ in range(num_layers)
|
||||
]
|
||||
for layer_id in range(num_layers):
|
||||
for physical_expert_id in range(num_physical_experts):
|
||||
logical_expert_id = physical_to_logical_map[
|
||||
layer_id, physical_expert_id
|
||||
].item()
|
||||
logical_to_all_physical_map[layer_id][logical_expert_id].append(
|
||||
physical_expert_id
|
||||
)
|
||||
|
||||
logical_to_all_physical_map = _pad_nested_array(
|
||||
logical_to_all_physical_map, pad_value=-1
|
||||
)
|
||||
|
||||
return torch.tensor(
|
||||
logical_to_all_physical_map, device=physical_to_logical_map.device
|
||||
)
|
||||
|
||||
|
||||
def _pad_nested_array(arr, pad_value):
|
||||
max_len = max(len(inner) for outer in arr for inner in outer)
|
||||
padded = [
|
||||
[inner + [pad_value] * (max_len - len(inner)) for inner in outer]
|
||||
for outer in arr
|
||||
]
|
||||
return padded
|
||||
|
||||
|
||||
# TODO optimize performance (rewrite and/or run in separate process with overlap)
|
||||
def compute_logical_to_rank_dispatch_physical_map(
|
||||
logical_to_all_physical_map: torch.Tensor,
|
||||
num_gpus: int,
|
||||
num_physical_experts: int,
|
||||
ep_rank: int,
|
||||
seed: int = 42,
|
||||
):
|
||||
r = random.Random(seed)
|
||||
|
||||
num_local_physical_experts = num_physical_experts // num_gpus
|
||||
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
|
||||
dtype = logical_to_all_physical_map.dtype
|
||||
|
||||
logical_to_rank_dispatch_physical_map = torch.full(
|
||||
size=(num_gpus, num_layers, num_logical_experts),
|
||||
fill_value=-1,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
for layer_id in range(num_layers):
|
||||
for logical_expert_id in range(num_logical_experts):
|
||||
candidate_physical_expert_ids = _logical_to_all_physical_raw(
|
||||
logical_to_all_physical_map, layer_id, logical_expert_id
|
||||
)
|
||||
output_partial = logical_to_rank_dispatch_physical_map[
|
||||
:, layer_id, logical_expert_id
|
||||
]
|
||||
|
||||
for gpu_id in range(num_gpus):
|
||||
same_gpu_physical_expert_ids = [
|
||||
physical_expert_id
|
||||
for physical_expert_id in candidate_physical_expert_ids
|
||||
if _compute_gpu_id_of_physical_expert(
|
||||
physical_expert_id, num_local_physical_experts
|
||||
)
|
||||
== gpu_id
|
||||
]
|
||||
if len(same_gpu_physical_expert_ids) > 0:
|
||||
output_partial[gpu_id] = same_gpu_physical_expert_ids[0]
|
||||
|
||||
num_remain = torch.sum(output_partial == -1).item()
|
||||
output_partial[output_partial == -1] = torch.tensor(
|
||||
_fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
assert torch.all(logical_to_rank_dispatch_physical_map != -1)
|
||||
|
||||
device = logical_to_all_physical_map.device
|
||||
return logical_to_rank_dispatch_physical_map[ep_rank, :, :].to(device)
|
||||
|
||||
|
||||
def _logical_to_all_physical_raw(
|
||||
logical_to_all_physical_map, layer_id: int, logical_expert_id: int
|
||||
) -> List[int]:
|
||||
return [
|
||||
physical_expert_id
|
||||
for physical_expert_id in logical_to_all_physical_map[
|
||||
layer_id, logical_expert_id
|
||||
].tolist()
|
||||
if physical_expert_id != -1
|
||||
]
|
||||
|
||||
|
||||
def _compute_gpu_id_of_physical_expert(
|
||||
physical_expert_id: int, num_local_physical_experts: int
|
||||
) -> int:
|
||||
return physical_expert_id // num_local_physical_experts
|
||||
|
||||
|
||||
def _fair_choices(arr: List, k: int, r: random.Random) -> List:
|
||||
quotient, remainder = divmod(k, len(arr))
|
||||
ans = arr * quotient + r.sample(arr, k=remainder)
|
||||
r.shuffle(ans)
|
||||
return ans
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfigForExpertLocation:
|
||||
num_layers: int
|
||||
num_logical_experts: int
|
||||
num_groups: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def init_dummy():
|
||||
return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1)
|
||||
|
||||
@staticmethod
|
||||
def from_model_config(model_config: ModelConfig):
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
if hasattr(model_class, "get_model_config_for_expert_location"):
|
||||
return model_class.get_model_config_for_expert_location(
|
||||
model_config.hf_config
|
||||
)
|
||||
else:
|
||||
return ModelConfigForExpertLocation.init_dummy()
|
||||
|
||||
|
||||
def compute_initial_expert_location_metadata(
|
||||
server_args: ServerArgs, model_config: ModelConfig
|
||||
) -> ExpertLocationMetadata:
|
||||
data = server_args.init_expert_location
|
||||
if data == "trivial":
|
||||
return ExpertLocationMetadata.init_trivial(server_args, model_config)
|
||||
|
||||
# TODO unify with the utils function
|
||||
if data.endswith(".pt"):
|
||||
data_dict = torch.load(data, weights_only=True)
|
||||
elif data.endswith(".json"):
|
||||
data_dict = json.loads(Path(data).read_text())
|
||||
else:
|
||||
data_dict = json.loads(data)
|
||||
|
||||
if "physical_to_logical_map" in data_dict:
|
||||
logger.info(
|
||||
"init_expert_location from init_by_mapping using ServerArgs.init_expert_location"
|
||||
)
|
||||
return ExpertLocationMetadata.init_by_mapping(
|
||||
server_args, model_config, **data_dict
|
||||
)
|
||||
elif "logical_count" in data_dict:
|
||||
logger.info(
|
||||
"init_expert_location from init_by_eplb using ServerArgs.init_expert_location"
|
||||
)
|
||||
return ExpertLocationMetadata.init_by_eplb(
|
||||
server_args, model_config, logical_count=data_dict["logical_count"]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unknown init_expert_location format ({list(data_dict.keys())=})"
|
||||
)
|
||||
@@ -1,108 +0,0 @@
|
||||
# Copyright 2023-2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpertLocationDispatchInfo:
|
||||
ep_dispatch_algorithm: Literal["static", "random"]
|
||||
# (num_logical_experts,)
|
||||
partial_logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
|
||||
# (num_logical_experts, X)
|
||||
partial_logical_to_all_physical_map: torch.Tensor
|
||||
# (num_logical_experts,)
|
||||
partial_logical_to_all_physical_map_num_valid: torch.Tensor
|
||||
num_physical_experts: int
|
||||
|
||||
@classmethod
|
||||
def init_new(cls, layer_id: int):
|
||||
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
|
||||
expert_location_metadata = get_global_expert_location_metadata()
|
||||
|
||||
if ep_dispatch_algorithm is None:
|
||||
return None
|
||||
|
||||
return cls(
|
||||
ep_dispatch_algorithm=ep_dispatch_algorithm,
|
||||
partial_logical_to_rank_dispatch_physical_map=(
|
||||
expert_location_metadata.logical_to_rank_dispatch_physical_map[
|
||||
layer_id, :
|
||||
]
|
||||
if expert_location_metadata.logical_to_rank_dispatch_physical_map
|
||||
is not None
|
||||
else None
|
||||
),
|
||||
partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[
|
||||
layer_id, :
|
||||
],
|
||||
partial_logical_to_all_physical_map_num_valid=expert_location_metadata.logical_to_all_physical_map_num_valid[
|
||||
layer_id, :
|
||||
],
|
||||
num_physical_experts=expert_location_metadata.num_physical_experts,
|
||||
)
|
||||
|
||||
|
||||
def transform_select_experts_inputs(
|
||||
router_logits: torch.Tensor,
|
||||
correction_bias: Optional[torch.Tensor],
|
||||
info: Optional[ExpertLocationDispatchInfo],
|
||||
):
|
||||
if (info is not None) and (info.ep_dispatch_algorithm == "fake"):
|
||||
router_logits = torch.randn_like(router_logits)
|
||||
if correction_bias is not None:
|
||||
correction_bias = torch.zeros_like(correction_bias)
|
||||
return router_logits, correction_bias
|
||||
|
||||
|
||||
def topk_ids_logical_to_physical(
|
||||
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
|
||||
) -> torch.Tensor:
|
||||
if info is None:
|
||||
return topk_ids
|
||||
|
||||
if info.ep_dispatch_algorithm == "static":
|
||||
return _topk_ids_logical_to_physical_static(topk_ids, info)
|
||||
if info.ep_dispatch_algorithm in ["dynamic", "fake"]:
|
||||
return _topk_ids_logical_to_physical_dynamic(topk_ids, info)
|
||||
raise NotImplementedError(f"Unknown algorithm {info.ep_dispatch_algorithm}")
|
||||
|
||||
|
||||
def _topk_ids_logical_to_physical_static(
|
||||
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
|
||||
) -> torch.Tensor:
|
||||
return info.partial_logical_to_rank_dispatch_physical_map[topk_ids]
|
||||
|
||||
|
||||
def _topk_ids_logical_to_physical_dynamic(
|
||||
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
|
||||
) -> torch.Tensor:
|
||||
topk_ids_original_shape = topk_ids.shape
|
||||
device = topk_ids.device
|
||||
topk_ids = topk_ids.flatten()
|
||||
|
||||
chosen_dispatch_index = (
|
||||
torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32, device=device)
|
||||
% info.partial_logical_to_all_physical_map_num_valid[topk_ids]
|
||||
)
|
||||
topk_ids = info.partial_logical_to_all_physical_map[topk_ids, chosen_dispatch_index]
|
||||
|
||||
topk_ids = topk_ids.view(topk_ids_original_shape)
|
||||
return topk_ids
|
||||
@@ -58,6 +58,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
prepare_abort,
|
||||
)
|
||||
from sglang.srt.distributed import get_pp_group, get_world_group
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.hf_transformers_utils import (
|
||||
get_processor,
|
||||
get_tokenizer,
|
||||
@@ -65,9 +66,6 @@ from sglang.srt.hf_transformers_utils import (
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.expert_distribution import (
|
||||
get_global_expert_distribution_recorder,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
CloseSessionReqInput,
|
||||
|
||||
Reference in New Issue
Block a user