[Lint]Style: Convert example to ruff format (#5863)

### What this PR does / why we need it?
This PR fixes linting issues in the `example/` to align with the
project's Ruff configuration.

- vLLM version: v0.13.0
- vLLM main:
bde38c11df

Signed-off-by: root <root@LAPTOP-VQKDDVMG.localdomain>
Co-authored-by: root <root@LAPTOP-VQKDDVMG.localdomain>
This commit is contained in:
SILONG ZENG
2026-01-13 20:46:50 +08:00
committed by GitHub
parent f7b904641e
commit 78d5ce3e01
23 changed files with 678 additions and 1037 deletions

View File

@@ -4,13 +4,11 @@ Expert parallelism load balancer (EPLB) for vLLM.
The rearrangement algorithm is adapted from
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
"""
from typing import Tuple
import torch
def balanced_packing(weight: torch.Tensor,
num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]:
def balanced_packing(weight: torch.Tensor, num_packs: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs
are as balanced as possible.
@@ -18,8 +16,8 @@ def balanced_packing(weight: torch.Tensor,
Parameters:
weight: [X, n], the weight of each item
num_packs: number of packs
Returns:
Returns:
pack_index: [X, n], the pack index of each item
rank_in_pack: [X, n], the rank of the item in the pack
"""
@@ -28,26 +26,18 @@ def balanced_packing(weight: torch.Tensor,
groups_per_pack = num_groups // num_packs
if groups_per_pack == 1:
pack_index = torch.arange(weight.size(-1),
dtype=torch.int64,
device=weight.device).expand(weight.shape)
pack_index = torch.arange(weight.size(-1), dtype=torch.int64, device=weight.device).expand(weight.shape)
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
return pack_index, rank_in_pack
indices = weight.float().sort(-1, descending=True).indices.cpu()
pack_index = torch.full_like(weight,
fill_value=-1,
dtype=torch.int64,
device='cpu')
pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu")
rank_in_pack = torch.full_like(pack_index, fill_value=-1)
for i in range(num_layers):
pack_weights = [0] * num_packs
pack_items = [0] * num_packs
for group in indices[i]:
pack = min(
(i
for i in range(num_packs) if pack_items[i] < groups_per_pack),
key=pack_weights.__getitem__)
pack = min((i for i in range(num_packs) if pack_items[i] < groups_per_pack), key=pack_weights.__getitem__)
assert pack_items[pack] < groups_per_pack
pack_index[i, group] = pack
rank_in_pack[i, group] = pack_items[pack]
@@ -56,16 +46,14 @@ def balanced_packing(weight: torch.Tensor,
return pack_index, rank_in_pack
def replicate_experts(
weight: torch.Tensor,
num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def replicate_experts(weight: torch.Tensor, num_phy: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.
Parameters:
weight: [X, num_log]
num_phy: total number of experts after replication
Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank
@@ -75,8 +63,7 @@ def replicate_experts(
num_redundant = num_phy - num_log
assert num_redundant >= 0
device = weight.device
phy2log = torch.arange(num_phy, dtype=torch.int64,
device=device).repeat(n, 1)
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
arangen = torch.arange(n, dtype=torch.int64, device=device)
@@ -88,9 +75,9 @@ def replicate_experts(
return phy2log, rank, logcnt
def rebalance_experts_hierarchical(weight: torch.Tensor,
num_physical_experts: int, num_groups: int,
num_nodes: int, num_gpus: int):
def rebalance_experts_hierarchical(
weight: torch.Tensor, num_physical_experts: int, num_groups: int, num_nodes: int, num_gpus: int
):
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
@@ -99,7 +86,7 @@ def rebalance_experts_hierarchical(weight: torch.Tensor,
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
Returns:
physical_to_logical_map: [num_moe_layers, num_physical_experts]
logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
logical_count: [num_moe_layers, num_logical_experts]
@@ -115,45 +102,37 @@ def rebalance_experts_hierarchical(weight: torch.Tensor,
def inverse(perm: torch.Tensor) -> torch.Tensor:
inv = torch.empty_like(perm)
inv.scatter_(
1, perm,
torch.arange(perm.size(1), dtype=torch.int64,
device=perm.device).expand(perm.shape))
inv.scatter_(1, perm, torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(perm.shape))
return inv
# Step 1: pack groups to nodes
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
group_pack_index, group_rank_in_pack = balanced_packing(
tokens_per_group, num_nodes)
log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) *
group_size).unsqueeze(-1) +
torch.arange(group_size,
dtype=torch.int64,
device=group_pack_index.device)).flatten(-2)
group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes)
log2mlog = (
((group_pack_index * groups_per_node + group_rank_in_pack) * group_size).unsqueeze(-1)
+ torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)
).flatten(-2)
mlog2log = inverse(log2mlog)
# Step 2: construct redundant experts within nodes
# [num_layers * num_nodes, num_logical_experts // num_nodes]
tokens_per_mlog = weight.gather(-1, mlog2log).view(
-1, num_logical_experts // num_nodes)
phy2mlog, phyrank, mlogcnt = replicate_experts(
tokens_per_mlog, num_physical_experts // num_nodes)
tokens_per_mlog = weight.gather(-1, mlog2log).view(-1, num_logical_experts // num_nodes)
phy2mlog, phyrank, mlogcnt = replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes)
# Step 3: pack physical_experts to GPUs
# [num_layers * num_nodes, num_physical_experts // num_nodes]
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
pack_index, rank_in_pack = balanced_packing(tokens_per_phy,
num_gpus // num_nodes)
pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes)
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
pphy2phy = inverse(phy2pphy)
pphy2mlog = phy2mlog.gather(
-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + torch.arange(
0,
num_logical_experts,
num_logical_experts // num_nodes,
device=group_pack_index.device).view(1, -1, 1)).flatten(-2)
pphy2mlog = phy2mlog.gather(-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog = (
pphy2mlog.view(num_layers, num_nodes, -1)
+ torch.arange(0, num_logical_experts, num_logical_experts // num_nodes, device=group_pack_index.device).view(
1, -1, 1
)
).flatten(-2)
pphy2log = mlog2log.gather(-1, pphy2mlog)
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
@@ -161,9 +140,8 @@ def rebalance_experts_hierarchical(weight: torch.Tensor,
def rebalance_experts(
weight: torch.Tensor, num_replicas: int, num_groups: int,
num_nodes: int,
num_gpus: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
weight: torch.Tensor, num_replicas: int, num_groups: int, num_nodes: int, num_gpus: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.
@@ -174,7 +152,7 @@ def rebalance_experts(
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
Returns:
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
@@ -183,23 +161,20 @@ def rebalance_experts(
weight = weight.float().cpu()
if num_groups % num_nodes == 0:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_gpus)
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(weight, num_replicas, num_groups, num_nodes, num_gpus)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_gpus)
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(weight, num_replicas, 1, 1, num_gpus)
maxlogcnt = logcnt.max().item()
log2phy: torch.Tensor = torch.full(
(num_layers, num_logical_experts, maxlogcnt),
-1,
dtype=torch.int64,
device=logcnt.device)
(num_layers, num_logical_experts, maxlogcnt), -1, dtype=torch.int64, device=logcnt.device
)
log2phy.view(num_layers, -1).scatter_(
-1, phy2log * maxlogcnt + phyrank,
torch.arange(num_replicas, dtype=torch.int64,
device=log2phy.device).expand(num_layers, -1))
-1,
phy2log * maxlogcnt + phyrank,
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(num_layers, -1),
)
return phy2log, log2phy, logcnt
__all__ = ['rebalance_experts']
__all__ = ["rebalance_experts"]