From 87068b5cc7e908385b2ea15c6d42f31d87e0e9f8 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Wed, 28 May 2025 06:32:59 +0800 Subject: [PATCH] Support gathering expert distribution details (#6665) --- .../srt/managers/expert_distribution.py | 137 +++++++++++++++++- test/srt/test_expert_distribution.py | 5 +- 2 files changed, 135 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 7191bedd8..a36ca5850 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -18,7 +18,7 @@ from abc import ABC from collections import deque from contextlib import contextmanager from pathlib import Path -from typing import Dict, List, Literal, Optional, Tuple, Type +from typing import Any, Dict, List, Literal, Optional, Tuple, Type import einops import torch @@ -293,6 +293,79 @@ class _SinglePassGatherer(ABC): raise NotImplementedError +class _DetailSinglePassGatherer(_SinglePassGatherer): + # DeepSeek V3 has this value; should generalize later + _TOP_K_NUM = 8 + + def __init__( + self, + server_args: ServerArgs, + expert_location_metadata: "ExpertLocationMetadata", + rank: int, + ): + super().__init__(expert_location_metadata, rank) + self._metadata: Optional[Dict[str, Any]] = None + self._topk_ids_of_layer = torch.zeros( + ( + expert_location_metadata.num_layers, + # TODO determine the max number + server_args.chunked_prefill_size * 8, + self._TOP_K_NUM, + ), + dtype=torch.int32, + device=server_args.device, + ) + self._misc_objects: List[Dict[str, Any]] = [] + assert ( + not server_args.enable_two_batch_overlap + ), "DetailSinglePassGatherer does not support TBO yet" + # TODO assert shared experts fusion is disabled, o/w data is wrong + + def on_forward_pass_start(self, forward_batch: ForwardBatch): + assert self._metadata is None + self._metadata = dict( + # TODO pr-chain + # rids=forward_batch.rids, + input_ids=forward_batch.input_ids.cpu().tolist(), + positions=forward_batch.positions.cpu().tolist(), + extend_seq_lens=forward_batch.extend_seq_lens_cpu, + forward_mode=forward_batch.forward_mode.value, + ) + + def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): + self._topk_ids_of_layer[layer_idx, : topk_ids.shape[0], :] = topk_ids + + def on_deepep_dispatch_normal( + self, + layer_idx: int, + local_physical_count_of_layer: List[int], + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + ): + self._misc_objects.append( + dict( + layer_id=layer_idx, + num_tokens_per_rank=num_tokens_per_rank.cpu().tolist(), + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank.cpu().tolist(), + num_tokens_per_expert=num_tokens_per_expert.cpu().tolist(), + ) + ) + + def reset(self): + self._topk_ids_of_layer[...] = -1 + self._misc_objects.clear() + self._metadata = None + + def collect(self) -> Dict: + num_tokens = len(self._metadata["input_ids"]) + return dict( + **self._metadata, + topk_ids_of_layer=self._topk_ids_of_layer[:, :num_tokens, :].clone().cpu(), + misc_objects=self._misc_objects, + ) + + class _LayerBasedSinglePassGatherer(_SinglePassGatherer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -438,9 +511,8 @@ class _Accumulator(ABC): def get_class(server_args: ServerArgs) -> Type["_Accumulator"]: return { "stat": _StatAccumulator, - # TODO pr-chain: enable this later - # "per_pass": _DetailAccumulator, - # "per_token": _DetailAccumulator, + "per_pass": _DetailAccumulator, + "per_token": _DetailAccumulator, }[server_args.expert_distribution_recorder_mode] def __init__( @@ -547,6 +619,63 @@ class _DequeCollection: return {d.maxlen: sum(d) / len(d) for d in self._dequeues} +class _DetailAccumulator(_UtilizationRateAccumulatorMixin): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._records = [] + + def get_single_pass_gatherer_keys(self): + if False: # TODO `server_args.enable_two_batch_overlap` + return [_SINGLE_PASS_GATHERER_KEY_PRIMARY, "child_a", "child_b"] + return super().get_single_pass_gatherer_keys() + + def get_single_pass_gatherer_key(self, debug_name: Optional[str]): + if False: # TODO `server_args.enable_two_batch_overlap` + return debug_name or _SINGLE_PASS_GATHERER_KEY_PRIMARY + return super().get_single_pass_gatherer_key(debug_name) + + def append( + self, + forward_pass_id: int, + gatherer_key: str, + single_pass_data: Dict, + ): + super().append(forward_pass_id, gatherer_key, single_pass_data) + + def _process_object(obj): + if isinstance(obj, torch.Tensor): + return obj.cpu().clone() + return obj + + single_pass_data_processed = { + k: _process_object(v) for k, v in single_pass_data.items() + } + + self._records.append( + dict( + forward_pass_id=forward_pass_id, + rank=self._rank, + gatherer_key=gatherer_key, + **single_pass_data_processed, + ) + ) + + def reset(self): + super().reset() + self._records.clear() + + def dump(self, output_mode: _OutputMode): + assert output_mode == "file" + output = dict( + records=self._records, + # NOTE: This may change during recording, so here we say it is the "last" one + last_physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map, + ) + _dump_to_file( + f"expert_distribution_recorder_{time.time()}_{self._rank}.pt", output + ) + + class _StatAccumulator(_UtilizationRateAccumulatorMixin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index b0efcfb38..f98c97766 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -23,9 +23,8 @@ class TestExpertDistribution(CustomTestCase): dict(model_path="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"), dict(model_path="Qwen/Qwen1.5-MoE-A2.7B"), dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", tp_size=2), - # TODO enable in next PR - # dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_pass"), - # dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_token"), + dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_pass"), + dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_token"), ]: with self.subTest(info=info): self._execute_core(**info)