# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Expert parallelism load balancer (EPLB) for vLLM. This module implements the core rearrangement algorithm. The rearrangement algorithm is adapted from [DeepSeek EPLB](https://github.com/deepseek-ai/eplb). Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example on how the EPLB algorithm works. """ import numpy as np import torch from .abstract import AbstractEplbPolicy class DefaultEplbPolicy(AbstractEplbPolicy): @classmethod def balanced_packing( cls, weight: torch.Tensor, num_packs: int ) -> tuple[torch.Tensor, torch.Tensor]: """ Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs are as balanced as possible. Parameters: weight: [X, n], the weight of each item num_packs: number of packs Returns: pack_index: [X, n], the pack index of each item rank_in_pack: [X, n], the rank of the item in the pack """ num_layers, num_groups = weight.shape assert num_groups % num_packs == 0 groups_per_pack = num_groups // num_packs device = weight.device if groups_per_pack == 1: pack_index = torch.arange( weight.size(-1), dtype=torch.int64, device=device ).expand(weight.shape) rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device) return pack_index, rank_in_pack weight_np = weight.cpu().numpy() # Sort and get indices in decending order indices_np = np.argsort(-weight_np, axis=-1) pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64) rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64) # Run the packing algorithm for i in range(num_layers): pack_weights = [0.0] * num_packs pack_items = [0] * num_packs for group in indices_np[i]: # Find a pack with capacity that has the lowest weight pack = min( (j for j in range(num_packs) if pack_items[j] < groups_per_pack), key=pack_weights.__getitem__, ) assert pack_items[pack] < groups_per_pack pack_index_np[i, group] = pack rank_in_pack_np[i, group] = pack_items[pack] pack_weights[pack] += weight_np[i, group] pack_items[pack] += 1 pack_index = torch.from_numpy(pack_index_np).to(device) rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device) return pack_index, rank_in_pack @classmethod def replicate_experts( cls, weight: torch.Tensor, num_phy: int ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized. Parameters: weight: [X, num_log] num_phy: total number of experts after replication Returns: phy2log: [X, num_phy], logical expert id of each physical expert rank: [X, num_phy], the replica rank logcnt: [X, num_log], number of replicas for each logical expert """ n, num_log = weight.shape num_redundant = num_phy - num_log assert num_redundant >= 0 device = weight.device phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) arangen = torch.arange(n, dtype=torch.int64, device=device) for i in range(num_log, num_phy): redundant_indices = (weight / logcnt).max(dim=-1).indices phy2log[:, i] = redundant_indices rank[:, i] = logcnt[arangen, redundant_indices] logcnt[arangen, redundant_indices] += 1 return phy2log, rank, logcnt @classmethod def rebalance_experts_hierarchical( cls, weight: torch.Tensor, num_physical_experts: int, num_groups: int, num_nodes: int, num_gpus: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Parameters: weight: [num_moe_layers, num_logical_experts] num_physical_experts: number of physical experts after replication num_groups: number of expert groups num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster num_gpus: number of GPUs, must be a multiple of `num_nodes` Returns: phy2log: [layers, num_replicas], the expert index of each replica log2phy: [layers, num_logical_experts, X], the replica indices for each expert logcnt: [layers, num_logical_experts], number of physical replicas for each logical expert """ num_layers, num_logical_experts = weight.shape assert num_logical_experts % num_groups == 0 group_size = num_logical_experts // num_groups assert num_groups % num_nodes == 0 groups_per_node = num_groups // num_nodes assert num_gpus % num_nodes == 0 assert num_physical_experts % num_gpus == 0 phy_experts_per_gpu = num_physical_experts // num_gpus def inverse(perm: torch.Tensor) -> torch.Tensor: inv = torch.empty_like(perm) inv.scatter_( 1, perm, torch.arange( perm.size(1), dtype=torch.int64, device=perm.device ).expand(perm.shape), ) return inv # Step 1: pack groups to nodes tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) group_pack_index, group_rank_in_pack = cls.balanced_packing( tokens_per_group, num_nodes ) log2mlog = ( ( (group_pack_index * groups_per_node + group_rank_in_pack) * group_size ).unsqueeze(-1) + torch.arange( group_size, dtype=torch.int64, device=group_pack_index.device ) ).flatten(-2) mlog2log = inverse(log2mlog) # Step 2: construct redundant experts within nodes # [num_layers * num_nodes, num_logical_experts // num_nodes] tokens_per_mlog = weight.gather(-1, mlog2log).view( -1, num_logical_experts // num_nodes ) phy2mlog, phyrank, mlogcnt = cls.replicate_experts( tokens_per_mlog, num_physical_experts // num_nodes ) # Step 3: pack physical_experts to GPUs # [num_layers * num_nodes, num_physical_experts // num_nodes] tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) pack_index, rank_in_pack = cls.balanced_packing( tokens_per_phy, num_gpus // num_nodes ) phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack pphy2phy = inverse(phy2pphy) pphy2mlog = phy2mlog.gather( -1, pphy2phy ) # [num_layers * num_nodes, num_log_per_nodes] pphy2mlog = ( pphy2mlog.view(num_layers, num_nodes, -1) + torch.arange( 0, num_logical_experts, num_logical_experts // num_nodes, device=group_pack_index.device, ).view(1, -1, 1) ).flatten(-2) pphy2log = mlog2log.gather(-1, pphy2mlog) pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) return pphy2log, pphyrank, logcnt @classmethod def rebalance_experts( cls, weight: torch.Tensor, num_replicas: int, num_groups: int, num_nodes: int, num_ranks: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Entry point for expert-parallelism load balancer. Parameters: weight: [layers, num_logical_experts], the load statistics for all logical experts num_replicas: number of physical experts, must be a multiple of `num_gpus` num_groups: number of expert groups num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster num_ranks: number of ranks, must be a multiple of `num_nodes` Returns: phy2log: [layers, num_replicas], the expert index of each replica log2phy: [layers, num_logical_experts, X], the replica indices for each expert logcnt: [layers, num_logical_experts], number of physical replicas for each logical expert """ num_layers, num_logical_experts = weight.shape weight = weight.float() if num_groups % num_nodes == 0: # use hierarchical load-balance policy phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical( weight, num_replicas, num_groups, num_nodes, num_ranks ) else: # use global load-balance policy phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical( weight, num_replicas, 1, 1, num_ranks ) num_redundant_experts = num_replicas - num_logical_experts maxlogcnt = num_redundant_experts + 1 log2phy: torch.Tensor = torch.full( (num_layers, num_logical_experts, maxlogcnt), -1, dtype=torch.int64, device=logcnt.device, ) log2phy.view(num_layers, -1).scatter_( -1, phy2log * maxlogcnt + phyrank, torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( num_layers, -1 ), ) return phy2log, log2phy, logcnt