268 lines
10 KiB
Python
268 lines
10 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
Expert parallelism load balancer (EPLB) for vLLM.
|
|
|
|
This module implements the core rearrangement algorithm.
|
|
|
|
The rearrangement algorithm is adapted from
|
|
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
|
|
|
|
Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
|
|
on how the EPLB algorithm works.
|
|
"""
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from .abstract import AbstractEplbPolicy
|
|
|
|
|
|
class DefaultEplbPolicy(AbstractEplbPolicy):
|
|
@classmethod
|
|
def balanced_packing(
|
|
cls, weight: torch.Tensor, num_packs: int
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Pack n weighted objects to m packs, such that each bin contains exactly
|
|
n/m objects and the weights of all packs are as balanced as possible.
|
|
|
|
Parameters:
|
|
weight: [X, n], the weight of each item
|
|
num_packs: number of packs
|
|
|
|
Returns:
|
|
pack_index: [X, n], the pack index of each item
|
|
rank_in_pack: [X, n], the rank of the item in the pack
|
|
"""
|
|
num_layers, num_groups = weight.shape
|
|
assert num_groups % num_packs == 0
|
|
groups_per_pack = num_groups // num_packs
|
|
|
|
device = weight.device
|
|
|
|
if groups_per_pack == 1:
|
|
pack_index = torch.arange(
|
|
weight.size(-1), dtype=torch.int64, device=device
|
|
).expand(weight.shape)
|
|
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device)
|
|
return pack_index, rank_in_pack
|
|
|
|
weight_np = weight.cpu().numpy()
|
|
|
|
# Sort and get indices in decending order
|
|
indices_np = np.argsort(-weight_np, axis=-1)
|
|
|
|
pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
|
|
rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
|
|
|
|
# Run the packing algorithm
|
|
for i in range(num_layers):
|
|
pack_weights = [0.0] * num_packs
|
|
pack_items = [0] * num_packs
|
|
|
|
for group in indices_np[i]:
|
|
# Find a pack with capacity that has the lowest weight
|
|
pack = min(
|
|
(j for j in range(num_packs) if pack_items[j] < groups_per_pack),
|
|
key=pack_weights.__getitem__,
|
|
)
|
|
|
|
assert pack_items[pack] < groups_per_pack
|
|
pack_index_np[i, group] = pack
|
|
rank_in_pack_np[i, group] = pack_items[pack]
|
|
pack_weights[pack] += weight_np[i, group]
|
|
pack_items[pack] += 1
|
|
|
|
pack_index = torch.from_numpy(pack_index_np).to(device)
|
|
rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device)
|
|
|
|
return pack_index, rank_in_pack
|
|
|
|
@classmethod
|
|
def replicate_experts(
|
|
cls, weight: torch.Tensor, num_phy: int
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
|
|
load of all replicas is minimized.
|
|
|
|
Parameters:
|
|
weight: [X, num_log]
|
|
num_phy: total number of experts after replication
|
|
|
|
Returns:
|
|
phy2log: [X, num_phy], logical expert id of each physical expert
|
|
rank: [X, num_phy], the replica rank
|
|
logcnt: [X, num_log], number of replicas for each logical expert
|
|
"""
|
|
n, num_log = weight.shape
|
|
num_redundant = num_phy - num_log
|
|
assert num_redundant >= 0
|
|
device = weight.device
|
|
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
|
|
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
|
|
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
|
|
arangen = torch.arange(n, dtype=torch.int64, device=device)
|
|
for i in range(num_log, num_phy):
|
|
redundant_indices = (weight / logcnt).max(dim=-1).indices
|
|
phy2log[:, i] = redundant_indices
|
|
rank[:, i] = logcnt[arangen, redundant_indices]
|
|
logcnt[arangen, redundant_indices] += 1
|
|
return phy2log, rank, logcnt
|
|
|
|
@classmethod
|
|
def rebalance_experts_hierarchical(
|
|
cls,
|
|
weight: torch.Tensor,
|
|
num_physical_experts: int,
|
|
num_groups: int,
|
|
num_nodes: int,
|
|
num_gpus: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Parameters:
|
|
weight: [num_moe_layers, num_logical_experts]
|
|
num_physical_experts: number of physical experts after replication
|
|
num_groups: number of expert groups
|
|
num_nodes: number of server nodes, where the intra-node network
|
|
(e.g, NVLink) is faster
|
|
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
|
|
|
Returns:
|
|
phy2log: [layers, num_replicas], the expert
|
|
index of each replica
|
|
log2phy: [layers, num_logical_experts, X],
|
|
the replica indices for each expert
|
|
logcnt: [layers, num_logical_experts], number of
|
|
physical replicas for each logical expert
|
|
"""
|
|
num_layers, num_logical_experts = weight.shape
|
|
assert num_logical_experts % num_groups == 0
|
|
group_size = num_logical_experts // num_groups
|
|
assert num_groups % num_nodes == 0
|
|
groups_per_node = num_groups // num_nodes
|
|
assert num_gpus % num_nodes == 0
|
|
assert num_physical_experts % num_gpus == 0
|
|
phy_experts_per_gpu = num_physical_experts // num_gpus
|
|
|
|
def inverse(perm: torch.Tensor) -> torch.Tensor:
|
|
inv = torch.empty_like(perm)
|
|
inv.scatter_(
|
|
1,
|
|
perm,
|
|
torch.arange(
|
|
perm.size(1), dtype=torch.int64, device=perm.device
|
|
).expand(perm.shape),
|
|
)
|
|
return inv
|
|
|
|
# Step 1: pack groups to nodes
|
|
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
|
|
group_pack_index, group_rank_in_pack = cls.balanced_packing(
|
|
tokens_per_group, num_nodes
|
|
)
|
|
log2mlog = (
|
|
(
|
|
(group_pack_index * groups_per_node + group_rank_in_pack) * group_size
|
|
).unsqueeze(-1)
|
|
+ torch.arange(
|
|
group_size, dtype=torch.int64, device=group_pack_index.device
|
|
)
|
|
).flatten(-2)
|
|
mlog2log = inverse(log2mlog)
|
|
|
|
# Step 2: construct redundant experts within nodes
|
|
# [num_layers * num_nodes, num_logical_experts // num_nodes]
|
|
tokens_per_mlog = weight.gather(-1, mlog2log).view(
|
|
-1, num_logical_experts // num_nodes
|
|
)
|
|
phy2mlog, phyrank, mlogcnt = cls.replicate_experts(
|
|
tokens_per_mlog, num_physical_experts // num_nodes
|
|
)
|
|
|
|
# Step 3: pack physical_experts to GPUs
|
|
# [num_layers * num_nodes, num_physical_experts // num_nodes]
|
|
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
|
|
pack_index, rank_in_pack = cls.balanced_packing(
|
|
tokens_per_phy, num_gpus // num_nodes
|
|
)
|
|
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
|
|
pphy2phy = inverse(phy2pphy)
|
|
|
|
pphy2mlog = phy2mlog.gather(
|
|
-1, pphy2phy
|
|
) # [num_layers * num_nodes, num_log_per_nodes]
|
|
pphy2mlog = (
|
|
pphy2mlog.view(num_layers, num_nodes, -1)
|
|
+ torch.arange(
|
|
0,
|
|
num_logical_experts,
|
|
num_logical_experts // num_nodes,
|
|
device=group_pack_index.device,
|
|
).view(1, -1, 1)
|
|
).flatten(-2)
|
|
pphy2log = mlog2log.gather(-1, pphy2mlog)
|
|
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
|
|
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
|
|
return pphy2log, pphyrank, logcnt
|
|
|
|
@classmethod
|
|
def rebalance_experts(
|
|
cls,
|
|
weight: torch.Tensor,
|
|
num_replicas: int,
|
|
num_groups: int,
|
|
num_nodes: int,
|
|
num_ranks: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Entry point for expert-parallelism load balancer.
|
|
|
|
Parameters:
|
|
weight: [layers, num_logical_experts], the load statistics for all
|
|
logical experts
|
|
num_replicas: number of physical experts, must be a multiple of
|
|
`num_gpus`
|
|
num_groups: number of expert groups
|
|
num_nodes: number of server nodes, where the intra-node network
|
|
(e.g, NVLink) is faster
|
|
num_ranks: number of ranks, must be a multiple of `num_nodes`
|
|
|
|
Returns:
|
|
phy2log: [layers, num_replicas], the expert
|
|
index of each replica
|
|
log2phy: [layers, num_logical_experts, X],
|
|
the replica indices for each expert
|
|
logcnt: [layers, num_logical_experts], number of
|
|
physical replicas for each logical expert
|
|
"""
|
|
num_layers, num_logical_experts = weight.shape
|
|
weight = weight.float()
|
|
if num_groups % num_nodes == 0:
|
|
# use hierarchical load-balance policy
|
|
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
|
|
weight, num_replicas, num_groups, num_nodes, num_ranks
|
|
)
|
|
else:
|
|
# use global load-balance policy
|
|
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
|
|
weight, num_replicas, 1, 1, num_ranks
|
|
)
|
|
num_redundant_experts = num_replicas - num_logical_experts
|
|
maxlogcnt = num_redundant_experts + 1
|
|
log2phy: torch.Tensor = torch.full(
|
|
(num_layers, num_logical_experts, maxlogcnt),
|
|
-1,
|
|
dtype=torch.int64,
|
|
device=logcnt.device,
|
|
)
|
|
log2phy.view(num_layers, -1).scatter_(
|
|
-1,
|
|
phy2log * maxlogcnt + phyrank,
|
|
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
|
|
num_layers, -1
|
|
),
|
|
)
|
|
return phy2log, log2phy, logcnt
|