feature(eplb): add min-rebalancing-utilization-threshold for eplb (#8345)

Co-authored-by: yizhang2077 <1109276519@qq.com>
This commit is contained in:
hzh0425
2025-08-30 11:24:29 +08:00
committed by GitHub
parent 591e6c5983
commit c2a26e725c
3 changed files with 62 additions and 4 deletions

View File

@@ -12,6 +12,7 @@
# limitations under the License.
# ==============================================================================
import logging
import math
import os
import time
from abc import ABC
@@ -614,8 +615,8 @@ class _UtilizationRateAccumulatorMixin(_Accumulator):
self._enable = self._server_args.enable_expert_distribution_metrics
if self._enable:
window_sizes = [10, 100, 1000]
self._history = _DequeCollection(maxlens=window_sizes)
self.window_sizes = [10, 100, 1000]
self._history = _DequeCollection(maxlens=self.window_sizes)
self._rank = torch.distributed.get_rank()
def append(
@@ -787,6 +788,7 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
output = dict(
rank=self._rank,
logical_count=logical_count_of_buffered_step,
average_utilization_rate_over_window=self._get_global_average_utilization_rate(),
)
if output_mode == "file":
@@ -797,6 +799,31 @@ class _StatAccumulator(_UtilizationRateAccumulatorMixin):
else:
raise NotImplementedError
def _get_global_average_utilization_rate(self):
if not self._enable or math.isclose(
self._server_args.eplb_min_rebalancing_utilization_threshold, 1.0
):
return None
if self._rank == 0:
utilization_mean_rates = self._history.mean()
window_index = self.window_sizes[-1]
average_utilization_rate_over_window = (
utilization_mean_rates[window_index]
if window_index in utilization_mean_rates
else 0
)
avg_rate_tensor = torch.tensor(
[average_utilization_rate_over_window],
dtype=torch.float32,
device="cuda",
)
else:
avg_rate_tensor = torch.empty(1, dtype=torch.float32, device="cuda")
torch.distributed.broadcast(avg_rate_tensor, src=0)
return avg_rate_tensor.item()
def _dump_to_file(name, data):
save_dir = Path(os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR", "/tmp"))