From e98afbe042cf1dc40a6f87a81655bdc68c6d89c1 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Tue, 20 May 2025 13:13:55 +0800 Subject: [PATCH] Support dispatching logical to physical experts (#6385) --- python/sglang/srt/layers/moe/ep_moe/layer.py | 4 + python/sglang/srt/layers/moe/topk.py | 18 ++++ .../srt/managers/expert_distribution.py | 3 +- python/sglang/srt/managers/expert_location.py | 61 ++++++++++++- .../srt/managers/expert_location_dispatch.py | 91 +++++++++++++++++++ python/sglang/srt/managers/schedule_batch.py | 1 + .../sglang/srt/model_executor/model_runner.py | 2 +- python/sglang/srt/models/deepseek_v2.py | 2 + python/sglang/srt/server_args.py | 7 ++ 9 files changed, 184 insertions(+), 5 deletions(-) create mode 100644 python/sglang/srt/managers/expert_location_dispatch.py diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index e32e053c0..cd244d570 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -6,6 +6,7 @@ from torch.nn import Module from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.managers.expert_location import get_global_expert_location_metadata +from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.managers.schedule_batch import global_server_args_dict try: @@ -237,6 +238,9 @@ class EPMoE(torch.nn.Module): correction_bias=self.correction_bias, custom_routing_function=self.custom_routing_function, routed_scaling_factor=self.routed_scaling_factor, + expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( + layer_id=self.layer_id, + ), ) reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 075587dc0..8895e6be6 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -22,6 +22,10 @@ from sglang.srt.managers.expert_distribution import ( ExpertDistributionRecorder, get_global_expert_distribution_recorder, ) +from sglang.srt.managers.expert_location_dispatch import ( + ExpertLocationDispatchInfo, + topk_ids_logical_to_physical, +) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip @@ -100,6 +104,7 @@ def grouped_topk( n_share_experts_fusion: int = 0, routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" @@ -140,6 +145,7 @@ def grouped_topk( topk_weights = topk_weights / topk_weights_sum topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) + topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) return topk_weights, topk_ids @@ -155,6 +161,7 @@ def biased_grouped_topk_impl( n_share_experts_fusion: int = 0, routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" @@ -202,6 +209,7 @@ def biased_grouped_topk_impl( topk_weights = topk_weights / topk_weights_sum topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) + topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) return topk_weights, topk_ids @@ -232,6 +240,7 @@ def biased_grouped_topk( n_share_experts_fusion: int = 0, routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ): assert ( routed_scaling_factor is not None @@ -252,6 +261,8 @@ def biased_grouped_topk( n_share_experts_fusion, routed_scaling_factor, ) + # TODO merge into kernel for this branch + topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) # TODO will fuse this into kernel, thus use slow manual operation now torch.compile( _mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend() @@ -276,6 +287,7 @@ def biased_grouped_topk( n_share_experts_fusion=n_share_experts_fusion, routed_scaling_factor=routed_scaling_factor, num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, ) @@ -292,6 +304,7 @@ def select_experts( torch_native: bool = False, routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ): n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] # DeepSeek V2/V3/R1 series models use grouped_top_k @@ -309,6 +322,7 @@ def select_experts( n_share_experts_fusion=n_share_experts_fusion, routed_scaling_factor=routed_scaling_factor, num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, ) else: topk_weights, topk_ids = biased_grouped_topk( @@ -322,11 +336,13 @@ def select_experts( n_share_experts_fusion=n_share_experts_fusion, routed_scaling_factor=routed_scaling_factor, num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, ) elif torch_native and custom_routing_function is None: assert ( num_token_non_padded is None ), "num_token_non_padded is not yet supported in fused_topk_native" + assert expert_location_dispatch_info is None topk_weights, topk_ids = fused_topk_native( hidden_states=hidden_states, gating_output=router_logits, @@ -337,6 +353,7 @@ def select_experts( assert ( num_token_non_padded is None ), "num_token_non_padded is not yet supported in fused_topk" + assert expert_location_dispatch_info is None topk_weights, topk_ids = fused_topk( hidden_states=hidden_states, gating_output=router_logits, @@ -347,6 +364,7 @@ def select_experts( assert ( num_token_non_padded is None ), "num_token_non_padded is not yet supported in custom_routing_function" + assert expert_location_dispatch_info is None topk_weights, topk_ids = custom_routing_function( hidden_states=hidden_states, gating_output=router_logits, diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index c32cafbb8..a4bf17c1e 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -23,9 +23,10 @@ import torch import torch.distributed from sglang.srt.managers.expert_location import ExpertLocationMetadata +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import Withable +from sglang.srt.utils import Withable, get_bool_env_var logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index c8b8db7c4..b31e51557 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -33,6 +33,7 @@ class ExpertLocationMetadata: physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts) + logical_to_rank_dispatch_physical_map: torch.Tensor # (layers, num_logical_experts) # -------------------------------- properties ------------------------------------ @@ -67,9 +68,11 @@ class ExpertLocationMetadata: num_layers_2, num_logical_experts_1 = ( self.logical_to_all_physical_map_num_valid.shape ) - # TODO pr-chain: enable this later - # assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3 - # assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2 + num_layers_3, num_logical_experts_2 = ( + self.logical_to_rank_dispatch_physical_map.shape + ) + assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3 + assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2 assert num_physical_experts_0 == num_physical_experts_1 # -------------------------------- construction ------------------------------------ @@ -196,6 +199,13 @@ class ExpertLocationMetadata: physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map_padded, logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, + logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map( + logical_to_all_physical_map=logical_to_all_physical_map, + logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, + num_gpus=ep_size, + num_physical_experts=num_physical_experts, + ep_rank=torch.distributed.get_rank(), + ), ) # -------------------------------- usage ------------------------------------ @@ -262,6 +272,51 @@ def _pad_nested_array(arr, pad_value): return padded +# TODO use more sophisticated approaches +def compute_logical_to_rank_dispatch_physical_map( + logical_to_all_physical_map: torch.Tensor, + logical_to_all_physical_map_num_valid: torch.Tensor, + num_gpus: int, + num_physical_experts: int, + ep_rank: int, + base_seed: int = 42, +): + device = logical_to_all_physical_map.device + + num_local_physical_experts = num_physical_experts // num_gpus + num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape + + g = torch.Generator(device=device) + g.manual_seed(base_seed + ep_rank) + + output_shape = (num_layers, num_logical_experts) + chosen_index = ( + torch.randint( + 0, 65536, output_shape, dtype=torch.int32, device=device, generator=g + ) + % logical_to_all_physical_map_num_valid + ) + logical_to_rank_dispatch_physical_map = torch.gather( + logical_to_all_physical_map, dim=2, index=chosen_index.unsqueeze(-1) + ).squeeze(-1) + assert logical_to_rank_dispatch_physical_map.shape == output_shape + + for index in range(logical_to_all_physical_map_num_valid.max().item()): + partial_logical_to_all_physical_map = logical_to_all_physical_map[:, :, index] + is_valid = partial_logical_to_all_physical_map != -1 + is_same_gpu = ( + partial_logical_to_all_physical_map // num_local_physical_experts + ) == ep_rank + logical_to_rank_dispatch_physical_map = torch.where( + is_valid & is_same_gpu, + partial_logical_to_all_physical_map, + logical_to_rank_dispatch_physical_map, + ) + + assert torch.all(logical_to_rank_dispatch_physical_map != -1) + return logical_to_rank_dispatch_physical_map + + @dataclass class ModelConfigForExpertLocation: num_layers: int diff --git a/python/sglang/srt/managers/expert_location_dispatch.py b/python/sglang/srt/managers/expert_location_dispatch.py new file mode 100644 index 000000000..1e4d7b06e --- /dev/null +++ b/python/sglang/srt/managers/expert_location_dispatch.py @@ -0,0 +1,91 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from dataclasses import dataclass +from typing import Literal, Optional + +import torch + +from sglang.srt.managers.expert_location import get_global_expert_location_metadata +from sglang.srt.managers.schedule_batch import global_server_args_dict + + +@dataclass +class ExpertLocationDispatchInfo: + ep_dispatch_algorithm: Literal["static", "random"] + # (num_logical_experts,) + partial_logical_to_rank_dispatch_physical_map: torch.Tensor + # (num_logical_experts, X) + partial_logical_to_all_physical_map: torch.Tensor + # (num_logical_experts,) + partial_logical_to_all_physical_map_num_valid: torch.Tensor + num_physical_experts: int + + @classmethod + def init_new(cls, layer_id: int): + ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"] + expert_location_metadata = get_global_expert_location_metadata() + + if ep_dispatch_algorithm is None: + return None + + return cls( + ep_dispatch_algorithm=ep_dispatch_algorithm, + partial_logical_to_rank_dispatch_physical_map=expert_location_metadata.logical_to_rank_dispatch_physical_map[ + layer_id, : + ], + partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[ + layer_id, : + ], + partial_logical_to_all_physical_map_num_valid=expert_location_metadata.logical_to_all_physical_map_num_valid[ + layer_id, : + ], + num_physical_experts=expert_location_metadata.num_physical_experts, + ) + + +def topk_ids_logical_to_physical( + topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo] +) -> torch.Tensor: + if info is None: + return topk_ids + + if info.ep_dispatch_algorithm == "static": + return _topk_ids_logical_to_physical_static(topk_ids, info) + if info.ep_dispatch_algorithm == "dynamic": + return _topk_ids_logical_to_physical_dynamic(topk_ids, info) + raise NotImplementedError + + +def _topk_ids_logical_to_physical_static( + topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo] +) -> torch.Tensor: + return info.partial_logical_to_rank_dispatch_physical_map[topk_ids] + + +def _topk_ids_logical_to_physical_dynamic( + topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo] +) -> torch.Tensor: + topk_ids_original_shape = topk_ids.shape + device = topk_ids.device + topk_ids = topk_ids.flatten() + + chosen_dispatch_index = ( + torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32, device=device) + % info.partial_logical_to_all_physical_map_num_valid[topk_ids] + ) + topk_ids = info.partial_logical_to_all_physical_map[topk_ids, chosen_dispatch_index] + + topk_ids = topk_ids.view(topk_ids_original_shape) + return topk_ids diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 10f91ed20..9981fe776 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -83,6 +83,7 @@ global_server_args_dict = { "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, "max_micro_batch_size": ServerArgs.max_micro_batch_size, "moe_dense_tp_size": ServerArgs.moe_dense_tp_size, + "ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm, "n_share_experts_fusion": ServerArgs.n_share_experts_fusion, "sampling_backend": ServerArgs.sampling_backend, "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 78a94a898..5b4614585 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -13,7 +13,6 @@ # ============================================================================== """ModelRunner runs the forward passes of the models.""" -import collections import datetime import gc import inspect @@ -196,6 +195,7 @@ class ModelRunner: "deepep_config": server_args.deepep_config, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, "moe_dense_tp_size": server_args.moe_dense_tp_size, + "ep_dispatch_algorithm": server_args.ep_dispatch_algorithm, "n_share_experts_fusion": server_args.n_share_experts_fusion, "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "torchao_config": server_args.torchao_config, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0b6642a23..b11734d85 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -80,6 +80,7 @@ from sglang.srt.managers.expert_distribution import ( get_global_expert_distribution_recorder, ) from sglang.srt.managers.expert_location import ModelConfigForExpertLocation +from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -113,6 +114,7 @@ if _is_hip: decode_attention_fwd_grouped_rope, ) + logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 59e8dccc1..f0c862cc4 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -170,6 +170,7 @@ class ServerArgs: enable_ep_moe: bool = False enable_deepep_moe: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" + ep_dispatch_algorithm: Optional[Literal["static", "dynamic"]] = None init_expert_location: str = "trivial" expert_distribution_recorder_mode: Optional[ Literal["stat", "per_pass", "per_token"] @@ -1271,6 +1272,12 @@ class ServerArgs: default="auto", help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.", ) + parser.add_argument( + "--ep-dispatch-algorithm", + type=str, + default=ServerArgs.ep_dispatch_algorithm, + help="The algorithm to choose ranks for redundant experts in expert parallel.", + ) parser.add_argument( "--init-expert-location", type=str,