From 9484eba4ad127abe695e26fdd34eeaed6e08c12a Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Thu, 22 May 2025 14:05:33 +0800 Subject: [PATCH] Support logging expert balancedness metrics (#6482) --- .../srt/managers/expert_distribution.py | 108 +++++++++++++++++- python/sglang/srt/server_args.py | 6 + 2 files changed, 113 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index a4bf17c1e..db5b82a6d 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -15,10 +15,12 @@ import logging import os import time from abc import ABC +from collections import deque from contextlib import contextmanager from pathlib import Path from typing import Dict, List, Literal, Optional, Tuple, Type +import einops import torch import torch.distributed @@ -472,7 +474,80 @@ class _Accumulator(ABC): pass -class _StatAccumulator(_Accumulator): +class _UtilizationRateAccumulatorMixin(_Accumulator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._enable = self._server_args.enable_expert_distribution_metrics + + if self._enable: + window_sizes = [10, 100, 1000] + self._history = _DequeCollection(maxlens=window_sizes) + self._rank = torch.distributed.get_rank() + + def append( + self, + forward_pass_id: int, + gatherer_key: str, + single_pass_data: Dict, + ): + super().append(forward_pass_id, gatherer_key, single_pass_data) + if self._enable: + self._append_utilization_rate( + forward_pass_id, single_pass_data["global_physical_count"] + ) + + def reset(self): + super().reset() + if self._enable: + self._history.clear() + + def _append_utilization_rate( + self, forward_pass_id: int, single_pass_global_physical_count: torch.Tensor + ): + gpu_physical_count = compute_gpu_physical_count( + single_pass_global_physical_count, + num_gpu=self._expert_location_metadata.ep_size, + ) + gpu_physical_count = gpu_physical_count.to(self._server_args.device) + torch.distributed.reduce( + gpu_physical_count, dst=0, op=torch.distributed.ReduceOp.SUM + ) + + if self._rank == 0: + utilization_rate_tensor = compute_utilization_rate(gpu_physical_count) + utilization_rate = torch.mean(utilization_rate_tensor).item() + self._history.append(utilization_rate) + + gpu_physical_count_sum = gpu_physical_count.sum().item() + + logger.info( + f"[Expert Balancedness] " + f"forward_pass_id={forward_pass_id} " + f"current_pass_balancedness={utilization_rate:.03f} " + f"{''.join(f'last_{size}_average_balancedness={value:.03f} ' for size, value in self._history.mean().items())} " + f"gpu_physical_count_sum={gpu_physical_count_sum}" + # f"current_pass_per_layer={[round(x, 2) for x in utilization_rate_tensor.cpu().tolist()]}" + ) + + +class _DequeCollection: + def __init__(self, maxlens: List[int]): + self._dequeues = [deque(maxlen=maxlen) for maxlen in maxlens] + + def append(self, value): + for d in self._dequeues: + d.append(value) + + def clear(self): + for d in self._dequeues: + d.clear() + + def mean(self) -> Dict[int, float]: + return {d.maxlen: sum(d) / len(d) for d in self._dequeues} + + +class _StatAccumulator(_UtilizationRateAccumulatorMixin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._global_physical_count_of_buffered_step = _Buffer.init_new( @@ -619,3 +694,34 @@ def _convert_global_physical_count_to_logical_count( src=global_physical_count, ) return logical_count + + +def compute_gpu_physical_count( + physical_count_of_whatever: torch.Tensor, # (..., num_layer, num_physical_expert) + num_gpu: int, +): + """output: gpu_physical_count_of_batch (..., num_layer, num_gpu)""" + return einops.reduce( + physical_count_of_whatever, + "... num_layer (num_gpu num_expert_per_gpu) -> ... num_layer num_gpu", + "sum", + num_gpu=num_gpu, + ) + + +def compute_utilization_rate( + gpu_physical_count_of_batch: torch.Tensor, # (..., num_layer, num_gpu) +): + """output: utilization_rate (..., num_layer)""" + gpu_physical_count_of_batch = gpu_physical_count_of_batch.float() + max_gpu_physical_count = einops.reduce( + gpu_physical_count_of_batch, + "... num_layer num_gpu -> ... num_layer", + "max", + ) + avg_gpu_physical_count = einops.reduce( + gpu_physical_count_of_batch, + "... num_layer num_gpu -> ... num_layer", + "mean", + ) + return (avg_gpu_physical_count + 1e-5) / (max_gpu_physical_count + 1e-5) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d0f641b45..23be9bca5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -177,6 +177,7 @@ class ServerArgs: Literal["stat", "per_pass", "per_token"] ] = None expert_distribution_recorder_buffer_size: Optional[int] = None + enable_expert_distribution_metrics: bool = False deepep_config: Optional[str] = None enable_torch_compile: bool = False torch_compile_max_bs: int = 32 @@ -1304,6 +1305,11 @@ class ServerArgs: default=ServerArgs.expert_distribution_recorder_buffer_size, help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.", ) + parser.add_argument( + "--enable-expert-distribution-metrics", + action="store_true", + help="Enable logging metrics for expert balancedness", + ) parser.add_argument( "--deepep-config", type=str,