Support layerwise rebalancing experts (#6851)
This commit is contained in:
@@ -24,6 +24,7 @@ from sglang.srt.managers.expert_location import (
|
||||
ExpertLocationMetadata,
|
||||
get_global_expert_location_metadata,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -37,6 +38,7 @@ class ExpertLocationUpdater:
|
||||
self,
|
||||
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
||||
new_expert_location_metadata: ExpertLocationMetadata,
|
||||
update_layer_ids: List[int],
|
||||
nnodes: int,
|
||||
rank: int,
|
||||
):
|
||||
@@ -46,45 +48,47 @@ class ExpertLocationUpdater:
|
||||
|
||||
old_expert_location_metadata = get_global_expert_location_metadata()
|
||||
_update_expert_weights(
|
||||
routed_experts_weights_of_layer,
|
||||
old_expert_location_metadata,
|
||||
new_expert_location_metadata,
|
||||
nnodes,
|
||||
rank,
|
||||
routed_experts_weights_of_layer=routed_experts_weights_of_layer,
|
||||
old_expert_location_metadata=old_expert_location_metadata,
|
||||
new_expert_location_metadata=new_expert_location_metadata,
|
||||
update_layer_ids=update_layer_ids,
|
||||
nnodes=nnodes,
|
||||
rank=rank,
|
||||
)
|
||||
old_expert_location_metadata.update(
|
||||
new_expert_location_metadata,
|
||||
update_layer_ids=update_layer_ids,
|
||||
)
|
||||
old_expert_location_metadata.update(new_expert_location_metadata)
|
||||
|
||||
|
||||
def _update_expert_weights(
|
||||
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
|
||||
old_expert_location_metadata: ExpertLocationMetadata,
|
||||
new_expert_location_metadata: ExpertLocationMetadata,
|
||||
update_layer_ids: List[int],
|
||||
nnodes: int,
|
||||
rank: int,
|
||||
):
|
||||
log_metrics = get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS")
|
||||
|
||||
temp_buffers = create_temp_buffers(
|
||||
next(iter(routed_experts_weights_of_layer.values()))
|
||||
routed_experts_weights_of_layer[update_layer_ids[0]]
|
||||
)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts
|
||||
num_gpu_per_node = world_size // nnodes
|
||||
|
||||
old_physical_to_logical_map = (
|
||||
old_expert_location_metadata.physical_to_logical_map.tolist()
|
||||
)
|
||||
new_physical_to_logical_map = (
|
||||
new_expert_location_metadata.physical_to_logical_map.tolist()
|
||||
)
|
||||
|
||||
for layer_id in sorted(routed_experts_weights_of_layer.keys()):
|
||||
for layer_id in update_layer_ids:
|
||||
update_expert_weights_single_layer(
|
||||
routed_experts_weights=routed_experts_weights_of_layer[layer_id],
|
||||
temp_buffers=temp_buffers,
|
||||
old_physical_to_logical_map=old_physical_to_logical_map[layer_id],
|
||||
new_physical_to_logical_map=new_physical_to_logical_map[layer_id],
|
||||
old_physical_to_logical_map=old_expert_location_metadata.physical_to_logical_map_cpu[
|
||||
layer_id
|
||||
].tolist(),
|
||||
new_physical_to_logical_map=new_expert_location_metadata.physical_to_logical_map_cpu[
|
||||
layer_id
|
||||
].tolist(),
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
num_gpu_per_node=num_gpu_per_node,
|
||||
rank=rank,
|
||||
|
||||
@@ -611,11 +611,14 @@ class ModelRunner:
|
||||
) from None
|
||||
|
||||
def update_expert_location(
|
||||
self, new_expert_location_metadata: ExpertLocationMetadata
|
||||
self,
|
||||
new_expert_location_metadata: ExpertLocationMetadata,
|
||||
update_layer_ids: List[int],
|
||||
):
|
||||
self.expert_location_updater.update(
|
||||
self.model.routed_experts_weights_of_layer,
|
||||
new_expert_location_metadata,
|
||||
update_layer_ids=update_layer_ids,
|
||||
nnodes=self.server_args.nnodes,
|
||||
rank=self.tp_rank,
|
||||
)
|
||||
@@ -1203,7 +1206,7 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
if self.eplb_manager is not None:
|
||||
self.eplb_manager.on_forward_pass_end(self.forward_pass_id)
|
||||
self.eplb_manager.on_forward_pass_end()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user