Support loading weights when physical experts are different from logical experts (#6386)
This commit is contained in:
@@ -5,6 +5,7 @@ import torch
|
||||
from torch.nn import Module
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
try:
|
||||
@@ -425,6 +426,28 @@ class EPMoE(torch.nn.Module):
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
) -> None:
|
||||
physical_expert_ids = (
|
||||
get_global_expert_location_metadata().logical_to_all_physical(
|
||||
self.layer_id, expert_id
|
||||
)
|
||||
)
|
||||
for physical_expert_id in physical_expert_ids:
|
||||
self._weight_loader_physical(
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
weight_name=weight_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=physical_expert_id,
|
||||
)
|
||||
|
||||
def _weight_loader_physical(
|
||||
self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
) -> None:
|
||||
if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user