diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index f91b8d5a6..e32e053c0 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -5,6 +5,7 @@ import torch from torch.nn import Module from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM +from sglang.srt.managers.expert_location import get_global_expert_location_metadata from sglang.srt.managers.schedule_batch import global_server_args_dict try: @@ -425,6 +426,28 @@ class EPMoE(torch.nn.Module): weight_name: str, shard_id: str, expert_id: int, + ) -> None: + physical_expert_ids = ( + get_global_expert_location_metadata().logical_to_all_physical( + self.layer_id, expert_id + ) + ) + for physical_expert_id in physical_expert_ids: + self._weight_loader_physical( + param=param, + loaded_weight=loaded_weight, + weight_name=weight_name, + shard_id=shard_id, + expert_id=physical_expert_id, + ) + + def _weight_loader_physical( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int, ) -> None: if expert_id < self.start_expert_id or expert_id > self.end_expert_id: return diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index 44496cdde..befb3c1f4 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -15,7 +15,7 @@ import json import logging from dataclasses import dataclass from pathlib import Path -from typing import Optional +from typing import List, Optional import torch import torch.distributed @@ -163,6 +163,19 @@ class ExpertLocationMetadata: logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, ) + # -------------------------------- usage ------------------------------------ + + def logical_to_all_physical( + self, layer_id: int, logical_expert_id: int + ) -> List[int]: + return [ + physical_expert_id + for physical_expert_id in self.logical_to_all_physical_map[ + layer_id, logical_expert_id + ].tolist() + if physical_expert_id != -1 + ] + _global_expert_location_metadata: Optional[ExpertLocationMetadata] = None