Support layerwise rebalancing experts (#6851)

This commit is contained in:
fzyzcjy
2025-06-05 15:05:52 +08:00
committed by GitHub
parent 72a110f664
commit 0de5e7d40f
6 changed files with 115 additions and 38 deletions

View File

@@ -24,6 +24,7 @@ from sglang.srt.managers.expert_location import (
ExpertLocationMetadata,
get_global_expert_location_metadata,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import get_bool_env_var
logger = logging.getLogger(__name__)
@@ -37,6 +38,7 @@ class ExpertLocationUpdater:
self,
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
new_expert_location_metadata: ExpertLocationMetadata,
update_layer_ids: List[int],
nnodes: int,
rank: int,
):
@@ -46,45 +48,47 @@ class ExpertLocationUpdater:
old_expert_location_metadata = get_global_expert_location_metadata()
_update_expert_weights(
routed_experts_weights_of_layer,
old_expert_location_metadata,
new_expert_location_metadata,
nnodes,
rank,
routed_experts_weights_of_layer=routed_experts_weights_of_layer,
old_expert_location_metadata=old_expert_location_metadata,
new_expert_location_metadata=new_expert_location_metadata,
update_layer_ids=update_layer_ids,
nnodes=nnodes,
rank=rank,
)
old_expert_location_metadata.update(
new_expert_location_metadata,
update_layer_ids=update_layer_ids,
)
old_expert_location_metadata.update(new_expert_location_metadata)
def _update_expert_weights(
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
old_expert_location_metadata: ExpertLocationMetadata,
new_expert_location_metadata: ExpertLocationMetadata,
update_layer_ids: List[int],
nnodes: int,
rank: int,
):
log_metrics = get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS")
temp_buffers = create_temp_buffers(
next(iter(routed_experts_weights_of_layer.values()))
routed_experts_weights_of_layer[update_layer_ids[0]]
)
world_size = torch.distributed.get_world_size()
num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts
num_gpu_per_node = world_size // nnodes
old_physical_to_logical_map = (
old_expert_location_metadata.physical_to_logical_map.tolist()
)
new_physical_to_logical_map = (
new_expert_location_metadata.physical_to_logical_map.tolist()
)
for layer_id in sorted(routed_experts_weights_of_layer.keys()):
for layer_id in update_layer_ids:
update_expert_weights_single_layer(
routed_experts_weights=routed_experts_weights_of_layer[layer_id],
temp_buffers=temp_buffers,
old_physical_to_logical_map=old_physical_to_logical_map[layer_id],
new_physical_to_logical_map=new_physical_to_logical_map[layer_id],
old_physical_to_logical_map=old_expert_location_metadata.physical_to_logical_map_cpu[
layer_id
].tolist(),
new_physical_to_logical_map=new_expert_location_metadata.physical_to_logical_map_cpu[
layer_id
].tolist(),
num_local_physical_experts=num_local_physical_experts,
num_gpu_per_node=num_gpu_per_node,
rank=rank,

View File

@@ -611,11 +611,14 @@ class ModelRunner:
) from None
def update_expert_location(
self, new_expert_location_metadata: ExpertLocationMetadata
self,
new_expert_location_metadata: ExpertLocationMetadata,
update_layer_ids: List[int],
):
self.expert_location_updater.update(
self.model.routed_experts_weights_of_layer,
new_expert_location_metadata,
update_layer_ids=update_layer_ids,
nnodes=self.server_args.nnodes,
rank=self.tp_rank,
)
@@ -1203,7 +1206,7 @@ class ModelRunner:
)
if self.eplb_manager is not None:
self.eplb_manager.on_forward_pass_end(self.forward_pass_id)
self.eplb_manager.on_forward_pass_end()
return output