diff --git a/python/sglang/srt/eplb/eplb_manager.py b/python/sglang/srt/eplb/eplb_manager.py index 604e2c464..7db74057a 100644 --- a/python/sglang/srt/eplb/eplb_manager.py +++ b/python/sglang/srt/eplb/eplb_manager.py @@ -58,9 +58,18 @@ class EPLBManager: torch.cuda.synchronize() time_start = time.time() - logical_count = get_global_expert_distribution_recorder().dump_record( + dump_record_output = get_global_expert_distribution_recorder().dump_record( output_mode="object" - )["logical_count"] + ) + logical_count = dump_record_output["logical_count"] + average_utilization_rate_over_window = dump_record_output[ + "average_utilization_rate_over_window" + ] + + # Check whether rebalancing is needed + if not self._check_rebalance_needed(average_utilization_rate_over_window): + return + expert_location_metadata = ExpertLocationMetadata.init_by_eplb( self._server_args, self._model_runner.model_config, logical_count ) @@ -81,6 +90,21 @@ class EPLBManager: msg += f" time={time_end - time_start:.3f}s" logger.info(msg) + def _check_rebalance_needed(self, average_utilization_rate_over_window): + if average_utilization_rate_over_window is None: + return True + + if ( + average_utilization_rate_over_window + > self._server_args.eplb_min_rebalancing_utilization_threshold + ): + logger.info( + f"[EPLBManager] Skipped ep rebalancing: current GPU utilization {average_utilization_rate_over_window:.2f} > minimum rebalance threshold {self._server_args.eplb_min_rebalancing_utilization_threshold:.2f}" + ) + return False + + return True + def _compute_update_layer_ids_chunks(self) -> List[List[int]]: all_layer_ids = sorted( list(self._model_runner.model.routed_experts_weights_of_layer.keys()) diff --git a/python/sglang/srt/eplb/expert_distribution.py b/python/sglang/srt/eplb/expert_distribution.py index c4a2c38f9..1b3d573d8 100644 --- a/python/sglang/srt/eplb/expert_distribution.py +++ b/python/sglang/srt/eplb/expert_distribution.py @@ -12,6 +12,7 @@ # limitations under the License. # ============================================================================== import logging +import math import os import time from abc import ABC @@ -614,8 +615,8 @@ class _UtilizationRateAccumulatorMixin(_Accumulator): self._enable = self._server_args.enable_expert_distribution_metrics if self._enable: - window_sizes = [10, 100, 1000] - self._history = _DequeCollection(maxlens=window_sizes) + self.window_sizes = [10, 100, 1000] + self._history = _DequeCollection(maxlens=self.window_sizes) self._rank = torch.distributed.get_rank() def append( @@ -787,6 +788,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin): output = dict( rank=self._rank, logical_count=logical_count_of_buffered_step, + average_utilization_rate_over_window=self._get_global_average_utilization_rate(), ) if output_mode == "file": @@ -797,6 +799,31 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin): else: raise NotImplementedError + def _get_global_average_utilization_rate(self): + if not self._enable or math.isclose( + self._server_args.eplb_min_rebalancing_utilization_threshold, 1.0 + ): + return None + + if self._rank == 0: + utilization_mean_rates = self._history.mean() + window_index = self.window_sizes[-1] + average_utilization_rate_over_window = ( + utilization_mean_rates[window_index] + if window_index in utilization_mean_rates + else 0 + ) + + avg_rate_tensor = torch.tensor( + [average_utilization_rate_over_window], + dtype=torch.float32, + device="cuda", + ) + else: + avg_rate_tensor = torch.empty(1, dtype=torch.float32, device="cuda") + torch.distributed.broadcast(avg_rate_tensor, src=0) + return avg_rate_tensor.item() + def _dump_to_file(name, data): save_dir = Path(os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR", "/tmp")) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 68f7db4a3..8114a81aa 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -274,6 +274,7 @@ class ServerArgs: eplb_algorithm: str = "auto" eplb_rebalance_num_iterations: int = 1000 eplb_rebalance_layers_per_chunk: Optional[int] = None + eplb_min_rebalancing_utilization_threshold: float = 1.0 expert_distribution_recorder_mode: Optional[ Literal["stat", "stat_approx", "per_pass", "per_token"] ] = None @@ -1595,6 +1596,12 @@ class ServerArgs: default=ServerArgs.eplb_rebalance_layers_per_chunk, help="Number of layers to rebalance per forward pass.", ) + parser.add_argument( + "--eplb-min-rebalancing-utilization-threshold", + type=float, + default=ServerArgs.eplb_min_rebalancing_utilization_threshold, + help="Minimum threshold for GPU average utilization to trigger EPLB rebalancing. Must be in the range [0.0, 1.0].", + ) parser.add_argument( "--expert-distribution-recorder-mode", type=str,