From cba1cdbc46013dd34ab14e9e04ef5adec7c8d5d3 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Tue, 20 May 2025 12:06:21 +0800 Subject: [PATCH] Support DeepSeek EPLB algorithm with static distributions (#6387) --- python/sglang/srt/managers/deepseek_eplb.py | 278 ++++++++++++++++++ python/sglang/srt/managers/expert_location.py | 49 ++- 2 files changed, 319 insertions(+), 8 deletions(-) create mode 100644 python/sglang/srt/managers/deepseek_eplb.py diff --git a/python/sglang/srt/managers/deepseek_eplb.py b/python/sglang/srt/managers/deepseek_eplb.py new file mode 100644 index 000000000..6fdf3d97d --- /dev/null +++ b/python/sglang/srt/managers/deepseek_eplb.py @@ -0,0 +1,278 @@ +# This file is copied from https://github.com/deepseek-ai/EPLB/blob/main/eplb.py since that one is not a pypi package + +from typing import Literal, Tuple + +import torch + + +def pack_groups(tokens_per_group: torch.Tensor, num_nodes: int) -> torch.Tensor: + num_layers, num_groups = tokens_per_group.shape + assert num_groups % num_nodes == 0 + groups_per_rank = num_groups // num_nodes + + indices = tokens_per_group.float().sort(-1, descending=True).indices.cpu() + ret = torch.full_like( + tokens_per_group, fill_value=-1, dtype=torch.int64, device="cpu" + ) + for layer in range(num_layers): + node_tokens = [0] * num_nodes + node_groups = [0] * num_nodes + for group in indices[layer]: + + def key_func(rank: int) -> int: + if node_groups[rank] >= groups_per_rank: + return 1, 0 + else: + return 0, node_tokens[rank] + + rank = min(range(num_nodes), key=key_func) + assert node_groups[rank] < groups_per_rank + ret[layer, group] = rank * groups_per_rank + node_groups[rank] + node_tokens[rank] += tokens_per_group[layer, group] + node_groups[rank] += 1 + return ret + + +def make_redundant_experts_chunkwise( + tokens_per_expert: torch.Tensor, + num_physical_experts: int, + num_local_physical_experts: int, + num_physical_experts_per_chunk: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_steps, num_moe_layers, num_logical_experts = tokens_per_expert.shape + num_redundancy_experts = num_physical_experts - num_logical_experts + + physical_to_logical_map = torch.empty( + num_moe_layers, + num_physical_experts, + dtype=torch.int, + device=tokens_per_expert.device, + ) + logical_to_physical_map = torch.full( + (num_moe_layers, num_logical_experts, num_redundancy_experts + 1), + -1, + dtype=torch.int, + device=tokens_per_expert.device, + ) + logical_count = torch.ones( + num_moe_layers, + num_logical_experts, + dtype=torch.int, + device=tokens_per_expert.device, + ) + + assert num_physical_experts % num_physical_experts_per_chunk == 0 + num_chunks = num_physical_experts // num_physical_experts_per_chunk + assert num_logical_experts % num_chunks == 0 + num_logical_experts_per_group = num_logical_experts // num_chunks + assert num_redundancy_experts % num_chunks == 0 + num_redundancy_experts_per_group = num_redundancy_experts // num_chunks + + arange_num_moe_layers_num_groups = torch.arange( + num_moe_layers * num_chunks, dtype=torch.int, device=tokens_per_expert.device + ) + arange_num_logical_experts = torch.arange( + num_logical_experts, dtype=torch.int, device=tokens_per_expert.device + ) + arange_num_logical_experts_per_group = torch.arange( + num_logical_experts_per_group, dtype=torch.int, device=tokens_per_expert.device + ) + arange_num_groups = torch.arange( + num_chunks, dtype=torch.int, device=tokens_per_expert.device + ) + physical_to_logical_map.view( + num_moe_layers, num_chunks, num_physical_experts_per_chunk + )[:, :, :num_logical_experts_per_group] = arange_num_logical_experts.view( + num_chunks, num_logical_experts_per_group + ) + logical_to_physical_map[:, :, 0] = ( + arange_num_logical_experts_per_group.expand( + num_chunks, num_logical_experts_per_group + ) + + arange_num_groups[:, None] * num_physical_experts_per_chunk + ).view(num_logical_experts) + + tokens_per_expert_all_diff = tokens_per_expert + arange_num_logical_experts * 1e-4 + for i in range(num_redundancy_experts_per_group): + score = ( + tokens_per_expert_all_diff / logical_count + ) # NOTE: Values in score must be different from each other + score1 = tokens_per_expert / (logical_count + 1) + score = score.view( + num_steps, num_moe_layers, num_chunks, num_logical_experts_per_group + ) + score1 = score1.view_as(score) + values, indices = score.max(-1, keepdim=True) + values = values.expand_as(score).contiguous() + score.scatter_(-1, indices, score1.gather(-1, indices)) + values.scatter_(-1, indices, score.max(-1, keepdim=True).values) + redundancy_indices = values.sum(0).argmin(-1) + physical_to_logical_map.view( + num_moe_layers, num_chunks, num_physical_experts_per_chunk + )[:, :, num_logical_experts_per_group + i] = ( + redundancy_indices + arange_num_groups * num_logical_experts_per_group + ) + redundancy_count = ( + logical_count.view( + num_moe_layers * num_chunks, num_logical_experts_per_group + ) + .gather(-1, redundancy_indices.view(num_moe_layers * num_chunks, 1)) + .squeeze(1) + ) + physical_redundancy_indices = ( + ( + arange_num_groups * num_physical_experts_per_chunk + + num_logical_experts_per_group + + i + ) + .expand(num_moe_layers, num_chunks) + .flatten() + ) + logical_to_physical_map.view( + num_moe_layers * num_chunks, + num_logical_experts_per_group, + num_redundancy_experts + 1, + )[ + arange_num_moe_layers_num_groups, + redundancy_indices.view(num_moe_layers * num_chunks), + redundancy_count, + ] = physical_redundancy_indices + logical_count.view(num_moe_layers * num_chunks, num_logical_experts_per_group)[ + arange_num_moe_layers_num_groups, + redundancy_indices.view(num_moe_layers * num_chunks), + ] += 1 + + if num_local_physical_experts > 1: + # Load-balancing between GPUs + physical_to_logical_map_int64 = physical_to_logical_map.to(torch.int64) + counts = logical_count.gather(-1, physical_to_logical_map_int64) + score = tokens_per_expert.sum(0).gather(-1, physical_to_logical_map_int64) + score = score / counts + score = score.view(num_moe_layers, num_chunks, num_physical_experts_per_chunk) + indices = score.argsort(-1, descending=True) + indices += torch.arange( + 0, + num_physical_experts, + num_physical_experts_per_chunk, + dtype=indices.dtype, + device=indices.device, + )[None, :, None] + + assert num_physical_experts_per_chunk % num_local_physical_experts == 0 + num_local_groups = num_physical_experts_per_chunk // num_local_physical_experts + indices = indices.view( + num_moe_layers, num_chunks, num_local_physical_experts, num_local_groups + ) + indices[:, :, 1::2, :] = indices[:, :, 1::2, :].flip(-1) + indices = indices.transpose(2, 3) + indices = indices.reshape(num_moe_layers, num_physical_experts) + physical_to_logical_map = physical_to_logical_map.gather(-1, indices) + mask = logical_to_physical_map == -1 + logical_to_physical_map[mask] = 0 + logical_to_physical_map = ( + indices.argsort(-1) + .gather( + -1, logical_to_physical_map.view(num_moe_layers, -1).to(torch.int64) + ) + .view_as(logical_to_physical_map) + .to(torch.int) + ) + logical_to_physical_map[mask] = -1 + + return physical_to_logical_map, logical_to_physical_map, logical_count + + +def decode_rebalance_experts( + tokens_per_expert: torch.Tensor, + num_physical_experts: int, + num_local_physical_experts: int, +): + return make_redundant_experts_chunkwise( + tokens_per_expert, + num_physical_experts, + num_local_physical_experts, + num_physical_experts, + ) + + +def prefill_rebalance_experts( + tokens_per_expert: torch.Tensor, + num_physical_experts: int, + num_local_physical_experts: int, + num_groups: int, + num_nodes: int, +): + tokens_per_expert = tokens_per_expert.float().cpu() + + num_steps, _, num_logical_experts = tokens_per_expert.shape + assert num_logical_experts % num_groups == 0 + group_size = num_logical_experts // num_groups + assert num_groups % num_nodes == 0, f"{num_groups=} {num_nodes=}" + + tokens_per_group = tokens_per_expert.sum(0).unflatten(-1, (num_groups, -1)).sum(-1) + group_perm = pack_groups( + tokens_per_group, num_nodes + ) # [num_moe_layers, num_groups] => [num_moe_layers, num_nodes] + + # log2mlog [layers, #logexp] -> [layers, #logexp] + log2mlog = ( + (group_perm * group_size).unsqueeze(-1) + + torch.arange(group_size, dtype=torch.int64, device=group_perm.device) + ).flatten(-2) + + # mlog2log [layers, #logexp] -> [layers, #logexp], inverse of log2mlog + mlog2log = torch.empty_like(log2mlog) + arange = torch.arange( + num_logical_experts, dtype=torch.int64, device=mlog2log.device + ) + mlog2log.scatter_(1, log2mlog, arange.expand(log2mlog.size(0), -1)) + + # tokens_per_mlog[i][j][k] = tokens_per_expert[i][j][mlog2log[j][k]] + tokens_per_mlog = tokens_per_expert.gather( + 2, mlog2log.unsqueeze(0).expand(num_steps, -1, -1) + ) + + phy2mlog, mlog2phy, mlog_count = make_redundant_experts_chunkwise( + tokens_per_mlog, + num_physical_experts, + num_local_physical_experts, + num_physical_experts // num_nodes, + ) + + # phy2log[i][j] = mlog2log[i][phy2mlog[i][j]] + phy2log = mlog2log.gather(1, phy2mlog.to(torch.int64)) + + # mlog2phy: [num_moe_layers, num_logical_experts, ...] + # log2phy[i][j][k] = mlog2phy[i][log2mlog[i][j]][k] + log2phy = mlog2phy.gather( + 1, log2mlog.unsqueeze(-1).expand(-1, -1, mlog2phy.size(-1)).to(torch.int64) + ) + + # log_count[i][j] = mlog_count[i][log2mlog[i][j]] + log_count = mlog_count.gather(1, log2mlog) + return phy2log, log2phy, log_count + + +def rebalance_experts( + tokens_per_expert: torch.Tensor, + num_physical_experts: int, + num_local_physical_experts: int, + num_groups: int, + num_nodes: int, + phase: Literal["prefill", "decode"], +): + if phase == "prefill": + return prefill_rebalance_experts( + tokens_per_expert=tokens_per_expert, + num_physical_experts=num_physical_experts, + num_local_physical_experts=num_local_physical_experts, + num_groups=num_groups, + num_nodes=num_nodes, + ) + if phase == "decode": + return decode_rebalance_experts( + tokens_per_expert=tokens_per_expert, + num_physical_experts=num_physical_experts, + num_local_physical_experts=num_local_physical_experts, + ) + raise NotImplementedError diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index befb3c1f4..c8b8db7c4 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -117,6 +117,41 @@ class ExpertLocationMetadata: logical_to_all_physical_map=logical_to_all_physical_map, ) + @staticmethod + def init_by_eplb( + server_args: ServerArgs, model_config: ModelConfig, logical_count: torch.Tensor + ): + if not isinstance(logical_count, torch.Tensor): + logical_count = torch.tensor(logical_count) + if len(logical_count.shape) == 2: + logical_count = logical_count.unsqueeze(0) + logical_count = logical_count.to(server_args.device) + + common = ExpertLocationMetadata._init_common(server_args, model_config) + model_config_for_expert_location = common["model_config_for_expert_location"] + num_physical_experts = common["num_physical_experts"] + + phase = server_args.disaggregation_mode + if phase == "null": + phase = "decode" + + physical_to_logical_map, logical_to_all_physical_map, expert_count = ( + deepseek_eplb.rebalance_experts( + tokens_per_expert=logical_count, + num_physical_experts=num_physical_experts, + num_local_physical_experts=num_physical_experts // common["ep_size"], + num_groups=model_config_for_expert_location.num_groups, + num_nodes=server_args.nnodes, + phase=phase, + ) + ) + + return ExpertLocationMetadata._init_raw( + ep_size=common["ep_size"], + physical_to_logical_map=physical_to_logical_map, + logical_to_all_physical_map=logical_to_all_physical_map, + ) + @staticmethod def _init_common(server_args: ServerArgs, model_config: ModelConfig): model_config_for_expert_location = ( @@ -272,14 +307,12 @@ def compute_initial_expert_location_metadata( server_args, model_config, **data_dict ) elif "logical_count" in data_dict: - # TODO pr-chain: enable this later - raise NotImplementedError - # logger.info( - # "init_expert_location from init_by_eplb using ServerArgs.init_expert_location" - # ) - # return ExpertLocationMetadata.init_by_eplb( - # server_args, model_config, logical_count=data_dict["logical_count"] - # ) + logger.info( + "init_expert_location from init_by_eplb using ServerArgs.init_expert_location" + ) + return ExpertLocationMetadata.init_by_eplb( + server_args, model_config, logical_count=data_dict["logical_count"] + ) else: raise NotImplementedError( f"Unknown init_expert_location format ({list(data_dict.keys())=})"