Support gathering expert distribution details (#6665)
This commit is contained in:
@@ -18,7 +18,7 @@ from abc import ABC
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Type
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
|
||||
|
||||
import einops
|
||||
import torch
|
||||
@@ -293,6 +293,79 @@ class _SinglePassGatherer(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class _DetailSinglePassGatherer(_SinglePassGatherer):
|
||||
# DeepSeek V3 has this value; should generalize later
|
||||
_TOP_K_NUM = 8
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
expert_location_metadata: "ExpertLocationMetadata",
|
||||
rank: int,
|
||||
):
|
||||
super().__init__(expert_location_metadata, rank)
|
||||
self._metadata: Optional[Dict[str, Any]] = None
|
||||
self._topk_ids_of_layer = torch.zeros(
|
||||
(
|
||||
expert_location_metadata.num_layers,
|
||||
# TODO determine the max number
|
||||
server_args.chunked_prefill_size * 8,
|
||||
self._TOP_K_NUM,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=server_args.device,
|
||||
)
|
||||
self._misc_objects: List[Dict[str, Any]] = []
|
||||
assert (
|
||||
not server_args.enable_two_batch_overlap
|
||||
), "DetailSinglePassGatherer does not support TBO yet"
|
||||
# TODO assert shared experts fusion is disabled, o/w data is wrong
|
||||
|
||||
def on_forward_pass_start(self, forward_batch: ForwardBatch):
|
||||
assert self._metadata is None
|
||||
self._metadata = dict(
|
||||
# TODO pr-chain
|
||||
# rids=forward_batch.rids,
|
||||
input_ids=forward_batch.input_ids.cpu().tolist(),
|
||||
positions=forward_batch.positions.cpu().tolist(),
|
||||
extend_seq_lens=forward_batch.extend_seq_lens_cpu,
|
||||
forward_mode=forward_batch.forward_mode.value,
|
||||
)
|
||||
|
||||
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
||||
self._topk_ids_of_layer[layer_idx, : topk_ids.shape[0], :] = topk_ids
|
||||
|
||||
def on_deepep_dispatch_normal(
|
||||
self,
|
||||
layer_idx: int,
|
||||
local_physical_count_of_layer: List[int],
|
||||
num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank,
|
||||
num_tokens_per_expert,
|
||||
):
|
||||
self._misc_objects.append(
|
||||
dict(
|
||||
layer_id=layer_idx,
|
||||
num_tokens_per_rank=num_tokens_per_rank.cpu().tolist(),
|
||||
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank.cpu().tolist(),
|
||||
num_tokens_per_expert=num_tokens_per_expert.cpu().tolist(),
|
||||
)
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
self._topk_ids_of_layer[...] = -1
|
||||
self._misc_objects.clear()
|
||||
self._metadata = None
|
||||
|
||||
def collect(self) -> Dict:
|
||||
num_tokens = len(self._metadata["input_ids"])
|
||||
return dict(
|
||||
**self._metadata,
|
||||
topk_ids_of_layer=self._topk_ids_of_layer[:, :num_tokens, :].clone().cpu(),
|
||||
misc_objects=self._misc_objects,
|
||||
)
|
||||
|
||||
|
||||
class _LayerBasedSinglePassGatherer(_SinglePassGatherer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -438,9 +511,8 @@ class _Accumulator(ABC):
|
||||
def get_class(server_args: ServerArgs) -> Type["_Accumulator"]:
|
||||
return {
|
||||
"stat": _StatAccumulator,
|
||||
# TODO pr-chain: enable this later
|
||||
# "per_pass": _DetailAccumulator,
|
||||
# "per_token": _DetailAccumulator,
|
||||
"per_pass": _DetailAccumulator,
|
||||
"per_token": _DetailAccumulator,
|
||||
}[server_args.expert_distribution_recorder_mode]
|
||||
|
||||
def __init__(
|
||||
@@ -547,6 +619,63 @@ class _DequeCollection:
|
||||
return {d.maxlen: sum(d) / len(d) for d in self._dequeues}
|
||||
|
||||
|
||||
class _DetailAccumulator(_UtilizationRateAccumulatorMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._records = []
|
||||
|
||||
def get_single_pass_gatherer_keys(self):
|
||||
if False: # TODO `server_args.enable_two_batch_overlap`
|
||||
return [_SINGLE_PASS_GATHERER_KEY_PRIMARY, "child_a", "child_b"]
|
||||
return super().get_single_pass_gatherer_keys()
|
||||
|
||||
def get_single_pass_gatherer_key(self, debug_name: Optional[str]):
|
||||
if False: # TODO `server_args.enable_two_batch_overlap`
|
||||
return debug_name or _SINGLE_PASS_GATHERER_KEY_PRIMARY
|
||||
return super().get_single_pass_gatherer_key(debug_name)
|
||||
|
||||
def append(
|
||||
self,
|
||||
forward_pass_id: int,
|
||||
gatherer_key: str,
|
||||
single_pass_data: Dict,
|
||||
):
|
||||
super().append(forward_pass_id, gatherer_key, single_pass_data)
|
||||
|
||||
def _process_object(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj.cpu().clone()
|
||||
return obj
|
||||
|
||||
single_pass_data_processed = {
|
||||
k: _process_object(v) for k, v in single_pass_data.items()
|
||||
}
|
||||
|
||||
self._records.append(
|
||||
dict(
|
||||
forward_pass_id=forward_pass_id,
|
||||
rank=self._rank,
|
||||
gatherer_key=gatherer_key,
|
||||
**single_pass_data_processed,
|
||||
)
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
super().reset()
|
||||
self._records.clear()
|
||||
|
||||
def dump(self, output_mode: _OutputMode):
|
||||
assert output_mode == "file"
|
||||
output = dict(
|
||||
records=self._records,
|
||||
# NOTE: This may change during recording, so here we say it is the "last" one
|
||||
last_physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map,
|
||||
)
|
||||
_dump_to_file(
|
||||
f"expert_distribution_recorder_{time.time()}_{self._rank}.pt", output
|
||||
)
|
||||
|
||||
|
||||
class _StatAccumulator(_UtilizationRateAccumulatorMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user