Support logging expert balancedness metrics (#6482)

This commit is contained in:
fzyzcjy
2025-05-22 14:05:33 +08:00
committed by GitHub
parent e9feb48838
commit 9484eba4ad
2 changed files with 113 additions and 1 deletions

View File

@@ -15,10 +15,12 @@ import logging
import os
import time
from abc import ABC
from collections import deque
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, List, Literal, Optional, Tuple, Type
import einops
import torch
import torch.distributed
@@ -472,7 +474,80 @@ class _Accumulator(ABC):
pass
class _StatAccumulator(_Accumulator):
class _UtilizationRateAccumulatorMixin(_Accumulator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._enable = self._server_args.enable_expert_distribution_metrics
if self._enable:
window_sizes = [10, 100, 1000]
self._history = _DequeCollection(maxlens=window_sizes)
self._rank = torch.distributed.get_rank()
def append(
self,
forward_pass_id: int,
gatherer_key: str,
single_pass_data: Dict,
):
super().append(forward_pass_id, gatherer_key, single_pass_data)
if self._enable:
self._append_utilization_rate(
forward_pass_id, single_pass_data["global_physical_count"]
)
def reset(self):
super().reset()
if self._enable:
self._history.clear()
def _append_utilization_rate(
self, forward_pass_id: int, single_pass_global_physical_count: torch.Tensor
):
gpu_physical_count = compute_gpu_physical_count(
single_pass_global_physical_count,
num_gpu=self._expert_location_metadata.ep_size,
)
gpu_physical_count = gpu_physical_count.to(self._server_args.device)
torch.distributed.reduce(
gpu_physical_count, dst=0, op=torch.distributed.ReduceOp.SUM
)
if self._rank == 0:
utilization_rate_tensor = compute_utilization_rate(gpu_physical_count)
utilization_rate = torch.mean(utilization_rate_tensor).item()
self._history.append(utilization_rate)
gpu_physical_count_sum = gpu_physical_count.sum().item()
logger.info(
f"[Expert Balancedness] "
f"forward_pass_id={forward_pass_id} "
f"current_pass_balancedness={utilization_rate:.03f} "
f"{''.join(f'last_{size}_average_balancedness={value:.03f} ' for size, value in self._history.mean().items())} "
f"gpu_physical_count_sum={gpu_physical_count_sum}"
# f"current_pass_per_layer={[round(x, 2) for x in utilization_rate_tensor.cpu().tolist()]}"
)
class _DequeCollection:
def __init__(self, maxlens: List[int]):
self._dequeues = [deque(maxlen=maxlen) for maxlen in maxlens]
def append(self, value):
for d in self._dequeues:
d.append(value)
def clear(self):
for d in self._dequeues:
d.clear()
def mean(self) -> Dict[int, float]:
return {d.maxlen: sum(d) / len(d) for d in self._dequeues}
class _StatAccumulator(_UtilizationRateAccumulatorMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._global_physical_count_of_buffered_step = _Buffer.init_new(
@@ -619,3 +694,34 @@ def _convert_global_physical_count_to_logical_count(
src=global_physical_count,
)
return logical_count
def compute_gpu_physical_count(
physical_count_of_whatever: torch.Tensor, # (..., num_layer, num_physical_expert)
num_gpu: int,
):
"""output: gpu_physical_count_of_batch (..., num_layer, num_gpu)"""
return einops.reduce(
physical_count_of_whatever,
"... num_layer (num_gpu num_expert_per_gpu) -> ... num_layer num_gpu",
"sum",
num_gpu=num_gpu,
)
def compute_utilization_rate(
gpu_physical_count_of_batch: torch.Tensor, # (..., num_layer, num_gpu)
):
"""output: utilization_rate (..., num_layer)"""
gpu_physical_count_of_batch = gpu_physical_count_of_batch.float()
max_gpu_physical_count = einops.reduce(
gpu_physical_count_of_batch,
"... num_layer num_gpu -> ... num_layer",
"max",
)
avg_gpu_physical_count = einops.reduce(
gpu_physical_count_of_batch,
"... num_layer num_gpu -> ... num_layer",
"mean",
)
return (avg_gpu_physical_count + 1e-5) / (max_gpu_physical_count + 1e-5)