Support dispatching logical to physical experts (#6385)
This commit is contained in:
@@ -6,6 +6,7 @@ from torch.nn import Module
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
try:
|
||||
@@ -237,6 +238,9 @@ class EPMoE(torch.nn.Module):
|
||||
correction_bias=self.correction_bias,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
),
|
||||
)
|
||||
|
||||
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
||||
|
||||
@@ -22,6 +22,10 @@ from sglang.srt.managers.expert_distribution import (
|
||||
ExpertDistributionRecorder,
|
||||
get_global_expert_distribution_recorder,
|
||||
)
|
||||
from sglang.srt.managers.expert_location_dispatch import (
|
||||
ExpertLocationDispatchInfo,
|
||||
topk_ids_logical_to_physical,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
||||
|
||||
@@ -100,6 +104,7 @@ def grouped_topk(
|
||||
n_share_experts_fusion: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -140,6 +145,7 @@ def grouped_topk(
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
|
||||
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
@@ -155,6 +161,7 @@ def biased_grouped_topk_impl(
|
||||
n_share_experts_fusion: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -202,6 +209,7 @@ def biased_grouped_topk_impl(
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
|
||||
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
@@ -232,6 +240,7 @@ def biased_grouped_topk(
|
||||
n_share_experts_fusion: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
assert (
|
||||
routed_scaling_factor is not None
|
||||
@@ -252,6 +261,8 @@ def biased_grouped_topk(
|
||||
n_share_experts_fusion,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
# TODO merge into kernel for this branch
|
||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||
# TODO will fuse this into kernel, thus use slow manual operation now
|
||||
torch.compile(
|
||||
_mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend()
|
||||
@@ -276,6 +287,7 @@ def biased_grouped_topk(
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
|
||||
|
||||
@@ -292,6 +304,7 @@ def select_experts(
|
||||
torch_native: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||
# DeepSeek V2/V3/R1 series models use grouped_top_k
|
||||
@@ -309,6 +322,7 @@ def select_experts(
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = biased_grouped_topk(
|
||||
@@ -322,11 +336,13 @@ def select_experts(
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
elif torch_native and custom_routing_function is None:
|
||||
assert (
|
||||
num_token_non_padded is None
|
||||
), "num_token_non_padded is not yet supported in fused_topk_native"
|
||||
assert expert_location_dispatch_info is None
|
||||
topk_weights, topk_ids = fused_topk_native(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
@@ -337,6 +353,7 @@ def select_experts(
|
||||
assert (
|
||||
num_token_non_padded is None
|
||||
), "num_token_non_padded is not yet supported in fused_topk"
|
||||
assert expert_location_dispatch_info is None
|
||||
topk_weights, topk_ids = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
@@ -347,6 +364,7 @@ def select_experts(
|
||||
assert (
|
||||
num_token_non_padded is None
|
||||
), "num_token_non_padded is not yet supported in custom_routing_function"
|
||||
assert expert_location_dispatch_info is None
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
|
||||
@@ -23,9 +23,10 @@ import torch
|
||||
import torch.distributed
|
||||
|
||||
from sglang.srt.managers.expert_location import ExpertLocationMetadata
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import Withable
|
||||
from sglang.srt.utils import Withable, get_bool_env_var
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ class ExpertLocationMetadata:
|
||||
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
|
||||
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
|
||||
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
|
||||
logical_to_rank_dispatch_physical_map: torch.Tensor # (layers, num_logical_experts)
|
||||
|
||||
# -------------------------------- properties ------------------------------------
|
||||
|
||||
@@ -67,9 +68,11 @@ class ExpertLocationMetadata:
|
||||
num_layers_2, num_logical_experts_1 = (
|
||||
self.logical_to_all_physical_map_num_valid.shape
|
||||
)
|
||||
# TODO pr-chain: enable this later
|
||||
# assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
|
||||
# assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
|
||||
num_layers_3, num_logical_experts_2 = (
|
||||
self.logical_to_rank_dispatch_physical_map.shape
|
||||
)
|
||||
assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
|
||||
assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
|
||||
assert num_physical_experts_0 == num_physical_experts_1
|
||||
|
||||
# -------------------------------- construction ------------------------------------
|
||||
@@ -196,6 +199,13 @@ class ExpertLocationMetadata:
|
||||
physical_to_logical_map=physical_to_logical_map,
|
||||
logical_to_all_physical_map=logical_to_all_physical_map_padded,
|
||||
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
||||
logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map(
|
||||
logical_to_all_physical_map=logical_to_all_physical_map,
|
||||
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
||||
num_gpus=ep_size,
|
||||
num_physical_experts=num_physical_experts,
|
||||
ep_rank=torch.distributed.get_rank(),
|
||||
),
|
||||
)
|
||||
|
||||
# -------------------------------- usage ------------------------------------
|
||||
@@ -262,6 +272,51 @@ def _pad_nested_array(arr, pad_value):
|
||||
return padded
|
||||
|
||||
|
||||
# TODO use more sophisticated approaches
|
||||
def compute_logical_to_rank_dispatch_physical_map(
|
||||
logical_to_all_physical_map: torch.Tensor,
|
||||
logical_to_all_physical_map_num_valid: torch.Tensor,
|
||||
num_gpus: int,
|
||||
num_physical_experts: int,
|
||||
ep_rank: int,
|
||||
base_seed: int = 42,
|
||||
):
|
||||
device = logical_to_all_physical_map.device
|
||||
|
||||
num_local_physical_experts = num_physical_experts // num_gpus
|
||||
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
|
||||
|
||||
g = torch.Generator(device=device)
|
||||
g.manual_seed(base_seed + ep_rank)
|
||||
|
||||
output_shape = (num_layers, num_logical_experts)
|
||||
chosen_index = (
|
||||
torch.randint(
|
||||
0, 65536, output_shape, dtype=torch.int32, device=device, generator=g
|
||||
)
|
||||
% logical_to_all_physical_map_num_valid
|
||||
)
|
||||
logical_to_rank_dispatch_physical_map = torch.gather(
|
||||
logical_to_all_physical_map, dim=2, index=chosen_index.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
assert logical_to_rank_dispatch_physical_map.shape == output_shape
|
||||
|
||||
for index in range(logical_to_all_physical_map_num_valid.max().item()):
|
||||
partial_logical_to_all_physical_map = logical_to_all_physical_map[:, :, index]
|
||||
is_valid = partial_logical_to_all_physical_map != -1
|
||||
is_same_gpu = (
|
||||
partial_logical_to_all_physical_map // num_local_physical_experts
|
||||
) == ep_rank
|
||||
logical_to_rank_dispatch_physical_map = torch.where(
|
||||
is_valid & is_same_gpu,
|
||||
partial_logical_to_all_physical_map,
|
||||
logical_to_rank_dispatch_physical_map,
|
||||
)
|
||||
|
||||
assert torch.all(logical_to_rank_dispatch_physical_map != -1)
|
||||
return logical_to_rank_dispatch_physical_map
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfigForExpertLocation:
|
||||
num_layers: int
|
||||
|
||||
91
python/sglang/srt/managers/expert_location_dispatch.py
Normal file
91
python/sglang/srt/managers/expert_location_dispatch.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright 2023-2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpertLocationDispatchInfo:
|
||||
ep_dispatch_algorithm: Literal["static", "random"]
|
||||
# (num_logical_experts,)
|
||||
partial_logical_to_rank_dispatch_physical_map: torch.Tensor
|
||||
# (num_logical_experts, X)
|
||||
partial_logical_to_all_physical_map: torch.Tensor
|
||||
# (num_logical_experts,)
|
||||
partial_logical_to_all_physical_map_num_valid: torch.Tensor
|
||||
num_physical_experts: int
|
||||
|
||||
@classmethod
|
||||
def init_new(cls, layer_id: int):
|
||||
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
|
||||
expert_location_metadata = get_global_expert_location_metadata()
|
||||
|
||||
if ep_dispatch_algorithm is None:
|
||||
return None
|
||||
|
||||
return cls(
|
||||
ep_dispatch_algorithm=ep_dispatch_algorithm,
|
||||
partial_logical_to_rank_dispatch_physical_map=expert_location_metadata.logical_to_rank_dispatch_physical_map[
|
||||
layer_id, :
|
||||
],
|
||||
partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[
|
||||
layer_id, :
|
||||
],
|
||||
partial_logical_to_all_physical_map_num_valid=expert_location_metadata.logical_to_all_physical_map_num_valid[
|
||||
layer_id, :
|
||||
],
|
||||
num_physical_experts=expert_location_metadata.num_physical_experts,
|
||||
)
|
||||
|
||||
|
||||
def topk_ids_logical_to_physical(
|
||||
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
|
||||
) -> torch.Tensor:
|
||||
if info is None:
|
||||
return topk_ids
|
||||
|
||||
if info.ep_dispatch_algorithm == "static":
|
||||
return _topk_ids_logical_to_physical_static(topk_ids, info)
|
||||
if info.ep_dispatch_algorithm == "dynamic":
|
||||
return _topk_ids_logical_to_physical_dynamic(topk_ids, info)
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _topk_ids_logical_to_physical_static(
|
||||
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
|
||||
) -> torch.Tensor:
|
||||
return info.partial_logical_to_rank_dispatch_physical_map[topk_ids]
|
||||
|
||||
|
||||
def _topk_ids_logical_to_physical_dynamic(
|
||||
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
|
||||
) -> torch.Tensor:
|
||||
topk_ids_original_shape = topk_ids.shape
|
||||
device = topk_ids.device
|
||||
topk_ids = topk_ids.flatten()
|
||||
|
||||
chosen_dispatch_index = (
|
||||
torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32, device=device)
|
||||
% info.partial_logical_to_all_physical_map_num_valid[topk_ids]
|
||||
)
|
||||
topk_ids = info.partial_logical_to_all_physical_map[topk_ids, chosen_dispatch_index]
|
||||
|
||||
topk_ids = topk_ids.view(topk_ids_original_shape)
|
||||
return topk_ids
|
||||
@@ -83,6 +83,7 @@ global_server_args_dict = {
|
||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
|
||||
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
|
||||
"ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm,
|
||||
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
|
||||
"sampling_backend": ServerArgs.sampling_backend,
|
||||
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# ==============================================================================
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
|
||||
import collections
|
||||
import datetime
|
||||
import gc
|
||||
import inspect
|
||||
@@ -196,6 +195,7 @@ class ModelRunner:
|
||||
"deepep_config": server_args.deepep_config,
|
||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||
"moe_dense_tp_size": server_args.moe_dense_tp_size,
|
||||
"ep_dispatch_algorithm": server_args.ep_dispatch_algorithm,
|
||||
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
||||
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
||||
"torchao_config": server_args.torchao_config,
|
||||
|
||||
@@ -80,6 +80,7 @@ from sglang.srt.managers.expert_distribution import (
|
||||
get_global_expert_distribution_recorder,
|
||||
)
|
||||
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
@@ -113,6 +114,7 @@ if _is_hip:
|
||||
decode_attention_fwd_grouped_rope,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -170,6 +170,7 @@ class ServerArgs:
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
||||
ep_dispatch_algorithm: Optional[Literal["static", "dynamic"]] = None
|
||||
init_expert_location: str = "trivial"
|
||||
expert_distribution_recorder_mode: Optional[
|
||||
Literal["stat", "per_pass", "per_token"]
|
||||
@@ -1271,6 +1272,12 @@ class ServerArgs:
|
||||
default="auto",
|
||||
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ep-dispatch-algorithm",
|
||||
type=str,
|
||||
default=ServerArgs.ep_dispatch_algorithm,
|
||||
help="The algorithm to choose ranks for redundant experts in expert parallel.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init-expert-location",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user