Support loading weights when physical experts are different from logical experts (#6386)
This commit is contained in:
@@ -5,6 +5,7 @@ import torch
|
|||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||||
|
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -425,6 +426,28 @@ class EPMoE(torch.nn.Module):
|
|||||||
weight_name: str,
|
weight_name: str,
|
||||||
shard_id: str,
|
shard_id: str,
|
||||||
expert_id: int,
|
expert_id: int,
|
||||||
|
) -> None:
|
||||||
|
physical_expert_ids = (
|
||||||
|
get_global_expert_location_metadata().logical_to_all_physical(
|
||||||
|
self.layer_id, expert_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for physical_expert_id in physical_expert_ids:
|
||||||
|
self._weight_loader_physical(
|
||||||
|
param=param,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
weight_name=weight_name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=physical_expert_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _weight_loader_physical(
|
||||||
|
self,
|
||||||
|
param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
weight_name: str,
|
||||||
|
shard_id: str,
|
||||||
|
expert_id: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
|
if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@@ -163,6 +163,19 @@ class ExpertLocationMetadata:
|
|||||||
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# -------------------------------- usage ------------------------------------
|
||||||
|
|
||||||
|
def logical_to_all_physical(
|
||||||
|
self, layer_id: int, logical_expert_id: int
|
||||||
|
) -> List[int]:
|
||||||
|
return [
|
||||||
|
physical_expert_id
|
||||||
|
for physical_expert_id in self.logical_to_all_physical_map[
|
||||||
|
layer_id, logical_expert_id
|
||||||
|
].tolist()
|
||||||
|
if physical_expert_id != -1
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
_global_expert_location_metadata: Optional[ExpertLocationMetadata] = None
|
_global_expert_location_metadata: Optional[ExpertLocationMetadata] = None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user