Files
sglang/python/sglang/srt/eplb/expert_distribution.py

967 lines
32 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations
import logging
import math
import os
import time
from abc import ABC
from collections import deque
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
import einops
import torch
import torch.distributed
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable, get_bool_env_var, is_npu
_is_npu = is_npu()
if TYPE_CHECKING:
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
logger = logging.getLogger(__name__)
# --------------------------------------- Entrypoint -----------------------------------------
_OutputMode = Literal["file", "object"]
class ExpertDistributionRecorder(ABC):
"""Global expert distribution recording"""
@staticmethod
def init_new(
server_args: ServerArgs,
expert_location_metadata: ExpertLocationMetadata,
rank: int,
):
if server_args.expert_distribution_recorder_mode is not None:
assert (
expert_location_metadata is not None
), "ExpertLocationMetadata is required for expert distribution recording. One possible"
"reason is that you are using a model that does not support expert distribution"
"recording. Try setting `get_model_config_for_expert_location` in your model."
return _ExpertDistributionRecorderReal(
server_args, expert_location_metadata, rank
)
else:
return _ExpertDistributionRecorderNoop()
@contextmanager
def with_current_layer(self, layer_idx):
yield
@contextmanager
def with_debug_name(self, debug_name):
yield
@contextmanager
def disable_this_region(self):
yield
@contextmanager
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
yield
def on_select_experts(self, topk_ids: torch.Tensor):
pass
def on_deepep_dispatch_normal(
self,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
pass
def on_deepep_dispatch_low_latency(
self, local_physical_count_of_layer: torch.Tensor
):
pass
def start_record(self):
self._on_not_implemented()
def stop_record(self):
self._on_not_implemented()
def dump_record(self, output_mode: _OutputMode = "file"):
self._on_not_implemented()
@property
def recording(self):
return False
def _on_not_implemented(self):
raise Exception(
"Please set ServerArgs.expert_distribution_recorder_mode to use ExpertDistributionRecorder."
)
class _ExpertDistributionRecorderNoop(ExpertDistributionRecorder):
pass
class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
def __init__(
self,
server_args: ServerArgs,
expert_location_metadata: ExpertLocationMetadata,
rank: int,
):
self._server_args = server_args
self._expert_location_metadata = expert_location_metadata
self._recording = False
self._disable_all = False
self._current_forward_pass_id = Withable()
self._current_layer_idx = Withable()
self._current_debug_name = Withable()
self._accumulator = _Accumulator.init_new(
server_args, expert_location_metadata, rank
)
self._single_pass_gatherers = {
k: _SinglePassGatherer.init_new(server_args, expert_location_metadata, rank)
for k in self._accumulator.get_single_pass_gatherer_keys()
}
if server_args.enable_expert_distribution_metrics:
logger.info(
"ExpertDistributionRecorder auto start record since enable_expert_distribution_metrics"
)
self.start_record()
def with_current_layer(self, layer_idx):
return self._current_layer_idx.with_value(layer_idx)
def with_debug_name(self, debug_name):
return self._current_debug_name.with_value(debug_name)
@contextmanager
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
with self._current_forward_pass_id.with_value(forward_pass_id):
self._on_forward_pass_start(forward_batch)
try:
yield
finally:
self._on_forward_pass_end(forward_pass_id)
@contextmanager
def disable_this_region(self):
"""Context manager to temporarily disable recording."""
previous_disable_all = self._disable_all
self._disable_all = True
try:
yield
finally:
self._disable_all = previous_disable_all
def _on_forward_pass_start(self, forward_batch: ForwardBatch):
if not self._recording:
return
for gatherer_key, gatherer in self._single_pass_gatherers.items():
gatherer.reset()
gatherer.on_forward_pass_start(forward_batch)
def _on_forward_pass_end(self, forward_pass_id: int):
if not self._recording:
return
for gatherer_key, gatherer in self._single_pass_gatherers.items():
single_pass_data = gatherer.collect()
self._accumulator.append(forward_pass_id, gatherer_key, single_pass_data)
def on_select_experts(self, topk_ids: torch.Tensor):
self._on_hook("on_select_experts", topk_ids=topk_ids)
def on_deepep_dispatch_normal(
self,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
self._on_hook(
"on_deepep_dispatch_normal",
local_physical_count_of_layer=local_physical_count_of_layer,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
num_tokens_per_expert=num_tokens_per_expert,
)
def on_deepep_dispatch_low_latency(
self, local_physical_count_of_layer: torch.Tensor
):
self._on_hook(
"on_deepep_dispatch_low_latency",
local_physical_count_of_layer=local_physical_count_of_layer,
)
def _on_hook(self, hook_name: str, **kwargs):
if self._disable_all:
return
if not (
self._recording or torch.get_device_module().is_current_stream_capturing()
):
return
gatherer = self._single_pass_gatherers[
self._accumulator.get_single_pass_gatherer_key(
self._current_debug_name.value
)
]
getattr(gatherer, hook_name)(layer_idx=self._current_layer_idx.value, **kwargs)
def _reset(self):
"""Reset the expert distribution recorder."""
logger.info("Resetting ExpertDistributionRecorder...")
assert (
self._current_layer_idx.value is None
), f"{self._current_layer_idx.value=}"
for gatherer in self._single_pass_gatherers.values():
gatherer.reset()
self._accumulator.reset()
def start_record(self):
"""Start recording the expert distribution."""
if self._recording:
logger.warning(
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
)
self._reset()
self._recording = True
def stop_record(self):
"""Stop recording the expert distribution."""
if not self._recording:
logger.warning(
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
)
self._recording = False
def dump_record(self, output_mode: _OutputMode = "file"):
"""Dump the expert distribution record and reset the recorder after dumping."""
output = self._accumulator.dump(output_mode=output_mode)
self._reset()
return output
@property
def recording(self):
return self._recording
_global_expert_distribution_recorder: Optional[ExpertDistributionRecorder] = (
_ExpertDistributionRecorderNoop()
)
def get_global_expert_distribution_recorder():
return _global_expert_distribution_recorder
def set_global_expert_distribution_recorder(value):
global _global_expert_distribution_recorder
_global_expert_distribution_recorder = value
# --------------------------------------- SinglePassGatherer -----------------------------------------
class _SinglePassGatherer(ABC):
@staticmethod
def init_new(
server_args: ServerArgs,
expert_location_metadata: ExpertLocationMetadata,
rank: int,
) -> "_SinglePassGatherer":
if server_args.expert_distribution_recorder_mode == "per_token":
return _DetailSinglePassGatherer(
server_args, expert_location_metadata, rank
)
if server_args.expert_distribution_recorder_mode == "stat_approx":
if server_args.moe_a2a_backend != "none" and (
server_args.deepep_mode == "normal"
):
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
else:
raise NotImplementedError
if server_args.moe_a2a_backend != "none":
if server_args.deepep_mode == "normal":
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
elif server_args.deepep_mode == "low_latency":
return _DeepepLowLatencySinglePassGatherer(
expert_location_metadata, rank
)
else:
raise NotImplementedError
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
def __init__(self, expert_location_metadata: ExpertLocationMetadata, rank: int):
self._expert_location_metadata = expert_location_metadata
self._rank = rank
def on_forward_pass_start(self, forward_batch: ForwardBatch):
pass
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
pass
def on_deepep_dispatch_normal(
self,
layer_idx: int,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
pass
def on_deepep_dispatch_low_latency(
self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
):
pass
def reset(self):
raise NotImplementedError
def collect(self) -> Dict:
raise NotImplementedError
class _DetailSinglePassGatherer(_SinglePassGatherer):
# DeepSeek V3 has this value; should generalize later
_TOP_K_NUM = 8
def __init__(
self,
server_args: ServerArgs,
expert_location_metadata: ExpertLocationMetadata,
rank: int,
):
super().__init__(expert_location_metadata, rank)
self._metadata: Optional[Dict[str, Any]] = None
self._topk_ids_of_layer = torch.zeros(
(
expert_location_metadata.num_layers,
# TODO determine the max number
server_args.chunked_prefill_size * 8,
self._TOP_K_NUM,
),
dtype=torch.int32,
device=server_args.device,
)
self._misc_objects: List[Dict[str, Any]] = []
assert (
not server_args.enable_two_batch_overlap
), "DetailSinglePassGatherer does not support TBO yet"
# TODO assert shared experts fusion is disabled, o/w data is wrong
def on_forward_pass_start(self, forward_batch: ForwardBatch):
assert self._metadata is None
self._metadata = dict(
# TODO pr-chain
# rids=forward_batch.rids,
input_ids=forward_batch.input_ids.cpu().tolist(),
positions=forward_batch.positions.cpu().tolist(),
extend_seq_lens=forward_batch.extend_seq_lens_cpu,
forward_mode=forward_batch.forward_mode.value,
)
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
self._topk_ids_of_layer[layer_idx, : topk_ids.shape[0], : topk_ids.shape[1]] = (
topk_ids
)
def on_deepep_dispatch_normal(
self,
layer_idx: int,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
self._misc_objects.append(
dict(
layer_id=layer_idx,
num_tokens_per_rank=num_tokens_per_rank.cpu().tolist(),
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank.cpu().tolist(),
num_tokens_per_expert=num_tokens_per_expert.cpu().tolist(),
)
)
def reset(self):
self._topk_ids_of_layer[...] = -1
self._misc_objects.clear()
self._metadata = None
def collect(self) -> Dict:
num_tokens = len(self._metadata["input_ids"])
return dict(
**self._metadata,
topk_ids_of_layer=self._topk_ids_of_layer[:, :num_tokens, :].clone().cpu(),
misc_objects=self._misc_objects,
)
class _LayerBasedCpuSinglePassGatherer(_SinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._objects_of_layer = {}
def _on_layer_data(self, layer_idx: int, objects: List[int]):
assert 0 <= layer_idx < self._expert_location_metadata.num_layers
if layer_idx in self._objects_of_layer:
self._objects_of_layer[layer_idx] = _list_sum(
self._objects_of_layer[layer_idx], objects
)
else:
self._objects_of_layer[layer_idx] = objects
def reset(self):
self._objects_of_layer.clear()
def _collect_objects(self, pad_len: int) -> torch.Tensor:
data = [
self._objects_of_layer.get(layer_index) or ([0] * pad_len)
for layer_index in range(self._expert_location_metadata.num_layers)
]
return torch.tensor(data)
def _list_sum(a: List, b: List) -> List:
return [x + y for x, y in zip(a, b, strict=True)]
class _LayerBasedGpuSinglePassGatherer(_SinglePassGatherer):
def __init__(self, *args, enable_global_physical_experts: bool, **kwargs):
super().__init__(*args, **kwargs)
if not _is_npu:
device = "cuda"
else:
device = "npu"
self._enable_global_physical_experts = enable_global_physical_experts
self._data = torch.zeros(
(
self._expert_location_metadata.num_layers,
(
self._expert_location_metadata.num_physical_experts
if enable_global_physical_experts
else self._expert_location_metadata.num_local_physical_experts
),
),
dtype=torch.int,
device=device,
)
def reset(self):
self._data[...] = 0
def collect(self) -> Dict:
if self._enable_global_physical_experts:
global_physical_count = self._data
else:
# Can optimize if bottleneck
global_physical_count = _convert_local_to_global_physical_count(
self._data,
rank=self._rank,
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
num_physical_experts=self._expert_location_metadata.num_physical_experts,
)
return dict(global_physical_count=global_physical_count)
class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, enable_global_physical_experts=True)
# can optimize (e.g. fuse / compile)
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
topk_ids = topk_ids.flatten()
mask = topk_ids != -1
self._data[layer_idx, :].scatter_add_(
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
)
class _DeepepNormalSinglePassGatherer(_LayerBasedCpuSinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if torch.distributed.get_rank() == 0:
logger.info(
"DeepepNormalSinglePassGatherer gathers approximate statistics. "
"If used with small batch size, consider using expert_distribution_recorder_mode=stat."
)
def on_deepep_dispatch_normal(
self,
layer_idx: int,
local_physical_count_of_layer: List[int],
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
):
assert isinstance(local_physical_count_of_layer, list)
self._on_layer_data(layer_idx, local_physical_count_of_layer)
def collect(self) -> Dict:
local_physical_count = super()._collect_objects(
pad_len=self._expert_location_metadata.num_local_physical_experts
)
global_physical_count = _convert_local_to_global_physical_count(
local_physical_count,
rank=self._rank,
num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts,
num_physical_experts=self._expert_location_metadata.num_physical_experts,
)
return dict(global_physical_count=global_physical_count)
class _DeepepLowLatencySinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, enable_global_physical_experts=False)
def on_deepep_dispatch_low_latency(
self, layer_idx: int, local_physical_count_of_layer: torch.Tensor
):
# Most naive implementation, can optimize later
self._data[layer_idx, :] += local_physical_count_of_layer
def _convert_local_to_global_physical_count(
local_physical_count: torch.Tensor,
rank: int,
num_local_physical_experts: int,
num_physical_experts: int,
) -> torch.Tensor:
dtype = local_physical_count.dtype
device = local_physical_count.device
num_layers, _ = local_physical_count.shape
ans = torch.zeros((num_layers, num_physical_experts), dtype=dtype, device=device)
ans[
:, num_local_physical_experts * rank : num_local_physical_experts * (rank + 1)
] = local_physical_count
return ans
# --------------------------------------- Accumulator -----------------------------------------
_SINGLE_PASS_GATHERER_KEY_PRIMARY = "primary"
class _Accumulator(ABC):
@staticmethod
def init_new(
server_args: ServerArgs,
expert_location_metadata: ExpertLocationMetadata,
rank: int,
) -> "_Accumulator":
return _Accumulator.get_class(server_args)(
server_args, expert_location_metadata, rank
)
@staticmethod
def get_class(server_args: ServerArgs) -> Type["_Accumulator"]:
return {
"stat": _StatAccumulator,
"stat_approx": _StatAccumulator,
"per_pass": _DetailAccumulator,
"per_token": _DetailAccumulator,
}[server_args.expert_distribution_recorder_mode]
def __init__(
self,
server_args: ServerArgs,
expert_location_metadata: ExpertLocationMetadata,
rank: int,
):
self._server_args = server_args
self._expert_location_metadata = expert_location_metadata
self._rank = rank
def get_single_pass_gatherer_keys(self):
return [_SINGLE_PASS_GATHERER_KEY_PRIMARY]
def get_single_pass_gatherer_key(self, debug_name: Optional[str]):
return _SINGLE_PASS_GATHERER_KEY_PRIMARY
def append(
self,
forward_pass_id: int,
gatherer_key: str,
single_pass_data: Dict,
):
pass
def reset(self):
pass
def dump(self, output_mode: _OutputMode):
pass
class _UtilizationRateAccumulatorMixin(_Accumulator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._enable = self._server_args.enable_expert_distribution_metrics
if self._enable:
self.window_sizes = [10, 100, 1000]
self._history = _DequeCollection(maxlens=self.window_sizes)
self._rank = torch.distributed.get_rank()
def append(
self,
forward_pass_id: int,
gatherer_key: str,
single_pass_data: Dict,
):
super().append(forward_pass_id, gatherer_key, single_pass_data)
if self._enable:
self._append_utilization_rate(
forward_pass_id, single_pass_data["global_physical_count"]
)
def reset(self):
super().reset()
if self._enable:
self._history.clear()
def _append_utilization_rate(
self, forward_pass_id: int, single_pass_global_physical_count: torch.Tensor
):
gpu_physical_count = compute_gpu_physical_count(
single_pass_global_physical_count,
num_gpu=self._expert_location_metadata.ep_size,
)
gpu_physical_count = gpu_physical_count.to(self._server_args.device)
torch.distributed.reduce(
gpu_physical_count, dst=0, op=torch.distributed.ReduceOp.SUM
)
if self._rank == 0:
utilization_rate_tensor = compute_utilization_rate(gpu_physical_count)
utilization_rate = torch.mean(utilization_rate_tensor).item()
self._history.append(utilization_rate)
gpu_physical_count_sum = gpu_physical_count.sum().item()
logger.info(
f"[Expert Balancedness] "
f"forward_pass_id={forward_pass_id} "
f"current_pass_balancedness={utilization_rate:.03f} "
f"{''.join(f'last_{size}_average_balancedness={value:.03f} ' for size, value in self._history.mean().items())} "
f"gpu_physical_count_sum={gpu_physical_count_sum}"
# f"current_pass_per_layer={[round(x, 2) for x in utilization_rate_tensor.cpu().tolist()]}"
)
class _DequeCollection:
def __init__(self, maxlens: List[int]):
self._dequeues = [deque(maxlen=maxlen) for maxlen in maxlens]
def append(self, value):
for d in self._dequeues:
d.append(value)
def clear(self):
for d in self._dequeues:
d.clear()
def mean(self) -> Dict[int, float]:
return {d.maxlen: sum(d) / len(d) for d in self._dequeues}
class _DetailAccumulator(_UtilizationRateAccumulatorMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._records = []
def get_single_pass_gatherer_keys(self):
if False: # TODO `server_args.enable_two_batch_overlap`
return [_SINGLE_PASS_GATHERER_KEY_PRIMARY, "child_a", "child_b"]
return super().get_single_pass_gatherer_keys()
def get_single_pass_gatherer_key(self, debug_name: Optional[str]):
if False: # TODO `server_args.enable_two_batch_overlap`
return debug_name or _SINGLE_PASS_GATHERER_KEY_PRIMARY
return super().get_single_pass_gatherer_key(debug_name)
def append(
self,
forward_pass_id: int,
gatherer_key: str,
single_pass_data: Dict,
):
super().append(forward_pass_id, gatherer_key, single_pass_data)
def _process_object(obj):
if isinstance(obj, torch.Tensor):
return obj.cpu().clone()
return obj
single_pass_data_processed = {
k: _process_object(v) for k, v in single_pass_data.items()
}
self._records.append(
dict(
forward_pass_id=forward_pass_id,
rank=self._rank,
gatherer_key=gatherer_key,
**single_pass_data_processed,
)
)
def reset(self):
super().reset()
self._records.clear()
def dump(self, output_mode: _OutputMode):
assert output_mode == "file"
output = dict(
records=self._records,
# NOTE: This may change during recording, so here we say it is the "last" one
last_physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map,
)
_dump_to_file(
f"expert_distribution_recorder_{time.time()}_{self._rank}.pt", output
)
class _StatAccumulator(_UtilizationRateAccumulatorMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._global_physical_count_of_buffered_step = _Buffer.init_new(
item_shape=(
self._expert_location_metadata.num_layers,
# Cannot use local_physical_count to support select_experts
self._expert_location_metadata.num_physical_experts,
),
buffer_size=self._server_args.expert_distribution_recorder_buffer_size,
dtype=torch.int32,
device=self._server_args.device,
)
self._first_dump = True
def append(
self,
forward_pass_id: int,
gatherer_key: str,
single_pass_data: Dict,
):
super().append(forward_pass_id, gatherer_key, single_pass_data)
# Can optimize if overhead here is large
self._global_physical_count_of_buffered_step.append(
single_pass_data["global_physical_count"]
)
def reset(self):
super().reset()
self._global_physical_count_of_buffered_step.reset()
def dump(self, output_mode: _OutputMode):
logical_count_of_buffered_step = _convert_global_physical_count_to_logical_count(
self._global_physical_count_of_buffered_step.get_all(),
num_layers=self._expert_location_metadata.num_layers,
num_logical_experts=self._expert_location_metadata.num_logical_experts,
physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map,
)
if self._first_dump:
self._first_dump = False
torch.get_device_module().empty_cache()
torch.distributed.all_reduce(
logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM
)
output = dict(
rank=self._rank,
logical_count=logical_count_of_buffered_step,
average_utilization_rate_over_window=self._get_global_average_utilization_rate(),
)
if output_mode == "file":
if self._rank == 0:
_dump_to_file(f"expert_distribution_recorder_{time.time()}.pt", output)
elif output_mode == "object":
return output
else:
raise NotImplementedError
def _get_global_average_utilization_rate(self):
if not self._enable or math.isclose(
self._server_args.eplb_min_rebalancing_utilization_threshold, 1.0
):
return None
if self._rank == 0:
utilization_mean_rates = self._history.mean()
window_index = self.window_sizes[-1]
average_utilization_rate_over_window = (
utilization_mean_rates[window_index]
if window_index in utilization_mean_rates
else 0
)
avg_rate_tensor = torch.tensor(
[average_utilization_rate_over_window],
dtype=torch.float32,
device="cuda",
)
else:
avg_rate_tensor = torch.empty(1, dtype=torch.float32, device="cuda")
torch.distributed.broadcast(avg_rate_tensor, src=0)
return avg_rate_tensor.item()
def _dump_to_file(name, data):
save_dir = Path(os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR", "/tmp"))
path_output = save_dir / name
logger.info(f"Write expert distribution to {path_output}")
if not save_dir.exists():
save_dir.mkdir(parents=True, exist_ok=True)
torch.save(data, str(path_output))
class _Buffer:
@staticmethod
def init_new(item_shape: Tuple, buffer_size: int, dtype, device):
if buffer_size < 0:
return _InfiniteBuffer(item_shape, dtype=dtype, device=device)
else:
return _CircularBuffer(item_shape, buffer_size, dtype=dtype, device=device)
def append(self, value: torch.Tensor):
raise NotImplementedError
def get_all(self) -> torch.Tensor:
raise NotImplementedError
def reset(self):
raise NotImplementedError
class _CircularBuffer(_Buffer):
def __init__(self, item_shape: Tuple, buffer_size: int, dtype, device):
self._buffer = torch.zeros(
(buffer_size, *item_shape), dtype=dtype, device=device
)
self._curr_index = 0
def append(self, value: torch.Tensor):
self._buffer[self._curr_index] = value
self._curr_index = (self._curr_index + 1) % len(self._buffer)
def get_all(self) -> torch.Tensor:
return self._buffer
def reset(self):
self._buffer[...] = 0
class _InfiniteBuffer(_Buffer):
def __init__(self, item_shape: Tuple, dtype, device):
self._item_shape = item_shape
self._buffer = torch.zeros((128, *item_shape), dtype=dtype, device=device)
self._size = 0
def append(self, value: torch.Tensor):
curr_buffer_size = len(self._buffer)
dtype = self._buffer.dtype
device = self._buffer.device
if self._size == curr_buffer_size:
new_buffer = torch.zeros(
(2 * curr_buffer_size, *self._item_shape), dtype=dtype, device=device
)
new_buffer[:curr_buffer_size] = self._buffer
self._buffer = new_buffer
self._buffer[self._size] = value
self._size += 1
def get_all(self) -> torch.Tensor:
return self._buffer[: self._size]
def reset(self):
self._buffer[...] = 0
self._size = 0
def _convert_global_physical_count_to_logical_count(
# (whatever, num_layers, num_physical_experts)
global_physical_count: torch.Tensor,
num_layers: int,
num_logical_experts: int,
physical_to_logical_map: torch.Tensor,
):
dim_extra, _, _ = global_physical_count.shape
dtype = global_physical_count.dtype
device = global_physical_count.device
logical_count = torch.zeros(
(dim_extra, num_layers, num_logical_experts), dtype=dtype, device=device
)
logical_count.scatter_add_(
dim=2,
index=physical_to_logical_map.unsqueeze(0)
.expand(dim_extra, -1, -1)
.to(torch.int64),
src=global_physical_count,
)
return logical_count
def compute_gpu_physical_count(
physical_count_of_whatever: torch.Tensor, # (..., num_layer, num_physical_expert)
num_gpu: int,
):
"""output: gpu_physical_count_of_batch (..., num_layer, num_gpu)"""
return einops.reduce(
physical_count_of_whatever,
"... num_layer (num_gpu num_expert_per_gpu) -> ... num_layer num_gpu",
"sum",
num_gpu=num_gpu,
)
def compute_utilization_rate(
gpu_physical_count_of_batch: torch.Tensor, # (..., num_layer, num_gpu)
):
"""output: utilization_rate (..., num_layer)"""
gpu_physical_count_of_batch = gpu_physical_count_of_batch.float()
max_gpu_physical_count = einops.reduce(
gpu_physical_count_of_batch,
"... num_layer num_gpu -> ... num_layer",
"max",
)
avg_gpu_physical_count = einops.reduce(
gpu_physical_count_of_batch,
"... num_layer num_gpu -> ... num_layer",
"mean",
)
return (avg_gpu_physical_count + 1e-5) / (max_gpu_physical_count + 1e-5)