From 0de5e7d40f04768098d61eca1ab8f892ac0efbd0 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Thu, 5 Jun 2025 15:05:52 +0800 Subject: [PATCH] Support layerwise rebalancing experts (#6851) --- python/sglang/srt/managers/eplb_manager.py | 62 +++++++++++++++---- python/sglang/srt/managers/expert_location.py | 19 ++++-- .../model_executor/expert_location_updater.py | 38 +++++++----- .../sglang/srt/model_executor/model_runner.py | 7 ++- python/sglang/srt/server_args.py | 7 +++ test/srt/test_eplb.py | 20 +++++- 6 files changed, 115 insertions(+), 38 deletions(-) diff --git a/python/sglang/srt/managers/eplb_manager.py b/python/sglang/srt/managers/eplb_manager.py index 44fcd4555..b74b7f21e 100644 --- a/python/sglang/srt/managers/eplb_manager.py +++ b/python/sglang/srt/managers/eplb_manager.py @@ -1,6 +1,6 @@ import logging import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List import torch.cuda @@ -20,6 +20,10 @@ class EPLBManager: super().__init__() self._model_runner = model_runner self._server_args = model_runner.server_args + self._rebalance_layers_per_chunk = ( + self._server_args.eplb_rebalance_layers_per_chunk + ) + self._rebalance_num_iterations = self._server_args.eplb_rebalance_num_iterations # Otherwise, the circular buffer will contain stale data. If the case is needed, it can be implemented. assert ( @@ -31,17 +35,30 @@ class EPLBManager: get_global_expert_distribution_recorder().start_record() logger.info( - f"[EPLBManager] system started, will rebalance per {self._server_args.eplb_rebalance_num_iterations} iterations." + f"[EPLBManager] system started, will rebalance per {self._rebalance_num_iterations} iterations." ) - def on_forward_pass_end(self, forward_pass_id: int): - if forward_pass_id % self._server_args.eplb_rebalance_num_iterations == 0: - self.rebalance() + self._main_generator = self._entrypoint() + + def on_forward_pass_end(self): + next(self._main_generator) + + # can be more complex if needed + def _entrypoint(self): + while True: + for _ in range(self._rebalance_num_iterations): + yield + + yield from self.rebalance() def rebalance(self): logger.info("[EPLBManager] rebalance start") - torch.cuda.synchronize() - time_start = time.time() + + enable_timing = self._rebalance_layers_per_chunk is None + + if enable_timing: + torch.cuda.synchronize() + time_start = time.time() logical_count = get_global_expert_distribution_recorder().dump_record( output_mode="object" @@ -49,8 +66,31 @@ class EPLBManager: expert_location_metadata = ExpertLocationMetadata.init_by_eplb( self._server_args, self._model_runner.model_config, logical_count ) - self._model_runner.update_expert_location(expert_location_metadata) - torch.cuda.synchronize() - time_end = time.time() - logger.info(f"[EPLBManager] rebalance end time={time_end - time_start:.3f}s") + update_layer_ids_chunks = self._compute_update_layer_ids_chunks() + for chunk_index, update_layer_ids in enumerate(update_layer_ids_chunks): + if len(update_layer_ids_chunks) > 1: + yield + self._model_runner.update_expert_location( + expert_location_metadata, + update_layer_ids=update_layer_ids, + ) + + msg = f"[EPLBManager] rebalance end" + if enable_timing: + torch.cuda.synchronize() + time_end = time.time() + msg += f" time={time_end - time_start:.3f}s" + logger.info(msg) + + def _compute_update_layer_ids_chunks(self) -> List[List[int]]: + all_layer_ids = sorted( + list(self._model_runner.model.routed_experts_weights_of_layer.keys()) + ) + chunk_size = self._rebalance_layers_per_chunk or 1000000 + return list(_chunk_list(all_layer_ids, chunk_size=chunk_size)) + + +def _chunk_list(items: List, chunk_size): + for start_index in range(0, len(items), chunk_size): + yield items[start_index : start_index + chunk_size] diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py index ea4c67a54..13ba9849e 100644 --- a/python/sglang/srt/managers/expert_location.py +++ b/python/sglang/srt/managers/expert_location.py @@ -33,6 +33,7 @@ logger = logging.getLogger(__name__) @dataclass class ExpertLocationMetadata: physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) + physical_to_logical_map_cpu: torch.Tensor logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts) # (layers, num_logical_experts) @@ -203,6 +204,7 @@ class ExpertLocationMetadata: return ExpertLocationMetadata( physical_to_logical_map=physical_to_logical_map, + physical_to_logical_map_cpu=physical_to_logical_map.cpu(), logical_to_all_physical_map=logical_to_all_physical_map_padded, logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, logical_to_rank_dispatch_physical_map=( @@ -223,6 +225,7 @@ class ExpertLocationMetadata: def update( self, other: "ExpertLocationMetadata", + update_layer_ids: List[int], ): for field in [ "ep_size", @@ -231,15 +234,21 @@ class ExpertLocationMetadata: for field in [ "physical_to_logical_map", + "physical_to_logical_map_cpu", "logical_to_all_physical_map", "logical_to_all_physical_map_num_valid", "logical_to_rank_dispatch_physical_map", ]: - src = getattr(other, field) - dst = getattr(self, field) - assert (src is not None) == (dst is not None) - if dst is not None: - dst[...] = src + other_field = getattr(other, field) + self_field = getattr(self, field) + assert (other_field is not None) == (self_field is not None) + if self_field is not None: + mask_update = torch.tensor( + [i in update_layer_ids for i in range(self.num_layers)] + ) + mask_update = mask_update.view(*([-1] + [1] * (self_field.dim() - 1))) + mask_update = mask_update.to(self_field.device, non_blocking=True) + self_field[...] = torch.where(mask_update, other_field, self_field) # -------------------------------- usage ------------------------------------ diff --git a/python/sglang/srt/model_executor/expert_location_updater.py b/python/sglang/srt/model_executor/expert_location_updater.py index fff049b41..52cc13524 100644 --- a/python/sglang/srt/model_executor/expert_location_updater.py +++ b/python/sglang/srt/model_executor/expert_location_updater.py @@ -24,6 +24,7 @@ from sglang.srt.managers.expert_location import ( ExpertLocationMetadata, get_global_expert_location_metadata, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import get_bool_env_var logger = logging.getLogger(__name__) @@ -37,6 +38,7 @@ class ExpertLocationUpdater: self, routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]], new_expert_location_metadata: ExpertLocationMetadata, + update_layer_ids: List[int], nnodes: int, rank: int, ): @@ -46,45 +48,47 @@ class ExpertLocationUpdater: old_expert_location_metadata = get_global_expert_location_metadata() _update_expert_weights( - routed_experts_weights_of_layer, - old_expert_location_metadata, - new_expert_location_metadata, - nnodes, - rank, + routed_experts_weights_of_layer=routed_experts_weights_of_layer, + old_expert_location_metadata=old_expert_location_metadata, + new_expert_location_metadata=new_expert_location_metadata, + update_layer_ids=update_layer_ids, + nnodes=nnodes, + rank=rank, + ) + old_expert_location_metadata.update( + new_expert_location_metadata, + update_layer_ids=update_layer_ids, ) - old_expert_location_metadata.update(new_expert_location_metadata) def _update_expert_weights( routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]], old_expert_location_metadata: ExpertLocationMetadata, new_expert_location_metadata: ExpertLocationMetadata, + update_layer_ids: List[int], nnodes: int, rank: int, ): log_metrics = get_bool_env_var("SGLANG_EXPERT_LOCATION_UPDATER_LOG_METRICS") temp_buffers = create_temp_buffers( - next(iter(routed_experts_weights_of_layer.values())) + routed_experts_weights_of_layer[update_layer_ids[0]] ) world_size = torch.distributed.get_world_size() num_local_physical_experts = old_expert_location_metadata.num_local_physical_experts num_gpu_per_node = world_size // nnodes - old_physical_to_logical_map = ( - old_expert_location_metadata.physical_to_logical_map.tolist() - ) - new_physical_to_logical_map = ( - new_expert_location_metadata.physical_to_logical_map.tolist() - ) - - for layer_id in sorted(routed_experts_weights_of_layer.keys()): + for layer_id in update_layer_ids: update_expert_weights_single_layer( routed_experts_weights=routed_experts_weights_of_layer[layer_id], temp_buffers=temp_buffers, - old_physical_to_logical_map=old_physical_to_logical_map[layer_id], - new_physical_to_logical_map=new_physical_to_logical_map[layer_id], + old_physical_to_logical_map=old_expert_location_metadata.physical_to_logical_map_cpu[ + layer_id + ].tolist(), + new_physical_to_logical_map=new_expert_location_metadata.physical_to_logical_map_cpu[ + layer_id + ].tolist(), num_local_physical_experts=num_local_physical_experts, num_gpu_per_node=num_gpu_per_node, rank=rank, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index adabc897f..30f6a7929 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -611,11 +611,14 @@ class ModelRunner: ) from None def update_expert_location( - self, new_expert_location_metadata: ExpertLocationMetadata + self, + new_expert_location_metadata: ExpertLocationMetadata, + update_layer_ids: List[int], ): self.expert_location_updater.update( self.model.routed_experts_weights_of_layer, new_expert_location_metadata, + update_layer_ids=update_layer_ids, nnodes=self.server_args.nnodes, rank=self.tp_rank, ) @@ -1203,7 +1206,7 @@ class ModelRunner: ) if self.eplb_manager is not None: - self.eplb_manager.on_forward_pass_end(self.forward_pass_id) + self.eplb_manager.on_forward_pass_end() return output diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b3557a472..97229b7f1 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -180,6 +180,7 @@ class ServerArgs: enable_eplb: bool = False eplb_algorithm: str = "auto" eplb_rebalance_num_iterations: int = 1000 + eplb_rebalance_layers_per_chunk: Optional[int] = None expert_distribution_recorder_mode: Optional[ Literal["stat", "per_pass", "per_token"] ] = None @@ -1367,6 +1368,12 @@ class ServerArgs: default=ServerArgs.eplb_rebalance_num_iterations, help="Number of iterations to automatically trigger a EPLB re-balance.", ) + parser.add_argument( + "--eplb-rebalance-layers-per-chunk", + type=int, + default=ServerArgs.eplb_rebalance_layers_per_chunk, + help="Number of layers to rebalance per forward pass.", + ) parser.add_argument( "--expert-distribution-recorder-mode", type=str, diff --git a/test/srt/test_eplb.py b/test/srt/test_eplb.py index f9c6fad20..c7eacc949 100755 --- a/test/srt/test_eplb.py +++ b/test/srt/test_eplb.py @@ -5,7 +5,6 @@ from pathlib import Path from types import SimpleNamespace import sglang as sgl -from sglang.srt.managers.expert_distribution_storage import ExpertDistributionStorage from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( @@ -17,7 +16,9 @@ from sglang.test.test_utils import ( ) -class TestDynamicEPLB(CustomTestCase): +class _BaseTestDynamicEPLB(CustomTestCase): + extra_args = [] + @classmethod def setUpClass(cls): cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST @@ -51,8 +52,13 @@ class TestDynamicEPLB(CustomTestCase): "stat", "--ep-dispatch-algorithm", "static", + *cls.extra_args, ], - env={"SGL_ENABLE_JIT_DEEPGEMM": "0", **os.environ}, + env={ + "SGL_ENABLE_JIT_DEEPGEMM": "0", + "SGLANG_EXPERT_LOCATION_UPDATER_CANARY": "1", + **os.environ, + }, ) @classmethod @@ -72,6 +78,14 @@ class TestDynamicEPLB(CustomTestCase): self.assertGreater(metrics["score"], 0.5) +class TestDynamicEPLBSimple(_BaseTestDynamicEPLB): + pass + + +class TestDynamicEPLBMultiChunk(_BaseTestDynamicEPLB): + extra_args = ["--eplb-rebalance-layers-per-chunk", "1"] + + class TestStaticEPLB(CustomTestCase): def test_save_expert_distribution_and_init_expert_location(self): os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "0"