Support dispatching logical to physical experts (#6385)

This commit is contained in:
fzyzcjy
2025-05-20 13:13:55 +08:00
committed by GitHub
parent 69af3ec35f
commit e98afbe042
9 changed files with 184 additions and 5 deletions

View File

@@ -23,9 +23,10 @@ import torch
import torch.distributed
from sglang.srt.managers.expert_location import ExpertLocationMetadata
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable
from sglang.srt.utils import Withable, get_bool_env_var
logger = logging.getLogger(__name__)

View File

@@ -33,6 +33,7 @@ class ExpertLocationMetadata:
physical_to_logical_map: torch.Tensor # (layers, num_physical_experts)
logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X)
logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts)
logical_to_rank_dispatch_physical_map: torch.Tensor # (layers, num_logical_experts)
# -------------------------------- properties ------------------------------------
@@ -67,9 +68,11 @@ class ExpertLocationMetadata:
num_layers_2, num_logical_experts_1 = (
self.logical_to_all_physical_map_num_valid.shape
)
# TODO pr-chain: enable this later
# assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
# assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
num_layers_3, num_logical_experts_2 = (
self.logical_to_rank_dispatch_physical_map.shape
)
assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3
assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2
assert num_physical_experts_0 == num_physical_experts_1
# -------------------------------- construction ------------------------------------
@@ -196,6 +199,13 @@ class ExpertLocationMetadata:
physical_to_logical_map=physical_to_logical_map,
logical_to_all_physical_map=logical_to_all_physical_map_padded,
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map(
logical_to_all_physical_map=logical_to_all_physical_map,
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
num_gpus=ep_size,
num_physical_experts=num_physical_experts,
ep_rank=torch.distributed.get_rank(),
),
)
# -------------------------------- usage ------------------------------------
@@ -262,6 +272,51 @@ def _pad_nested_array(arr, pad_value):
return padded
# TODO use more sophisticated approaches
def compute_logical_to_rank_dispatch_physical_map(
logical_to_all_physical_map: torch.Tensor,
logical_to_all_physical_map_num_valid: torch.Tensor,
num_gpus: int,
num_physical_experts: int,
ep_rank: int,
base_seed: int = 42,
):
device = logical_to_all_physical_map.device
num_local_physical_experts = num_physical_experts // num_gpus
num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape
g = torch.Generator(device=device)
g.manual_seed(base_seed + ep_rank)
output_shape = (num_layers, num_logical_experts)
chosen_index = (
torch.randint(
0, 65536, output_shape, dtype=torch.int32, device=device, generator=g
)
% logical_to_all_physical_map_num_valid
)
logical_to_rank_dispatch_physical_map = torch.gather(
logical_to_all_physical_map, dim=2, index=chosen_index.unsqueeze(-1)
).squeeze(-1)
assert logical_to_rank_dispatch_physical_map.shape == output_shape
for index in range(logical_to_all_physical_map_num_valid.max().item()):
partial_logical_to_all_physical_map = logical_to_all_physical_map[:, :, index]
is_valid = partial_logical_to_all_physical_map != -1
is_same_gpu = (
partial_logical_to_all_physical_map // num_local_physical_experts
) == ep_rank
logical_to_rank_dispatch_physical_map = torch.where(
is_valid & is_same_gpu,
partial_logical_to_all_physical_map,
logical_to_rank_dispatch_physical_map,
)
assert torch.all(logical_to_rank_dispatch_physical_map != -1)
return logical_to_rank_dispatch_physical_map
@dataclass
class ModelConfigForExpertLocation:
num_layers: int

View File

@@ -0,0 +1,91 @@
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from dataclasses import dataclass
from typing import Literal, Optional
import torch
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
from sglang.srt.managers.schedule_batch import global_server_args_dict
@dataclass
class ExpertLocationDispatchInfo:
ep_dispatch_algorithm: Literal["static", "random"]
# (num_logical_experts,)
partial_logical_to_rank_dispatch_physical_map: torch.Tensor
# (num_logical_experts, X)
partial_logical_to_all_physical_map: torch.Tensor
# (num_logical_experts,)
partial_logical_to_all_physical_map_num_valid: torch.Tensor
num_physical_experts: int
@classmethod
def init_new(cls, layer_id: int):
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
expert_location_metadata = get_global_expert_location_metadata()
if ep_dispatch_algorithm is None:
return None
return cls(
ep_dispatch_algorithm=ep_dispatch_algorithm,
partial_logical_to_rank_dispatch_physical_map=expert_location_metadata.logical_to_rank_dispatch_physical_map[
layer_id, :
],
partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[
layer_id, :
],
partial_logical_to_all_physical_map_num_valid=expert_location_metadata.logical_to_all_physical_map_num_valid[
layer_id, :
],
num_physical_experts=expert_location_metadata.num_physical_experts,
)
def topk_ids_logical_to_physical(
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
) -> torch.Tensor:
if info is None:
return topk_ids
if info.ep_dispatch_algorithm == "static":
return _topk_ids_logical_to_physical_static(topk_ids, info)
if info.ep_dispatch_algorithm == "dynamic":
return _topk_ids_logical_to_physical_dynamic(topk_ids, info)
raise NotImplementedError
def _topk_ids_logical_to_physical_static(
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
) -> torch.Tensor:
return info.partial_logical_to_rank_dispatch_physical_map[topk_ids]
def _topk_ids_logical_to_physical_dynamic(
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
) -> torch.Tensor:
topk_ids_original_shape = topk_ids.shape
device = topk_ids.device
topk_ids = topk_ids.flatten()
chosen_dispatch_index = (
torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32, device=device)
% info.partial_logical_to_all_physical_map_num_valid[topk_ids]
)
topk_ids = info.partial_logical_to_all_physical_map[topk_ids, chosen_dispatch_index]
topk_ids = topk_ids.view(topk_ids_original_shape)
return topk_ids

View File

@@ -83,6 +83,7 @@ global_server_args_dict = {
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
"ep_dispatch_algorithm": ServerArgs.ep_dispatch_algorithm,
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
"sampling_backend": ServerArgs.sampling_backend,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,