Move files related to EPLB (#7580)

This commit is contained in:
fzyzcjy
2025-06-30 06:39:38 +08:00
committed by GitHub
parent e3f9b54819
commit 0c9c6c75a8
22 changed files with 42 additions and 54 deletions

View File

@@ -1,63 +0,0 @@
from enum import Enum, auto
from typing import Optional
import torch
from sglang.srt.managers.eplb_algorithms import deepseek, deepseek_vec
class EplbAlgorithm(Enum):
deepseek = auto()
deepseek_hierarchical = auto()
deepseek_vec = auto()
deepseek_vec_hierarchical = auto()
# TODO may have more algorithm later
def rebalance_experts(
tokens_per_expert: torch.Tensor,
num_physical_experts: int,
num_local_physical_experts: int,
num_groups: Optional[int],
num_nodes: int,
algorithm: EplbAlgorithm,
):
if algorithm in [EplbAlgorithm.deepseek, EplbAlgorithm.deepseek_hierarchical]:
return deepseek.rebalance_experts(
weight=tokens_per_expert.sum(dim=0),
num_replicas=num_physical_experts,
num_groups=num_groups,
num_nodes=num_nodes,
num_gpus=num_physical_experts // num_local_physical_experts,
enable_hierarchical=algorithm == EplbAlgorithm.deepseek_hierarchical,
)
if algorithm in [
EplbAlgorithm.deepseek_vec,
EplbAlgorithm.deepseek_vec_hierarchical,
]:
return deepseek_vec.rebalance_experts(
tokens_per_expert=tokens_per_expert,
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_physical_experts,
num_groups=num_groups,
num_nodes=num_nodes,
enable_hierarchical=algorithm == EplbAlgorithm.deepseek_vec_hierarchical,
)
raise NotImplementedError
def compute_algorithm(
raw_algorithm: str,
num_groups: Optional[int],
num_nodes: int,
) -> EplbAlgorithm:
if raw_algorithm != "auto":
return EplbAlgorithm[raw_algorithm]
# TODO test on real scenarios and know which ones perform better
if (num_groups is not None) and (num_groups % num_nodes == 0):
return EplbAlgorithm.deepseek_hierarchical
else:
return EplbAlgorithm.deepseek

View File

@@ -1,223 +0,0 @@
# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
from typing import Tuple
import torch
from sglang.srt.utils import get_bool_env_var
def balanced_packing(
weight: torch.Tensor, num_packs: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs
are as balanced as possible.
Parameters:
weight: [X, n], the weight of each item
num_packs: number of packs
Returns:
pack_index: [X, n], the pack index of each item
rank_in_pack: [X, n], the rank of the item in the pack
"""
num_layers, num_groups = weight.shape
assert num_groups % num_packs == 0
groups_per_pack = num_groups // num_packs
if groups_per_pack == 1:
pack_index = torch.arange(
weight.size(-1), dtype=torch.int64, device=weight.device
).expand(weight.shape)
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
return pack_index, rank_in_pack
indices = weight.float().sort(-1, descending=True).indices.cpu()
pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu")
rank_in_pack = torch.full_like(pack_index, fill_value=-1)
for i in range(num_layers):
pack_weights = [0] * num_packs
pack_items = [0] * num_packs
for group in indices[i]:
pack = min(
(i for i in range(num_packs) if pack_items[i] < groups_per_pack),
key=pack_weights.__getitem__,
)
assert pack_items[pack] < groups_per_pack
pack_index[i, group] = pack
rank_in_pack[i, group] = pack_items[pack]
pack_weights[pack] += weight[i, group]
pack_items[pack] += 1
return pack_index, rank_in_pack
def replicate_experts(
weight: torch.Tensor, num_phy: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.
Parameters:
weight: [X, num_log]
num_phy: total number of experts after replication
Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank
logcnt: [X, num_log], number of replicas for each logical expert
"""
n, num_log = weight.shape
num_redundant = num_phy - num_log
assert num_redundant >= 0
device = weight.device
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
arangen = torch.arange(n, dtype=torch.int64, device=device)
for i in range(num_log, num_phy):
redundant_indices = (weight / logcnt).max(dim=-1).indices
phy2log[:, i] = redundant_indices
rank[:, i] = logcnt[arangen, redundant_indices]
logcnt[arangen, redundant_indices] += 1
return phy2log, rank, logcnt
def rebalance_experts_hierarchical(
weight: torch.Tensor,
num_physical_experts: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
):
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
num_physical_experts: number of physical experts after replication
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map: [num_moe_layers, num_physical_experts]
logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
logical_count: [num_moe_layers, num_logical_experts]
"""
num_layers, num_logical_experts = weight.shape
assert num_logical_experts % num_groups == 0
group_size = num_logical_experts // num_groups
assert num_groups % num_nodes == 0
groups_per_node = num_groups // num_nodes
assert num_gpus % num_nodes == 0
assert num_physical_experts % num_gpus == 0
phy_experts_per_gpu = num_physical_experts // num_gpus
def inverse(perm: torch.Tensor) -> torch.Tensor:
inv = torch.empty_like(perm)
inv.scatter_(
1,
perm,
torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(
perm.shape
),
)
return inv
# Step 1: pack groups to nodes
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes)
log2mlog = (
(
(group_pack_index * groups_per_node + group_rank_in_pack) * group_size
).unsqueeze(-1)
+ torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)
).flatten(-2)
mlog2log = inverse(log2mlog)
# Step 2: construct redundant experts within nodes
# [num_layers * num_nodes, num_logical_experts // num_nodes]
tokens_per_mlog = weight.gather(-1, mlog2log).view(
-1, num_logical_experts // num_nodes
)
phy2mlog, phyrank, mlogcnt = replicate_experts(
tokens_per_mlog, num_physical_experts // num_nodes
)
# Step 3: pack physical_experts to GPUs
# [num_layers * num_nodes, num_physical_experts // num_nodes]
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes)
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
pphy2phy = inverse(phy2pphy)
pphy2mlog = phy2mlog.gather(
-1, pphy2phy
) # [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog = (
pphy2mlog.view(num_layers, num_nodes, -1)
+ torch.arange(
0,
num_logical_experts,
num_logical_experts // num_nodes,
device=group_pack_index.device,
).view(1, -1, 1)
).flatten(-2)
pphy2log = mlog2log.gather(-1, pphy2mlog)
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
return pphy2log, pphyrank, logcnt
def rebalance_experts(
weight: torch.Tensor,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
enable_hierarchical: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics for all logical experts
num_replicas: number of physical experts, must be a multiple of `num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
"""
num_layers, num_logical_experts = weight.shape
weight = weight.float().cpu()
if enable_hierarchical:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_gpus
)
maxlogcnt = logcnt.max().item()
log2phy: torch.Tensor = torch.full(
(num_layers, num_logical_experts, maxlogcnt),
-1,
dtype=torch.int64,
device=logcnt.device,
)
log2phy.view(num_layers, -1).scatter_(
-1,
phy2log * maxlogcnt + phyrank,
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
num_layers, -1
),
)
return phy2log, log2phy, logcnt
__all__ = ["rebalance_experts"]

View File

@@ -1,276 +0,0 @@
# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package
from typing import Optional, Tuple
import torch
def pack_groups(tokens_per_group: torch.Tensor, num_nodes: int) -> torch.Tensor:
num_layers, num_groups = tokens_per_group.shape
assert num_groups % num_nodes == 0
groups_per_rank = num_groups // num_nodes
indices = tokens_per_group.float().sort(-1, descending=True).indices.cpu()
ret = torch.full_like(
tokens_per_group, fill_value=-1, dtype=torch.int64, device="cpu"
)
for layer in range(num_layers):
node_tokens = [0] * num_nodes
node_groups = [0] * num_nodes
for group in indices[layer]:
def key_func(rank: int) -> int:
if node_groups[rank] >= groups_per_rank:
return 1, 0
else:
return 0, node_tokens[rank]
rank = min(range(num_nodes), key=key_func)
assert node_groups[rank] < groups_per_rank
ret[layer, group] = rank * groups_per_rank + node_groups[rank]
node_tokens[rank] += tokens_per_group[layer, group]
node_groups[rank] += 1
return ret
def make_redundant_experts_chunkwise(
tokens_per_expert: torch.Tensor,
num_physical_experts: int,
num_local_physical_experts: int,
num_physical_experts_per_chunk: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_steps, num_moe_layers, num_logical_experts = tokens_per_expert.shape
num_redundancy_experts = num_physical_experts - num_logical_experts
physical_to_logical_map = torch.empty(
num_moe_layers,
num_physical_experts,
dtype=torch.int,
device=tokens_per_expert.device,
)
logical_to_physical_map = torch.full(
(num_moe_layers, num_logical_experts, num_redundancy_experts + 1),
-1,
dtype=torch.int,
device=tokens_per_expert.device,
)
logical_count = torch.ones(
num_moe_layers,
num_logical_experts,
dtype=torch.int,
device=tokens_per_expert.device,
)
assert num_physical_experts % num_physical_experts_per_chunk == 0
num_chunks = num_physical_experts // num_physical_experts_per_chunk
assert num_logical_experts % num_chunks == 0
num_logical_experts_per_group = num_logical_experts // num_chunks
assert num_redundancy_experts % num_chunks == 0
num_redundancy_experts_per_group = num_redundancy_experts // num_chunks
arange_num_moe_layers_num_groups = torch.arange(
num_moe_layers * num_chunks, dtype=torch.int, device=tokens_per_expert.device
)
arange_num_logical_experts = torch.arange(
num_logical_experts, dtype=torch.int, device=tokens_per_expert.device
)
arange_num_logical_experts_per_group = torch.arange(
num_logical_experts_per_group, dtype=torch.int, device=tokens_per_expert.device
)
arange_num_groups = torch.arange(
num_chunks, dtype=torch.int, device=tokens_per_expert.device
)
physical_to_logical_map.view(
num_moe_layers, num_chunks, num_physical_experts_per_chunk
)[:, :, :num_logical_experts_per_group] = arange_num_logical_experts.view(
num_chunks, num_logical_experts_per_group
)
logical_to_physical_map[:, :, 0] = (
arange_num_logical_experts_per_group.expand(
num_chunks, num_logical_experts_per_group
)
+ arange_num_groups[:, None] * num_physical_experts_per_chunk
).view(num_logical_experts)
tokens_per_expert_all_diff = tokens_per_expert + arange_num_logical_experts * 1e-4
for i in range(num_redundancy_experts_per_group):
score = (
tokens_per_expert_all_diff / logical_count
) # NOTE: Values in score must be different from each other
score1 = tokens_per_expert / (logical_count + 1)
score = score.view(
num_steps, num_moe_layers, num_chunks, num_logical_experts_per_group
)
score1 = score1.view_as(score)
values, indices = score.max(-1, keepdim=True)
values = values.expand_as(score).contiguous()
score.scatter_(-1, indices, score1.gather(-1, indices))
values.scatter_(-1, indices, score.max(-1, keepdim=True).values)
redundancy_indices = values.sum(0).argmin(-1)
physical_to_logical_map.view(
num_moe_layers, num_chunks, num_physical_experts_per_chunk
)[:, :, num_logical_experts_per_group + i] = (
redundancy_indices + arange_num_groups * num_logical_experts_per_group
)
redundancy_count = (
logical_count.view(
num_moe_layers * num_chunks, num_logical_experts_per_group
)
.gather(-1, redundancy_indices.view(num_moe_layers * num_chunks, 1))
.squeeze(1)
)
physical_redundancy_indices = (
(
arange_num_groups * num_physical_experts_per_chunk
+ num_logical_experts_per_group
+ i
)
.expand(num_moe_layers, num_chunks)
.flatten()
)
logical_to_physical_map.view(
num_moe_layers * num_chunks,
num_logical_experts_per_group,
num_redundancy_experts + 1,
)[
arange_num_moe_layers_num_groups,
redundancy_indices.view(num_moe_layers * num_chunks),
redundancy_count,
] = physical_redundancy_indices
logical_count.view(num_moe_layers * num_chunks, num_logical_experts_per_group)[
arange_num_moe_layers_num_groups,
redundancy_indices.view(num_moe_layers * num_chunks),
] += 1
if num_local_physical_experts > 1:
# Load-balancing between GPUs
physical_to_logical_map_int64 = physical_to_logical_map.to(torch.int64)
counts = logical_count.gather(-1, physical_to_logical_map_int64)
score = tokens_per_expert.sum(0).gather(-1, physical_to_logical_map_int64)
score = score / counts
score = score.view(num_moe_layers, num_chunks, num_physical_experts_per_chunk)
indices = score.argsort(-1, descending=True)
indices += torch.arange(
0,
num_physical_experts,
num_physical_experts_per_chunk,
dtype=indices.dtype,
device=indices.device,
)[None, :, None]
assert num_physical_experts_per_chunk % num_local_physical_experts == 0
num_local_groups = num_physical_experts_per_chunk // num_local_physical_experts
indices = indices.view(
num_moe_layers, num_chunks, num_local_physical_experts, num_local_groups
)
indices[:, :, 1::2, :] = indices[:, :, 1::2, :].flip(-1)
indices = indices.transpose(2, 3)
indices = indices.reshape(num_moe_layers, num_physical_experts)
physical_to_logical_map = physical_to_logical_map.gather(-1, indices)
mask = logical_to_physical_map == -1
logical_to_physical_map[mask] = 0
logical_to_physical_map = (
indices.argsort(-1)
.gather(
-1, logical_to_physical_map.view(num_moe_layers, -1).to(torch.int64)
)
.view_as(logical_to_physical_map)
.to(torch.int)
)
logical_to_physical_map[mask] = -1
return physical_to_logical_map, logical_to_physical_map, logical_count
def decode_rebalance_experts(
tokens_per_expert: torch.Tensor,
num_physical_experts: int,
num_local_physical_experts: int,
):
return make_redundant_experts_chunkwise(
tokens_per_expert,
num_physical_experts,
num_local_physical_experts,
num_physical_experts,
)
def prefill_rebalance_experts(
tokens_per_expert: torch.Tensor,
num_physical_experts: int,
num_local_physical_experts: int,
num_groups: int,
num_nodes: int,
):
tokens_per_expert = tokens_per_expert.float().cpu()
num_steps, _, num_logical_experts = tokens_per_expert.shape
assert num_logical_experts % num_groups == 0
group_size = num_logical_experts // num_groups
assert num_groups % num_nodes == 0, f"{num_groups=} {num_nodes=}"
tokens_per_group = tokens_per_expert.sum(0).unflatten(-1, (num_groups, -1)).sum(-1)
group_perm = pack_groups(
tokens_per_group, num_nodes
) # [num_moe_layers, num_groups] => [num_moe_layers, num_nodes]
# log2mlog [layers, #logexp] -> [layers, #logexp]
log2mlog = (
(group_perm * group_size).unsqueeze(-1)
+ torch.arange(group_size, dtype=torch.int64, device=group_perm.device)
).flatten(-2)
# mlog2log [layers, #logexp] -> [layers, #logexp], inverse of log2mlog
mlog2log = torch.empty_like(log2mlog)
arange = torch.arange(
num_logical_experts, dtype=torch.int64, device=mlog2log.device
)
mlog2log.scatter_(1, log2mlog, arange.expand(log2mlog.size(0), -1))
# tokens_per_mlog[i][j][k] = tokens_per_expert[i][j][mlog2log[j][k]]
tokens_per_mlog = tokens_per_expert.gather(
2, mlog2log.unsqueeze(0).expand(num_steps, -1, -1)
)
phy2mlog, mlog2phy, mlog_count = make_redundant_experts_chunkwise(
tokens_per_mlog,
num_physical_experts,
num_local_physical_experts,
num_physical_experts // num_nodes,
)
# phy2log[i][j] = mlog2log[i][phy2mlog[i][j]]
phy2log = mlog2log.gather(1, phy2mlog.to(torch.int64))
# mlog2phy: [num_moe_layers, num_logical_experts, ...]
# log2phy[i][j][k] = mlog2phy[i][log2mlog[i][j]][k]
log2phy = mlog2phy.gather(
1, log2mlog.unsqueeze(-1).expand(-1, -1, mlog2phy.size(-1)).to(torch.int64)
)
# log_count[i][j] = mlog_count[i][log2mlog[i][j]]
log_count = mlog_count.gather(1, log2mlog)
return phy2log, log2phy, log_count
def rebalance_experts(
tokens_per_expert: torch.Tensor,
num_physical_experts: int,
num_local_physical_experts: int,
num_groups: Optional[int],
num_nodes: int,
enable_hierarchical: bool,
):
if enable_hierarchical:
return prefill_rebalance_experts(
tokens_per_expert=tokens_per_expert,
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_physical_experts,
num_groups=num_groups,
num_nodes=num_nodes,
)
else:
return decode_rebalance_experts(
tokens_per_expert=tokens_per_expert,
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_physical_experts,
)

View File

@@ -1,96 +0,0 @@
import logging
import time
from typing import TYPE_CHECKING, List
import torch.cuda
from sglang.srt.managers.expert_distribution import (
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import ExpertLocationMetadata
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
logger = logging.getLogger(__name__)
class EPLBManager:
def __init__(self, model_runner: "ModelRunner"):
super().__init__()
self._model_runner = model_runner
self._server_args = model_runner.server_args
self._rebalance_layers_per_chunk = (
self._server_args.eplb_rebalance_layers_per_chunk
)
self._rebalance_num_iterations = self._server_args.eplb_rebalance_num_iterations
# Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented.
assert (
self._server_args.eplb_rebalance_num_iterations
>= self._server_args.expert_distribution_recorder_buffer_size
), "eplb_rebalance_num_iterations must be greater than expert_distribution_recorder_buffer_size"
if not get_global_expert_distribution_recorder().recording:
get_global_expert_distribution_recorder().start_record()
logger.info(
f"[EPLBManager] system started, will rebalance per {self._rebalance_num_iterations} iterations."
)
self._main_generator = self._entrypoint()
def on_forward_pass_end(self):
next(self._main_generator)
# can be more complex if needed
def _entrypoint(self):
while True:
for _ in range(self._rebalance_num_iterations):
yield
yield from self.rebalance()
def rebalance(self):
logger.info("[EPLBManager] rebalance start")
enable_timing = self._rebalance_layers_per_chunk is None
if enable_timing:
torch.cuda.synchronize()
time_start = time.time()
logical_count = get_global_expert_distribution_recorder().dump_record(
output_mode="object"
)["logical_count"]
expert_location_metadata = ExpertLocationMetadata.init_by_eplb(
self._server_args, self._model_runner.model_config, logical_count
)
update_layer_ids_chunks = self._compute_update_layer_ids_chunks()
for chunk_index, update_layer_ids in enumerate(update_layer_ids_chunks):
if len(update_layer_ids_chunks) > 1:
yield
self._model_runner.update_expert_location(
expert_location_metadata,
update_layer_ids=update_layer_ids,
)
msg = f"[EPLBManager] rebalance end"
if enable_timing:
torch.cuda.synchronize()
time_end = time.time()
msg += f" time={time_end - time_start:.3f}s"
logger.info(msg)
def _compute_update_layer_ids_chunks(self) -> List[List[int]]:
all_layer_ids = sorted(
list(self._model_runner.model.routed_experts_weights_of_layer.keys())
)
chunk_size = self._rebalance_layers_per_chunk or 1000000
return list(_chunk_list(all_layer_ids, chunk_size=chunk_size))
def _chunk_list(items: List, chunk_size):
for start_index in range(0, len(items), chunk_size):
yield items[start_index : start_index + chunk_size]

View File

@@ -1,920 +0,0 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import logging
import os
import time
from abc import ABC
from collections import deque
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
import einops
import torch
import torch.distributed
from sglang.srt.managers.expert_location import ExpertLocationMetadata
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable, get_bool_env_var
logger = logging.getLogger(__name__)
# --------------------------------------- Entrypoint -----------------------------------------
_OutputMode = Literal["file", "object"]
class ExpertDistributionRecorder(ABC):
"""Global expert distribution recording"""
@staticmethod
def init_new(
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
rank: int,
):
if server_args.expert_distribution_recorder_mode is not None:
return _ExpertDistributionRecorderReal(
server_args, expert_location_metadata, rank
)
else:
return _ExpertDistributionRecorderNoop()
@contextmanager
def with_current_layer(self, layer_idx):
yield
@contextmanager
def with_debug_name(self, debug_name):
yield
@contextmanager
def disable_this_region(self):
yield
@contextmanager
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
yield
def on_select_experts(self, topk_ids: torch.Tensor):
pass
def on_deepep_dispatch_normal(
self,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
pass
def on_deepep_dispatch_low_latency(
self, local_physical_count_of_layer: torch.Tensor
):
pass
def start_record(self):
self._on_not_implemented()
def stop_record(self):
self._on_not_implemented()
def dump_record(self, output_mode: _OutputMode = "file"):
self._on_not_implemented()
@property
def recording(self):
return False
def _on_not_implemented(self):
raise Exception(
"Please set ServerArgs.expert_distribution_recorder_mode to use ExpertDistributionRecorder."
)
class _ExpertDistributionRecorderNoop(ExpertDistributionRecorder):
pass
class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
def __init__(
self,
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
rank: int,
):
self._server_args = server_args
self._expert_location_metadata = expert_location_metadata
self._recording = False
self._disable_all = False
self._current_forward_pass_id = Withable()
self._current_layer_idx = Withable()
self._current_debug_name = Withable()
self._accumulator = _Accumulator.init_new(
server_args, expert_location_metadata, rank
)
self._single_pass_gatherers = {
k: _SinglePassGatherer.init_new(server_args, expert_location_metadata, rank)
for k in self._accumulator.get_single_pass_gatherer_keys()
}
if server_args.enable_expert_distribution_metrics:
logger.info(
"ExpertDistributionRecorder auto start record since enable_expert_distribution_metrics"
)
self.start_record()
def with_current_layer(self, layer_idx):
return self._current_layer_idx.with_value(layer_idx)
def with_debug_name(self, debug_name):
return self._current_debug_name.with_value(debug_name)
@contextmanager
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
with self._current_forward_pass_id.with_value(forward_pass_id):
self._on_forward_pass_start(forward_batch)
try:
yield
finally:
self._on_forward_pass_end(forward_pass_id)
@contextmanager
def disable_this_region(self):
"""Context manager to temporarily disable recording."""
previous_disable_all = self._disable_all
self._disable_all = True
try:
yield
finally:
self._disable_all = previous_disable_all
def _on_forward_pass_start(self, forward_batch: ForwardBatch):
if not self._recording:
return
for gatherer_key, gatherer in self._single_pass_gatherers.items():
gatherer.reset()
gatherer.on_forward_pass_start(forward_batch)
def _on_forward_pass_end(self, forward_pass_id: int):
if not self._recording:
return
for gatherer_key, gatherer in self._single_pass_gatherers.items():
single_pass_data = gatherer.collect()
self._accumulator.append(forward_pass_id, gatherer_key, single_pass_data)
def on_select_experts(self, topk_ids: torch.Tensor):
self._on_hook("on_select_experts", topk_ids=topk_ids)
def on_deepep_dispatch_normal(
self,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
self._on_hook(
"on_deepep_dispatch_normal",
local_physical_count_of_layer=local_physical_count_of_layer,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert=num_tokens_per_expert,
)
def on_deepep_dispatch_low_latency(
self, local_physical_count_of_layer: torch.Tensor
):
self._on_hook(
"on_deepep_dispatch_low_latency",
local_physical_count_of_layer=local_physical_count_of_layer,
)
def _on_hook(self, hook_name: str, **kwargs):
if self._disable_all:
return
if not (self._recording or torch.cuda.is_current_stream_capturing()):
return
gatherer = self._single_pass_gatherers[
self._accumulator.get_single_pass_gatherer_key(
self._current_debug_name.value
)
]
getattr(gatherer, hook_name)(layer_idx=self._current_layer_idx.value, **kwargs)
def _reset(self):
"""Reset the expert distribution recorder."""
logger.info("Resetting ExpertDistributionRecorder...")
assert (
self._current_layer_idx.value is None
), f"{self._current_layer_idx.value=}"
for gatherer in self._single_pass_gatherers.values():
gatherer.reset()
self._accumulator.reset()
def start_record(self):
"""Start recording the expert distribution."""
if self._recording:
logger.warning(
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
)
self._reset()
self._recording = True
def stop_record(self):
"""Stop recording the expert distribution."""
if not self._recording:
logger.warning(
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
)
self._recording = False
def dump_record(self, output_mode: _OutputMode = "file"):
"""Dump the expert distribution record and reset the recorder after dumping."""
output = self._accumulator.dump(output_mode=output_mode)
self._reset()
return output
@property
def recording(self):
return self._recording
_global_expert_distribution_recorder: Optional[ExpertDistributionRecorder] = (
_ExpertDistributionRecorderNoop()
)
def get_global_expert_distribution_recorder():
return _global_expert_distribution_recorder
def set_global_expert_distribution_recorder(value):
global _global_expert_distribution_recorder
_global_expert_distribution_recorder = value
# --------------------------------------- SinglePassGatherer -----------------------------------------
class _SinglePassGatherer(ABC):
@staticmethod
def init_new(
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
rank: int,
) -> "_SinglePassGatherer":
if server_args.expert_distribution_recorder_mode == "per_token":
return _DetailSinglePassGatherer(
server_args, expert_location_metadata, rank
)
if server_args.expert_distribution_recorder_mode == "stat_approx":
if server_args.enable_deepep_moe and (server_args.deepep_mode == "normal"):
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
else:
raise NotImplementedError
if server_args.enable_deepep_moe:
if server_args.deepep_mode == "normal":
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
elif server_args.deepep_mode == "low_latency":
return _DeepepLowLatencySinglePassGatherer(
expert_location_metadata, rank
)
else:
raise NotImplementedError
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int):
self._expert_location_metadata = expert_location_metadata
self._rank = rank
def on_forward_pass_start(self, forward_batch: ForwardBatch):
pass
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
pass
def on_deepep_dispatch_normal(
self,
layer_idx: int,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
pass
def on_deepep_dispatch_low_latency(
self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
):
pass
def reset(self):
raise NotImplementedError
def collect(self) -> Dict:
raise NotImplementedError
class _DetailSinglePassGatherer(_SinglePassGatherer):
# DeepSeek V3 has this value; should generalize later
_TOP_K_NUM = 8
def __init__(
self,
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
rank: int,
):
super().__init__(expert_location_metadata, rank)
self._metadata: Optional[Dict[str, Any]] = None
self._topk_ids_of_layer = torch.zeros(
(
expert_location_metadata.num_layers,
# TODO determine the max number
server_args.chunked_prefill_size * 8,
self._TOP_K_NUM,
),
dtype=torch.int32,
device=server_args.device,
)
self._misc_objects: List[Dict[str, Any]] = []
assert (
not server_args.enable_two_batch_overlap
), "DetailSinglePassGatherer does not support TBO yet"
# TODO assert shared experts fusion is disabled, o/w data is wrong
def on_forward_pass_start(self, forward_batch: ForwardBatch):
assert self._metadata is None
self._metadata = dict(
# TODO pr-chain
# rids=forward_batch.rids,
input_ids=forward_batch.input_ids.cpu().tolist(),
positions=forward_batch.positions.cpu().tolist(),
extend_seq_lens=forward_batch.extend_seq_lens_cpu,
forward_mode=forward_batch.forward_mode.value,
)
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
self._topk_ids_of_layer[layer_idx, : topk_ids.shape[0], : topk_ids.shape[1]] = (
topk_ids
)
def on_deepep_dispatch_normal(
self,
layer_idx: int,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
self._misc_objects.append(
dict(
layer_id=layer_idx,
num_tokens_per_rank=num_tokens_per_rank.cpu().tolist(),
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank.cpu().tolist(),
num_tokens_per_expert=num_tokens_per_expert.cpu().tolist(),
)
)
def reset(self):
self._topk_ids_of_layer[...] = -1
self._misc_objects.clear()
self._metadata = None
def collect(self) -> Dict:
num_tokens = len(self._metadata["input_ids"])
return dict(
**self._metadata,
topk_ids_of_layer=self._topk_ids_of_layer[:, :num_tokens, :].clone().cpu(),
misc_objects=self._misc_objects,
)
class _LayerBasedCpuSinglePassGatherer(_SinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._objects_of_layer = {}
def _on_layer_data(self, layer_idx: int, objects: List[int]):
assert 0 <= layer_idx < self._expert_location_metadata.num_layers
if layer_idx in self._objects_of_layer:
self._objects_of_layer[layer_idx] = _list_sum(
self._objects_of_layer[layer_idx], objects
)
else:
self._objects_of_layer[layer_idx] = objects
def reset(self):
self._objects_of_layer.clear()
def _collect_objects(self, pad_len: int) -> torch.Tensor:
data = [
self._objects_of_layer.get(layer_index) or ([0] * pad_len)
for layer_index in range(self._expert_location_metadata.num_layers)
]
return torch.tensor(data)
def _list_sum(a: List, b: List) -> List:
return [x + y for x, y in zip(a, b, strict=True)]
class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
super().__init__(*args, **kwargs)
self._enable_global_physical_experts = enable_global_physical_experts
self._data = torch.zeros(
(
self._expert_location_metadata.num_layers,
(
self._expert_location_metadata.num_physical_experts
if enable_global_physical_experts
else self._expert_location_metadata.num_local_physical_experts
),
),
dtype=torch.int,
device="cuda",
)
def reset(self):
self._data[...] = 0
def collect(self) -> Dict:
if self._enable_global_physical_experts:
global_physical_count = self._data
else:
# Can optimize if bottleneck
global_physical_count = _convert_local_to_global_physical_count(
self._data,
rank=self._rank,
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
num_physical_experts=self._expert_location_metadata.num_physical_experts,
)
return dict(global_physical_count=global_physical_count)
class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, enable_global_physical_experts=True)
# can optimize (e.g. fuse / compile)
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
topk_ids = topk_ids.flatten()
mask = topk_ids != -1
self._data[layer_idx, :].scatter_add_(
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
)
class _DeepepNormalSinglePassGatherer(_LayerBasedCpuSinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if torch.distributed.get_rank() == 0:
logger.info(
"DeepepNormalSinglePassGatherer gathers approximate statistics. "
"If used with small batch size, consider using expert_distribution_recorder_mode=stat."
)
def on_deepep_dispatch_normal(
self,
layer_idx: int,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
assert isinstance(local_physical_count_of_layer, list)
self._on_layer_data(layer_idx, local_physical_count_of_layer)
def collect(self) -> Dict:
local_physical_count = super()._collect_objects(
pad_len=self._expert_location_metadata.num_local_physical_experts
)
global_physical_count = _convert_local_to_global_physical_count(
local_physical_count,
rank=self._rank,
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
num_physical_experts=self._expert_location_metadata.num_physical_experts,
)
return dict(global_physical_count=global_physical_count)
class _DeepepLowLatencySinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, enable_global_physical_experts=False)
def on_deepep_dispatch_low_latency(
self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
):
# Most naive implementation, can optimize later
self._data[layer_idx, :] += local_physical_count_of_layer
def _convert_local_to_global_physical_count(
local_physical_count: torch.Tensor,
rank: int,
num_local_physical_experts: int,
num_physical_experts: int,
) -> torch.Tensor:
dtype = local_physical_count.dtype
device = local_physical_count.device
num_layers, _ = local_physical_count.shape
ans = torch.zeros((num_layers, num_physical_experts), dtype=dtype, device=device)
ans[
:, num_local_physical_experts * rank : num_local_physical_experts * (rank + 1)
] = local_physical_count
return ans
# --------------------------------------- Accumulator -----------------------------------------
_SINGLE_PASS_GATHERER_KEY_PRIMARY = "primary"
class _Accumulator(ABC):
@staticmethod
def init_new(
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
rank: int,
) -> "_Accumulator":
return _Accumulator.get_class(server_args)(
server_args, expert_location_metadata, rank
)
@staticmethod
def get_class(server_args: ServerArgs) -> Type["_Accumulator"]:
return {
"stat": _StatAccumulator,
"stat_approx": _StatAccumulator,
"per_pass": _DetailAccumulator,
"per_token": _DetailAccumulator,
}[server_args.expert_distribution_recorder_mode]
def __init__(
self,
server_args: ServerArgs,
expert_location_metadata: "ExpertLocationMetadata",
rank: int,
):
self._server_args = server_args
self._expert_location_metadata = expert_location_metadata
self._rank = rank
def get_single_pass_gatherer_keys(self):
return [_SINGLE_PASS_GATHERER_KEY_PRIMARY]
def get_single_pass_gatherer_key(self, debug_name: Optional[str]):
return _SINGLE_PASS_GATHERER_KEY_PRIMARY
def append(
self,
forward_pass_id: int,
gatherer_key: str,
single_pass_data: Dict,
):
pass
def reset(self):
pass
def dump(self, output_mode: _OutputMode):
pass
class _UtilizationRateAccumulatorMixin(_Accumulator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._enable = self._server_args.enable_expert_distribution_metrics
if self._enable:
window_sizes = [10, 100, 1000]
self._history = _DequeCollection(maxlens=window_sizes)
self._rank = torch.distributed.get_rank()
def append(
self,
forward_pass_id: int,
gatherer_key: str,
single_pass_data: Dict,
):
super().append(forward_pass_id, gatherer_key, single_pass_data)
if self._enable:
self._append_utilization_rate(
forward_pass_id, single_pass_data["global_physical_count"]
)
def reset(self):
super().reset()
if self._enable:
self._history.clear()
def _append_utilization_rate(
self, forward_pass_id: int, single_pass_global_physical_count: torch.Tensor
):
gpu_physical_count = compute_gpu_physical_count(
single_pass_global_physical_count,
num_gpu=self._expert_location_metadata.ep_size,
)
gpu_physical_count = gpu_physical_count.to(self._server_args.device)
torch.distributed.reduce(
gpu_physical_count, dst=0, op=torch.distributed.ReduceOp.SUM
)
if self._rank == 0:
utilization_rate_tensor = compute_utilization_rate(gpu_physical_count)
utilization_rate = torch.mean(utilization_rate_tensor).item()
self._history.append(utilization_rate)
gpu_physical_count_sum = gpu_physical_count.sum().item()
logger.info(
f"[Expert Balancedness] "
f"forward_pass_id={forward_pass_id} "
f"current_pass_balancedness={utilization_rate:.03f} "
f"{''.join(f'last_{size}_average_balancedness={value:.03f} ' for size, value in self._history.mean().items())} "
f"gpu_physical_count_sum={gpu_physical_count_sum}"
# f"current_pass_per_layer={[round(x, 2) for x in utilization_rate_tensor.cpu().tolist()]}"
)
class _DequeCollection:
def __init__(self, maxlens: List[int]):
self._dequeues = [deque(maxlen=maxlen) for maxlen in maxlens]
def append(self, value):
for d in self._dequeues:
d.append(value)
def clear(self):
for d in self._dequeues:
d.clear()
def mean(self) -> Dict[int, float]:
return {d.maxlen: sum(d) / len(d) for d in self._dequeues}
class _DetailAccumulator(_UtilizationRateAccumulatorMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._records = []
def get_single_pass_gatherer_keys(self):
if False: # TODO `server_args.enable_two_batch_overlap`
return [_SINGLE_PASS_GATHERER_KEY_PRIMARY, "child_a", "child_b"]
return super().get_single_pass_gatherer_keys()
def get_single_pass_gatherer_key(self, debug_name: Optional[str]):
if False: # TODO `server_args.enable_two_batch_overlap`
return debug_name or _SINGLE_PASS_GATHERER_KEY_PRIMARY
return super().get_single_pass_gatherer_key(debug_name)
def append(
self,
forward_pass_id: int,
gatherer_key: str,
single_pass_data: Dict,
):
super().append(forward_pass_id, gatherer_key, single_pass_data)
def _process_object(obj):
if isinstance(obj, torch.Tensor):
return obj.cpu().clone()
return obj
single_pass_data_processed = {
k: _process_object(v) for k, v in single_pass_data.items()
}
self._records.append(
dict(
forward_pass_id=forward_pass_id,
rank=self._rank,
gatherer_key=gatherer_key,
**single_pass_data_processed,
)
)
def reset(self):
super().reset()
self._records.clear()
def dump(self, output_mode: _OutputMode):
assert output_mode == "file"
output = dict(
records=self._records,
# NOTE: This may change during recording, so here we say it is the "last" one
last_physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map,
)
_dump_to_file(
f"expert_distribution_recorder_{time.time()}_{self._rank}.pt", output
)
class _StatAccumulator(_UtilizationRateAccumulatorMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._global_physical_count_of_buffered_step = _Buffer.init_new(
item_shape=(
self._expert_location_metadata.num_layers,
# Cannot use local_physical_count to support select_experts
self._expert_location_metadata.num_physical_experts,
),
buffer_size=self._server_args.expert_distribution_recorder_buffer_size,
dtype=torch.int32,
device=self._server_args.device,
)
self._first_dump = True
def append(
self,
forward_pass_id: int,
gatherer_key: str,
single_pass_data: Dict,
):
super().append(forward_pass_id, gatherer_key, single_pass_data)
# Can optimize if overhead here is large
self._global_physical_count_of_buffered_step.append(
single_pass_data["global_physical_count"]
)
def reset(self):
super().reset()
self._global_physical_count_of_buffered_step.reset()
def dump(self, output_mode: _OutputMode):
logical_count_of_buffered_step = _convert_global_physical_count_to_logical_count(
self._global_physical_count_of_buffered_step.get_all(),
num_layers=self._expert_location_metadata.num_layers,
num_logical_experts=self._expert_location_metadata.num_logical_experts,
physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map,
)
if self._first_dump:
self._first_dump = False
torch.cuda.empty_cache()
torch.distributed.all_reduce(
logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM
)
output = dict(
rank=self._rank,
logical_count=logical_count_of_buffered_step,
)
if output_mode == "file":
if self._rank == 0:
_dump_to_file(f"expert_distribution_recorder_{time.time()}.pt", output)
elif output_mode == "object":
return output
else:
raise NotImplementedError
def _dump_to_file(name, data):
save_dir = Path(os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR", "/tmp"))
path_output = save_dir / name
logger.info(f"Write expert distribution to {path_output}")
if not save_dir.exists():
save_dir.mkdir(parents=True, exist_ok=True)
torch.save(data, str(path_output))
class _Buffer:
@staticmethod
def init_new(item_shape: Tuple, buffer_size: int, dtype, device):
if buffer_size < 0:
return _InfiniteBuffer(item_shape, dtype=dtype, device=device)
else:
return _CircularBuffer(item_shape, buffer_size, dtype=dtype, device=device)
def append(self, value: torch.Tensor):
raise NotImplementedError
def get_all(self) -> torch.Tensor:
raise NotImplementedError
def reset(self):
raise NotImplementedError
class _CircularBuffer(_Buffer):
def __init__(self, item_shape: Tuple, buffer_size: int, dtype, device):
self._buffer = torch.zeros(
(buffer_size, *item_shape), dtype=dtype, device=device
)
self._curr_index = 0
def append(self, value: torch.Tensor):
self._buffer[self._curr_index] = value
self._curr_index = (self._curr_index + 1) % len(self._buffer)
def get_all(self) -> torch.Tensor:
return self._buffer
def reset(self):
self._buffer[...] = 0
class _InfiniteBuffer(_Buffer):
def __init__(self, item_shape: Tuple, dtype, device):
self._item_shape = item_shape
self._buffer = torch.zeros((128, *item_shape), dtype=dtype, device=device)
self._size = 0
def append(self, value: torch.Tensor):
curr_buffer_size = len(self._buffer)
dtype = self._buffer.dtype
device = self._buffer.device
if self._size == curr_buffer_size:
new_buffer = torch.zeros(
(2 * curr_buffer_size, *self._item_shape), dtype=dtype, device=device
)
new_buffer[:curr_buffer_size] = self._buffer
self._buffer = new_buffer
self._buffer[self._size] = value
self._size += 1
def get_all(self) -> torch.Tensor:
return self._buffer[: self._size]
def reset(self):
self._buffer[...] = 0
self._size = 0
def _convert_global_physical_count_to_logical_count(
# (whatever, num_layers, num_physical_experts)
global_physical_count: torch.Tensor,
num_layers: int,
num_logical_experts: int,
physical_to_logical_map: torch.Tensor,
):
dim_extra, _, _ = global_physical_count.shape
dtype = global_physical_count.dtype
device = global_physical_count.device
logical_count = torch.zeros(
(dim_extra, num_layers, num_logical_experts), dtype=dtype, device=device
)
logical_count.scatter_add_(
dim=2,
index=physical_to_logical_map.unsqueeze(0)
.expand(dim_extra, -1, -1)
.to(torch.int64),
src=global_physical_count,
)
return logical_count
def compute_gpu_physical_count(
physical_count_of_whatever: torch.Tensor, # (..., num_layer, num_physical_expert)
num_gpu: int,
):
"""output: gpu_physical_count_of_batch (..., num_layer, num_gpu)"""
return einops.reduce(
physical_count_of_whatever,
"... num_layer (num_gpu num_expert_per_gpu) -> ... num_layer num_gpu",
"sum",
num_gpu=num_gpu,
)
def compute_utilization_rate(
gpu_physical_count_of_batch: torch.Tensor, # (..., num_layer, num_gpu)
):
"""output: utilization_rate (..., num_layer)"""
gpu_physical_count_of_batch = gpu_physical_count_of_batch.float()
max_gpu_physical_count = einops.reduce(
gpu_physical_count_of_batch,
"... num_layer num_gpu -> ... num_layer",
"max",
)
avg_gpu_physical_count = einops.reduce(
gpu_physical_count_of_batch,
"... num_layer num_gpu -> ... num_layer",
"mean",
)
return (avg_gpu_physical_count + 1e-5) / (max_gpu_physical_count + 1e-5)

View File

@@ -1,448 +0,0 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import json
import logging
import random
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
import torch
import torch.distributed
import torch.nn.functional as F
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.managers import eplb_algorithms
from sglang.srt.model_loader import get_model_architecture
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
@dataclass
class ExpertLocationMetadata:
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
physical_to_logical_map_cpu: torch.Tensor
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
# (layers, num_logical_experts)
logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
# -------------------------------- properties ------------------------------------
@property
def num_layers(self) -> int:
return self.physical_to_logical_map.shape[0]
@property
def num_physical_experts(self) -> int:
return self.physical_to_logical_map.shape[1]
@property
def num_local_physical_experts(self) -> int:
ans, remainder = divmod(self.num_physical_experts, self.ep_size)
assert remainder == 0
return ans
@property
def num_logical_experts(self) -> int:
return self.logical_to_all_physical_map.shape[1]
@property
def ep_size(self):
# TODO change when EP size != world size
return torch.distributed.get_world_size()
def __post_init__(self):
num_layers_0, num_physical_experts_0 = self.physical_to_logical_map.shape
num_layers_1, num_logical_experts_0, num_physical_experts_1 = (
self.logical_to_all_physical_map.shape
)
num_layers_2, num_logical_experts_1 = (
self.logical_to_all_physical_map_num_valid.shape
)
assert num_layers_0 == num_layers_1 == num_layers_2
assert num_logical_experts_0 == num_logical_experts_1
assert num_physical_experts_0 == num_physical_experts_1
# -------------------------------- construction ------------------------------------
@staticmethod
def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
"""Trivial location - logical expert i corresponds to physical expert i"""
common = ExpertLocationMetadata._init_common(server_args, model_config)
num_physical_experts = common["num_physical_experts"]
model_config_for_expert_location = common["model_config_for_expert_location"]
num_layers = model_config_for_expert_location.num_layers
num_logical_experts = model_config_for_expert_location.num_logical_experts
physical_to_logical_map = (
torch.arange(0, num_physical_experts).repeat(num_layers, 1)
% num_logical_experts
)
return ExpertLocationMetadata.init_by_mapping(
server_args,
model_config,
physical_to_logical_map=physical_to_logical_map,
)
@staticmethod
def init_by_mapping(
server_args: ServerArgs,
model_config: ModelConfig,
physical_to_logical_map,
):
if not isinstance(physical_to_logical_map, torch.Tensor):
physical_to_logical_map = torch.tensor(physical_to_logical_map)
physical_to_logical_map = physical_to_logical_map.to(server_args.device)
common = ExpertLocationMetadata._init_common(server_args, model_config)
model_config_for_expert_location = common["model_config_for_expert_location"]
logical_to_all_physical_map = _compute_logical_to_all_physical_map(
physical_to_logical_map,
num_logical_experts=model_config_for_expert_location.num_logical_experts,
)
return ExpertLocationMetadata._init_raw(
server_args=server_args,
ep_size=common["ep_size"],
physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map,
)
@staticmethod
def init_by_eplb(
server_args: ServerArgs, model_config: ModelConfig, logical_count: torch.Tensor
):
if not isinstance(logical_count, torch.Tensor):
logical_count = torch.tensor(logical_count)
if len(logical_count.shape) == 2:
logical_count = logical_count.unsqueeze(0)
logical_count = logical_count.to(server_args.device)
common = ExpertLocationMetadata._init_common(server_args, model_config)
model_config_for_expert_location = common["model_config_for_expert_location"]
num_physical_experts = common["num_physical_experts"]
num_groups = model_config_for_expert_location.num_groups
num_nodes = server_args.nnodes
physical_to_logical_map, logical_to_all_physical_map, expert_count = (
eplb_algorithms.rebalance_experts(
tokens_per_expert=logical_count,
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_physical_experts // common["ep_size"],
num_groups=num_groups,
num_nodes=num_nodes,
algorithm=eplb_algorithms.compute_algorithm(
raw_algorithm=server_args.eplb_algorithm,
num_groups=num_groups,
num_nodes=num_nodes,
),
)
)
return ExpertLocationMetadata._init_raw(
server_args=server_args,
ep_size=common["ep_size"],
physical_to_logical_map=physical_to_logical_map.to(server_args.device),
logical_to_all_physical_map=logical_to_all_physical_map.to(
server_args.device
),
)
@staticmethod
def _init_common(server_args: ServerArgs, model_config: ModelConfig):
model_config_for_expert_location = (
ModelConfigForExpertLocation.from_model_config(model_config)
)
num_physical_experts = (
model_config_for_expert_location.num_logical_experts
+ server_args.ep_num_redundant_experts
)
ep_size = server_args.ep_size
assert num_physical_experts % ep_size == 0
num_local_physical_experts = num_physical_experts // ep_size
return dict(
model_config_for_expert_location=model_config_for_expert_location,
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_physical_experts,
ep_size=ep_size,
)
@staticmethod
def _init_raw(
server_args: ServerArgs,
ep_size: int,
physical_to_logical_map: torch.Tensor,
logical_to_all_physical_map: torch.Tensor,
):
_, num_physical_experts = physical_to_logical_map.shape
logical_to_all_physical_map_padded = F.pad(
logical_to_all_physical_map,
(0, num_physical_experts - logical_to_all_physical_map.shape[-1]),
value=-1,
)
logical_to_all_physical_map_num_valid = torch.count_nonzero(
logical_to_all_physical_map != -1, dim=-1
)
return ExpertLocationMetadata(
physical_to_logical_map=physical_to_logical_map,
physical_to_logical_map_cpu=physical_to_logical_map.cpu(),
logical_to_all_physical_map=logical_to_all_physical_map_padded,
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
logical_to_rank_dispatch_physical_map=(
compute_logical_to_rank_dispatch_physical_map(
logical_to_all_physical_map=logical_to_all_physical_map,
num_gpus=ep_size,
num_physical_experts=num_physical_experts,
# TODO improve when we have real EP rank
ep_rank=torch.distributed.get_rank() % ep_size,
)
if server_args.ep_dispatch_algorithm == "static"
else None
),
)
# -------------------------------- mutation ------------------------------------
def update(
self,
other: "ExpertLocationMetadata",
update_layer_ids: List[int],
):
for field in [
"ep_size",
]:
assert getattr(self, field) == getattr(other, field)
for field in [
"physical_to_logical_map",
"physical_to_logical_map_cpu",
"logical_to_all_physical_map",
"logical_to_all_physical_map_num_valid",
"logical_to_rank_dispatch_physical_map",
]:
other_field = getattr(other, field)
self_field = getattr(self, field)
assert (other_field is not None) == (self_field is not None)
if self_field is not None:
mask_update = torch.tensor(
[i in update_layer_ids for i in range(self.num_layers)]
)
mask_update = mask_update.view(*([-1] + [1] * (self_field.dim() - 1)))
mask_update = mask_update.to(self_field.device, non_blocking=True)
self_field[...] = torch.where(mask_update, other_field, self_field)
# -------------------------------- usage ------------------------------------
def logical_to_all_physical(
self, layer_id: int, logical_expert_id: int
) -> List[int]:
return [
physical_expert_id
for physical_expert_id in self.logical_to_all_physical_map[
layer_id, logical_expert_id
].tolist()
if physical_expert_id != -1
]
_global_expert_location_metadata: Optional[ExpertLocationMetadata] = None
def get_global_expert_location_metadata():
return _global_expert_location_metadata
def set_global_expert_location_metadata(value):
global _global_expert_location_metadata
assert _global_expert_location_metadata is None
_global_expert_location_metadata = value
def _compute_logical_to_all_physical_map(
physical_to_logical_map: torch.Tensor, num_logical_experts: int
):
# This is rarely called, so we use for loops for maximum clarity
num_layers, num_physical_experts = physical_to_logical_map.shape
logical_to_all_physical_map = [
[[] for _ in range(num_logical_experts)] for _ in range(num_layers)
]
for layer_id in range(num_layers):
for physical_expert_id in range(num_physical_experts):
logical_expert_id = physical_to_logical_map[
layer_id, physical_expert_id
].item()
logical_to_all_physical_map[layer_id][logical_expert_id].append(
physical_expert_id
)
logical_to_all_physical_map = _pad_nested_array(
logical_to_all_physical_map, pad_value=-1
)
return torch.tensor(
logical_to_all_physical_map, device=physical_to_logical_map.device
)
def _pad_nested_array(arr, pad_value):
max_len = max(len(inner) for outer in arr for inner in outer)
padded = [
[inner + [pad_value] * (max_len - len(inner)) for inner in outer]
for outer in arr
]
return padded
# TODO optimize performance (rewrite and/or run in separate process with overlap)
def compute_logical_to_rank_dispatch_physical_map(
logical_to_all_physical_map: torch.Tensor,
num_gpus: int,
num_physical_experts: int,
ep_rank: int,
seed: int = 42,
):
r = random.Random(seed)
num_local_physical_experts = num_physical_experts // num_gpus
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
dtype = logical_to_all_physical_map.dtype
logical_to_rank_dispatch_physical_map = torch.full(
size=(num_gpus, num_layers, num_logical_experts),
fill_value=-1,
dtype=dtype,
)
for layer_id in range(num_layers):
for logical_expert_id in range(num_logical_experts):
candidate_physical_expert_ids = _logical_to_all_physical_raw(
logical_to_all_physical_map, layer_id, logical_expert_id
)
output_partial = logical_to_rank_dispatch_physical_map[
:, layer_id, logical_expert_id
]
for gpu_id in range(num_gpus):
same_gpu_physical_expert_ids = [
physical_expert_id
for physical_expert_id in candidate_physical_expert_ids
if _compute_gpu_id_of_physical_expert(
physical_expert_id, num_local_physical_experts
)
== gpu_id
]
if len(same_gpu_physical_expert_ids) > 0:
output_partial[gpu_id] = same_gpu_physical_expert_ids[0]
num_remain = torch.sum(output_partial == -1).item()
output_partial[output_partial == -1] = torch.tensor(
_fair_choices(candidate_physical_expert_ids, k=num_remain, r=r),
dtype=dtype,
)
assert torch.all(logical_to_rank_dispatch_physical_map != -1)
device = logical_to_all_physical_map.device
return logical_to_rank_dispatch_physical_map[ep_rank, :, :].to(device)
def _logical_to_all_physical_raw(
logical_to_all_physical_map, layer_id: int, logical_expert_id: int
) -> List[int]:
return [
physical_expert_id
for physical_expert_id in logical_to_all_physical_map[
layer_id, logical_expert_id
].tolist()
if physical_expert_id != -1
]
def _compute_gpu_id_of_physical_expert(
physical_expert_id: int, num_local_physical_experts: int
) -> int:
return physical_expert_id // num_local_physical_experts
def _fair_choices(arr: List, k: int, r: random.Random) -> List:
quotient, remainder = divmod(k, len(arr))
ans = arr * quotient + r.sample(arr, k=remainder)
r.shuffle(ans)
return ans
@dataclass
class ModelConfigForExpertLocation:
num_layers: int
num_logical_experts: int
num_groups: Optional[int] = None
@staticmethod
def init_dummy():
return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1)
@staticmethod
def from_model_config(model_config: ModelConfig):
model_class, _ = get_model_architecture(model_config)
if hasattr(model_class, "get_model_config_for_expert_location"):
return model_class.get_model_config_for_expert_location(
model_config.hf_config
)
else:
return ModelConfigForExpertLocation.init_dummy()
def compute_initial_expert_location_metadata(
server_args: ServerArgs, model_config: ModelConfig
) -> ExpertLocationMetadata:
data = server_args.init_expert_location
if data == "trivial":
return ExpertLocationMetadata.init_trivial(server_args, model_config)
# TODO unify with the utils function
if data.endswith(".pt"):
data_dict = torch.load(data, weights_only=True)
elif data.endswith(".json"):
data_dict = json.loads(Path(data).read_text())
else:
data_dict = json.loads(data)
if "physical_to_logical_map" in data_dict:
logger.info(
"init_expert_location from init_by_mapping using ServerArgs.init_expert_location"
)
return ExpertLocationMetadata.init_by_mapping(
server_args, model_config, **data_dict
)
elif "logical_count" in data_dict:
logger.info(
"init_expert_location from init_by_eplb using ServerArgs.init_expert_location"
)
return ExpertLocationMetadata.init_by_eplb(
server_args, model_config, logical_count=data_dict["logical_count"]
)
else:
raise NotImplementedError(
f"Unknown init_expert_location format ({list(data_dict.keys())=})"
)

View File

@@ -1,108 +0,0 @@
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from dataclasses import dataclass
from typing import Literal, Optional
import torch
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
from sglang.srt.managers.schedule_batch import global_server_args_dict
@dataclass
class ExpertLocationDispatchInfo:
ep_dispatch_algorithm: Literal["static", "random"]
# (num_logical_experts,)
partial_logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
# (num_logical_experts, X)
partial_logical_to_all_physical_map: torch.Tensor
# (num_logical_experts,)
partial_logical_to_all_physical_map_num_valid: torch.Tensor
num_physical_experts: int
@classmethod
def init_new(cls, layer_id: int):
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
expert_location_metadata = get_global_expert_location_metadata()
if ep_dispatch_algorithm is None:
return None
return cls(
ep_dispatch_algorithm=ep_dispatch_algorithm,
partial_logical_to_rank_dispatch_physical_map=(
expert_location_metadata.logical_to_rank_dispatch_physical_map[
layer_id, :
]
if expert_location_metadata.logical_to_rank_dispatch_physical_map
is not None
else None
),
partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[
layer_id, :
],
partial_logical_to_all_physical_map_num_valid=expert_location_metadata.logical_to_all_physical_map_num_valid[
layer_id, :
],
num_physical_experts=expert_location_metadata.num_physical_experts,
)
def transform_select_experts_inputs(
router_logits: torch.Tensor,
correction_bias: Optional[torch.Tensor],
info: Optional[ExpertLocationDispatchInfo],
):
if (info is not None) and (info.ep_dispatch_algorithm == "fake"):
router_logits = torch.randn_like(router_logits)
if correction_bias is not None:
correction_bias = torch.zeros_like(correction_bias)
return router_logits, correction_bias
def topk_ids_logical_to_physical(
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
) -> torch.Tensor:
if info is None:
return topk_ids
if info.ep_dispatch_algorithm == "static":
return _topk_ids_logical_to_physical_static(topk_ids, info)
if info.ep_dispatch_algorithm in ["dynamic", "fake"]:
return _topk_ids_logical_to_physical_dynamic(topk_ids, info)
raise NotImplementedError(f"Unknown algorithm {info.ep_dispatch_algorithm}")
def _topk_ids_logical_to_physical_static(
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
) -> torch.Tensor:
return info.partial_logical_to_rank_dispatch_physical_map[topk_ids]
def _topk_ids_logical_to_physical_dynamic(
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
) -> torch.Tensor:
topk_ids_original_shape = topk_ids.shape
device = topk_ids.device
topk_ids = topk_ids.flatten()
chosen_dispatch_index = (
torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32, device=device)
% info.partial_logical_to_all_physical_map_num_valid[topk_ids]
)
topk_ids = info.partial_logical_to_all_physical_map[topk_ids, chosen_dispatch_index]
topk_ids = topk_ids.view(topk_ids_original_shape)
return topk_ids

View File

@@ -58,6 +58,7 @@ from sglang.srt.disaggregation.utils import (
prepare_abort,
)
from sglang.srt.distributed import get_pp_group, get_world_group
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.hf_transformers_utils import (
get_processor,
get_tokenizer,
@@ -65,9 +66,6 @@ from sglang.srt.hf_transformers_utils import (
)
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.expert_distribution import (
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.io_struct import (
AbortReq,
CloseSessionReqInput,