Support logging expert balancedness metrics (#6482)
This commit is contained in:
@@ -15,10 +15,12 @@ import logging
|
||||
import os
|
||||
import time
|
||||
from abc import ABC
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Type
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
@@ -472,7 +474,80 @@ class _Accumulator(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class _StatAccumulator(_Accumulator):
|
||||
class _UtilizationRateAccumulatorMixin(_Accumulator):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._enable = self._server_args.enable_expert_distribution_metrics
|
||||
|
||||
if self._enable:
|
||||
window_sizes = [10, 100, 1000]
|
||||
self._history = _DequeCollection(maxlens=window_sizes)
|
||||
self._rank = torch.distributed.get_rank()
|
||||
|
||||
def append(
|
||||
self,
|
||||
forward_pass_id: int,
|
||||
gatherer_key: str,
|
||||
single_pass_data: Dict,
|
||||
):
|
||||
super().append(forward_pass_id, gatherer_key, single_pass_data)
|
||||
if self._enable:
|
||||
self._append_utilization_rate(
|
||||
forward_pass_id, single_pass_data["global_physical_count"]
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
if self._enable:
|
||||
self._history.clear()
|
||||
|
||||
def _append_utilization_rate(
|
||||
self, forward_pass_id: int, single_pass_global_physical_count: torch.Tensor
|
||||
):
|
||||
gpu_physical_count = compute_gpu_physical_count(
|
||||
single_pass_global_physical_count,
|
||||
num_gpu=self._expert_location_metadata.ep_size,
|
||||
)
|
||||
gpu_physical_count = gpu_physical_count.to(self._server_args.device)
|
||||
torch.distributed.reduce(
|
||||
gpu_physical_count, dst=0, op=torch.distributed.ReduceOp.SUM
|
||||
)
|
||||
|
||||
if self._rank == 0:
|
||||
utilization_rate_tensor = compute_utilization_rate(gpu_physical_count)
|
||||
utilization_rate = torch.mean(utilization_rate_tensor).item()
|
||||
self._history.append(utilization_rate)
|
||||
|
||||
gpu_physical_count_sum = gpu_physical_count.sum().item()
|
||||
|
||||
logger.info(
|
||||
f"[Expert Balancedness] "
|
||||
f"forward_pass_id={forward_pass_id} "
|
||||
f"current_pass_balancedness={utilization_rate:.03f} "
|
||||
f"{''.join(f'last_{size}_average_balancedness={value:.03f} ' for size, value in self._history.mean().items())} "
|
||||
f"gpu_physical_count_sum={gpu_physical_count_sum}"
|
||||
# f"current_pass_per_layer={[round(x, 2) for x in utilization_rate_tensor.cpu().tolist()]}"
|
||||
)
|
||||
|
||||
|
||||
class _DequeCollection:
|
||||
def __init__(self, maxlens: List[int]):
|
||||
self._dequeues = [deque(maxlen=maxlen) for maxlen in maxlens]
|
||||
|
||||
def append(self, value):
|
||||
for d in self._dequeues:
|
||||
d.append(value)
|
||||
|
||||
def clear(self):
|
||||
for d in self._dequeues:
|
||||
d.clear()
|
||||
|
||||
def mean(self) -> Dict[int, float]:
|
||||
return {d.maxlen: sum(d) / len(d) for d in self._dequeues}
|
||||
|
||||
|
||||
class _StatAccumulator(_UtilizationRateAccumulatorMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._global_physical_count_of_buffered_step = _Buffer.init_new(
|
||||
@@ -619,3 +694,34 @@ def _convert_global_physical_count_to_logical_count(
|
||||
src=global_physical_count,
|
||||
)
|
||||
return logical_count
|
||||
|
||||
|
||||
def compute_gpu_physical_count(
|
||||
physical_count_of_whatever: torch.Tensor, # (..., num_layer, num_physical_expert)
|
||||
num_gpu: int,
|
||||
):
|
||||
"""output: gpu_physical_count_of_batch (..., num_layer, num_gpu)"""
|
||||
return einops.reduce(
|
||||
physical_count_of_whatever,
|
||||
"... num_layer (num_gpu num_expert_per_gpu) -> ... num_layer num_gpu",
|
||||
"sum",
|
||||
num_gpu=num_gpu,
|
||||
)
|
||||
|
||||
|
||||
def compute_utilization_rate(
|
||||
gpu_physical_count_of_batch: torch.Tensor, # (..., num_layer, num_gpu)
|
||||
):
|
||||
"""output: utilization_rate (..., num_layer)"""
|
||||
gpu_physical_count_of_batch = gpu_physical_count_of_batch.float()
|
||||
max_gpu_physical_count = einops.reduce(
|
||||
gpu_physical_count_of_batch,
|
||||
"... num_layer num_gpu -> ... num_layer",
|
||||
"max",
|
||||
)
|
||||
avg_gpu_physical_count = einops.reduce(
|
||||
gpu_physical_count_of_batch,
|
||||
"... num_layer num_gpu -> ... num_layer",
|
||||
"mean",
|
||||
)
|
||||
return (avg_gpu_physical_count + 1e-5) / (max_gpu_physical_count + 1e-5)
|
||||
|
||||
@@ -177,6 +177,7 @@ class ServerArgs:
|
||||
Literal["stat", "per_pass", "per_token"]
|
||||
] = None
|
||||
expert_distribution_recorder_buffer_size: Optional[int] = None
|
||||
enable_expert_distribution_metrics: bool = False
|
||||
deepep_config: Optional[str] = None
|
||||
enable_torch_compile: bool = False
|
||||
torch_compile_max_bs: int = 32
|
||||
@@ -1304,6 +1305,11 @@ class ServerArgs:
|
||||
default=ServerArgs.expert_distribution_recorder_buffer_size,
|
||||
help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-expert-distribution-metrics",
|
||||
action="store_true",
|
||||
help="Enable logging metrics for expert balancedness",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--deepep-config",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user