Support logging expert balancedness metrics (#6482)
This commit is contained in:
@@ -15,10 +15,12 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from collections import deque
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Literal, Optional, Tuple, Type
|
from typing import Dict, List, Literal, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import einops
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
@@ -472,7 +474,80 @@ class _Accumulator(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class _StatAccumulator(_Accumulator):
|
class _UtilizationRateAccumulatorMixin(_Accumulator):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self._enable = self._server_args.enable_expert_distribution_metrics
|
||||||
|
|
||||||
|
if self._enable:
|
||||||
|
window_sizes = [10, 100, 1000]
|
||||||
|
self._history = _DequeCollection(maxlens=window_sizes)
|
||||||
|
self._rank = torch.distributed.get_rank()
|
||||||
|
|
||||||
|
def append(
|
||||||
|
self,
|
||||||
|
forward_pass_id: int,
|
||||||
|
gatherer_key: str,
|
||||||
|
single_pass_data: Dict,
|
||||||
|
):
|
||||||
|
super().append(forward_pass_id, gatherer_key, single_pass_data)
|
||||||
|
if self._enable:
|
||||||
|
self._append_utilization_rate(
|
||||||
|
forward_pass_id, single_pass_data["global_physical_count"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
super().reset()
|
||||||
|
if self._enable:
|
||||||
|
self._history.clear()
|
||||||
|
|
||||||
|
def _append_utilization_rate(
|
||||||
|
self, forward_pass_id: int, single_pass_global_physical_count: torch.Tensor
|
||||||
|
):
|
||||||
|
gpu_physical_count = compute_gpu_physical_count(
|
||||||
|
single_pass_global_physical_count,
|
||||||
|
num_gpu=self._expert_location_metadata.ep_size,
|
||||||
|
)
|
||||||
|
gpu_physical_count = gpu_physical_count.to(self._server_args.device)
|
||||||
|
torch.distributed.reduce(
|
||||||
|
gpu_physical_count, dst=0, op=torch.distributed.ReduceOp.SUM
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._rank == 0:
|
||||||
|
utilization_rate_tensor = compute_utilization_rate(gpu_physical_count)
|
||||||
|
utilization_rate = torch.mean(utilization_rate_tensor).item()
|
||||||
|
self._history.append(utilization_rate)
|
||||||
|
|
||||||
|
gpu_physical_count_sum = gpu_physical_count.sum().item()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[Expert Balancedness] "
|
||||||
|
f"forward_pass_id={forward_pass_id} "
|
||||||
|
f"current_pass_balancedness={utilization_rate:.03f} "
|
||||||
|
f"{''.join(f'last_{size}_average_balancedness={value:.03f} ' for size, value in self._history.mean().items())} "
|
||||||
|
f"gpu_physical_count_sum={gpu_physical_count_sum}"
|
||||||
|
# f"current_pass_per_layer={[round(x, 2) for x in utilization_rate_tensor.cpu().tolist()]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _DequeCollection:
|
||||||
|
def __init__(self, maxlens: List[int]):
|
||||||
|
self._dequeues = [deque(maxlen=maxlen) for maxlen in maxlens]
|
||||||
|
|
||||||
|
def append(self, value):
|
||||||
|
for d in self._dequeues:
|
||||||
|
d.append(value)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
for d in self._dequeues:
|
||||||
|
d.clear()
|
||||||
|
|
||||||
|
def mean(self) -> Dict[int, float]:
|
||||||
|
return {d.maxlen: sum(d) / len(d) for d in self._dequeues}
|
||||||
|
|
||||||
|
|
||||||
|
class _StatAccumulator(_UtilizationRateAccumulatorMixin):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._global_physical_count_of_buffered_step = _Buffer.init_new(
|
self._global_physical_count_of_buffered_step = _Buffer.init_new(
|
||||||
@@ -619,3 +694,34 @@ def _convert_global_physical_count_to_logical_count(
|
|||||||
src=global_physical_count,
|
src=global_physical_count,
|
||||||
)
|
)
|
||||||
return logical_count
|
return logical_count
|
||||||
|
|
||||||
|
|
||||||
|
def compute_gpu_physical_count(
|
||||||
|
physical_count_of_whatever: torch.Tensor, # (..., num_layer, num_physical_expert)
|
||||||
|
num_gpu: int,
|
||||||
|
):
|
||||||
|
"""output: gpu_physical_count_of_batch (..., num_layer, num_gpu)"""
|
||||||
|
return einops.reduce(
|
||||||
|
physical_count_of_whatever,
|
||||||
|
"... num_layer (num_gpu num_expert_per_gpu) -> ... num_layer num_gpu",
|
||||||
|
"sum",
|
||||||
|
num_gpu=num_gpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_utilization_rate(
|
||||||
|
gpu_physical_count_of_batch: torch.Tensor, # (..., num_layer, num_gpu)
|
||||||
|
):
|
||||||
|
"""output: utilization_rate (..., num_layer)"""
|
||||||
|
gpu_physical_count_of_batch = gpu_physical_count_of_batch.float()
|
||||||
|
max_gpu_physical_count = einops.reduce(
|
||||||
|
gpu_physical_count_of_batch,
|
||||||
|
"... num_layer num_gpu -> ... num_layer",
|
||||||
|
"max",
|
||||||
|
)
|
||||||
|
avg_gpu_physical_count = einops.reduce(
|
||||||
|
gpu_physical_count_of_batch,
|
||||||
|
"... num_layer num_gpu -> ... num_layer",
|
||||||
|
"mean",
|
||||||
|
)
|
||||||
|
return (avg_gpu_physical_count + 1e-5) / (max_gpu_physical_count + 1e-5)
|
||||||
|
|||||||
@@ -177,6 +177,7 @@ class ServerArgs:
|
|||||||
Literal["stat", "per_pass", "per_token"]
|
Literal["stat", "per_pass", "per_token"]
|
||||||
] = None
|
] = None
|
||||||
expert_distribution_recorder_buffer_size: Optional[int] = None
|
expert_distribution_recorder_buffer_size: Optional[int] = None
|
||||||
|
enable_expert_distribution_metrics: bool = False
|
||||||
deepep_config: Optional[str] = None
|
deepep_config: Optional[str] = None
|
||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
torch_compile_max_bs: int = 32
|
torch_compile_max_bs: int = 32
|
||||||
@@ -1304,6 +1305,11 @@ class ServerArgs:
|
|||||||
default=ServerArgs.expert_distribution_recorder_buffer_size,
|
default=ServerArgs.expert_distribution_recorder_buffer_size,
|
||||||
help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.",
|
help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-expert-distribution-metrics",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable logging metrics for expert balancedness",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--deepep-config",
|
"--deepep-config",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
Reference in New Issue
Block a user