Support loading weights when physical experts are different from logical experts (#6386)

This commit is contained in:
fzyzcjy
2025-05-20 12:05:53 +08:00
committed by GitHub
parent d0443275f0
commit c471d39eb9
2 changed files with 37 additions and 1 deletions

View File

@@ -5,6 +5,7 @@ import torch
from torch.nn import Module from torch.nn import Module
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
try: try:
@@ -425,6 +426,28 @@ class EPMoE(torch.nn.Module):
weight_name: str, weight_name: str,
shard_id: str, shard_id: str,
expert_id: int, expert_id: int,
) -> None:
physical_expert_ids = (
get_global_expert_location_metadata().logical_to_all_physical(
self.layer_id, expert_id
)
)
for physical_expert_id in physical_expert_ids:
self._weight_loader_physical(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=physical_expert_id,
)
def _weight_loader_physical(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None: ) -> None:
if expert_id < self.start_expert_id or expert_id > self.end_expert_id: if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
return return

View File

@@ -15,7 +15,7 @@ import json
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional from typing import List, Optional
import torch import torch
import torch.distributed import torch.distributed
@@ -163,6 +163,19 @@ class ExpertLocationMetadata:
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
) )
# -------------------------------- usage ------------------------------------
def logical_to_all_physical(
self, layer_id: int, logical_expert_id: int
) -> List[int]:
return [
physical_expert_id
for physical_expert_id in self.logical_to_all_physical_map[
layer_id, logical_expert_id
].tolist()
if physical_expert_id != -1
]
_global_expert_location_metadata: Optional[ExpertLocationMetadata] = None _global_expert_location_metadata: Optional[ExpertLocationMetadata] = None