diff --git a/docs/backend/native_api.ipynb b/docs/backend/native_api.ipynb index 04b8ec0ed..189c678c0 100644 --- a/docs/backend/native_api.ipynb +++ b/docs/backend/native_api.ipynb @@ -390,7 +390,7 @@ "outputs": [], "source": [ "expert_record_server_process, port = launch_server_cmd(\n", - " \"python3 -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0\"\n", + " \"python3 -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0 --expert-distribution-recorder-mode stat\"\n", ")\n", "\n", "wait_for_server(f\"http://localhost:{port}\")" @@ -415,19 +415,7 @@ "print_highlight(response)\n", "\n", "response = requests.post(f\"http://localhost:{port}/dump_expert_distribution_record\")\n", - "print_highlight(response)\n", - "\n", - "import glob\n", - "\n", - "output_file = glob.glob(\"expert_distribution_*.csv\")[0]\n", - "with open(output_file, \"r\") as f:\n", - " print_highlight(\"\\n| Layer ID | Expert ID | Count |\")\n", - " print_highlight(\"|----------|-----------|--------|\")\n", - " next(f)\n", - " for i, line in enumerate(f):\n", - " if i < 9:\n", - " layer_id, expert_id, count = line.strip().split(\",\")\n", - " print_highlight(f\"| {layer_id:8} | {expert_id:9} | {count:6} |\")" + "print_highlight(response)" ] }, { diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 4d165dbd2..b647f456b 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -1,6 +1,9 @@ import logging from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM +from sglang.srt.managers.expert_distribution import ( + get_global_expert_distribution_recorder, +) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import DeepEPMode, load_json_config @@ -326,6 +329,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): config=_DeepEPConfig.get_instance().normal_dispatch_config, ) + get_global_expert_distribution_recorder().on_deepep_dispatch_normal( + num_recv_tokens_per_expert_list, + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + num_tokens_per_expert=num_tokens_per_expert, + ) + return ( recv_x, recv_topk_idx, @@ -489,6 +499,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ): hook() if self.return_recv_hook else event.current_stream_wait() + get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency( + masked_m + ) + reorder_topk_ids = seg_indptr = None return ( diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 4c065e4e5..075587dc0 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -18,7 +18,10 @@ from typing import Callable, Optional import torch import torch.nn.functional as F -from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder +from sglang.srt.managers.expert_distribution import ( + ExpertDistributionRecorder, + get_global_expert_distribution_recorder, +) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip @@ -31,8 +34,6 @@ if _is_cuda: if _is_cuda or _is_hip: from sgl_kernel import topk_softmax -expert_distribution_recorder = ExpertDistributionRecorder() - def fused_topk_native( hidden_states: torch.Tensor, @@ -353,6 +354,6 @@ def select_experts( renormalize=renormalize, ) - expert_distribution_recorder.record_new_token(topk_ids) + get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) return topk_weights, topk_ids diff --git a/python/sglang/srt/managers/expert_distribution.py b/python/sglang/srt/managers/expert_distribution.py index 226256ed2..c32cafbb8 100644 --- a/python/sglang/srt/managers/expert_distribution.py +++ b/python/sglang/srt/managers/expert_distribution.py @@ -1,81 +1,620 @@ -import json +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== import logging +import os import time -from collections import defaultdict -from typing import Dict, List, Tuple +from abc import ABC +from contextlib import contextmanager +from pathlib import Path +from typing import Dict, List, Literal, Optional, Tuple, Type import torch +import torch.distributed + +from sglang.srt.managers.expert_location import ExpertLocationMetadata +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import Withable logger = logging.getLogger(__name__) +# --------------------------------------- Entrypoint ----------------------------------------- -# global expert distribution recording -class ExpertDistributionRecorder: - # This class is a singleton class - def __new__(cls): - if not hasattr(cls, "instance"): - cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls) - return cls.instance +_OutputMode = Literal["file", "object"] - def __init__(self): - # the length of the dictionary is the number of layers - # the length of the list is the number of tokens - # the length of the tuple is topk's k value - self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict( - list - ) - self._record = False - self._current_layer_id = "UNKNOWN" - def set_current_layer(self, layer_idx): - self._current_layer_id = layer_idx +class ExpertDistributionRecorder(ABC): + """Global expert distribution recording""" - def record_new_token(self, topk_ids): - if not self._record: - return - topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() - torch.cuda.synchronize() - for i in topk_ids_list: - self._expert_distribution_record[self._current_layer_id].append(tuple(i)) + @staticmethod + def init_new( + server_args: ServerArgs, + expert_location_metadata: "ExpertLocationMetadata", + rank: int, + ): + if server_args.expert_distribution_recorder_mode is not None: + return _ExpertDistributionRecorderReal( + server_args, expert_location_metadata, rank + ) + else: + return _ExpertDistributionRecorderNoop() - def reset(self): - """Reset the expert distribution recorder.""" - logger.info("Resetting expert distribution record...") - self._record = False - self._expert_distribution_record.clear() - self._current_layer_id = "UNKNOWN" + @contextmanager + def with_current_layer(self, layer_idx): + yield + + @contextmanager + def with_debug_name(self, debug_name): + yield + + @contextmanager + def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch): + yield + + def on_select_experts(self, topk_ids: torch.Tensor): + pass + + def on_deepep_dispatch_normal( + self, + local_physical_count_of_layer: List[int], + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + ): + pass + + def on_deepep_dispatch_low_latency( + self, local_physical_count_of_layer: torch.Tensor + ): + pass def start_record(self): - """Start recording the expert distribution. Reset the recorder and set the recording flag to True.""" - if self._record == True: + self._on_not_implemented() + + def stop_record(self): + self._on_not_implemented() + + def dump_record(self, output_mode: _OutputMode = "file"): + self._on_not_implemented() + + def _on_not_implemented(self): + raise Exception( + "Please set ServerArgs.expert_distribution_recorder_mode to use ExpertDistributionRecorder." + ) + + +class _ExpertDistributionRecorderNoop(ExpertDistributionRecorder): + pass + + +class _ExpertDistributionRecorderReal(ExpertDistributionRecorder): + def __init__( + self, + server_args: ServerArgs, + expert_location_metadata: "ExpertLocationMetadata", + rank: int, + ): + self._server_args = server_args + self._expert_location_metadata = expert_location_metadata + + self._recording = False + self._current_forward_pass_id = Withable() + self._current_layer_idx = Withable() + self._current_debug_name = Withable() + self._accumulator = _Accumulator.init_new( + server_args, expert_location_metadata, rank + ) + self._single_pass_gatherers = { + k: _SinglePassGatherer.init_new(server_args, expert_location_metadata, rank) + for k in self._accumulator.get_single_pass_gatherer_keys() + } + + def with_current_layer(self, layer_idx): + return self._current_layer_idx.with_value(layer_idx) + + def with_debug_name(self, debug_name): + return self._current_debug_name.with_value(debug_name) + + @contextmanager + def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch): + with self._current_forward_pass_id.with_value(forward_pass_id): + self._on_forward_pass_start(forward_batch) + try: + yield + finally: + self._on_forward_pass_end(forward_pass_id) + + def _on_forward_pass_start(self, forward_batch: ForwardBatch): + if not self._recording: + return + for gatherer_key, gatherer in self._single_pass_gatherers.items(): + gatherer.reset() + gatherer.on_forward_pass_start(forward_batch) + + def _on_forward_pass_end(self, forward_pass_id: int): + if not self._recording: + return + for gatherer_key, gatherer in self._single_pass_gatherers.items(): + single_pass_data = gatherer.collect() + self._accumulator.append(forward_pass_id, gatherer_key, single_pass_data) + + def on_select_experts(self, topk_ids: torch.Tensor): + self._on_hook("on_select_experts", topk_ids=topk_ids) + + def on_deepep_dispatch_normal( + self, + local_physical_count_of_layer: List[int], + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + ): + self._on_hook( + "on_deepep_dispatch_normal", + local_physical_count_of_layer=local_physical_count_of_layer, + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + num_tokens_per_expert=num_tokens_per_expert, + ) + + def on_deepep_dispatch_low_latency( + self, local_physical_count_of_layer: torch.Tensor + ): + self._on_hook( + "on_deepep_dispatch_low_latency", + local_physical_count_of_layer=local_physical_count_of_layer, + ) + + def _on_hook(self, hook_name: str, **kwargs): + if not (self._recording or torch.cuda.is_current_stream_capturing()): + return + gatherer = self._single_pass_gatherers[ + self._accumulator.get_single_pass_gatherer_key( + self._current_debug_name.value + ) + ] + getattr(gatherer, hook_name)(layer_idx=self._current_layer_idx.value, **kwargs) + + def _reset(self): + """Reset the expert distribution recorder.""" + logger.info("Resetting ExpertDistributionRecorder...") + assert ( + self._current_layer_idx.value is None + ), f"{self._current_layer_idx.value=}" + for gatherer in self._single_pass_gatherers.values(): + gatherer.reset() + self._accumulator.reset() + + def start_record(self): + """Start recording the expert distribution.""" + if self._recording: logger.warning( "SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?" ) - self.reset() - self._record = True + self._reset() + self._recording = True def stop_record(self): - """Stop recording the expert distribution. Set the recording flag to False.""" - if self._record == False: + """Stop recording the expert distribution.""" + if not self._recording: logger.warning( "SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?" ) - self._record = False + self._recording = False - def dump_record(self): - """Dump the expert distribution record to a file. Reset the recorder after dumping.""" - results = {} - for layer_idx, layer_record in self._expert_distribution_record.items(): - results[layer_idx] = defaultdict(int) - for token_record in layer_record: - for expert_idx in token_record: - results[layer_idx][expert_idx] += 1 - with open( - f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv", - "w", - ) as fd: - fd.write("layer_id,expert_id,count\n") - for layer_idx, layer_results in results.items(): - for expert_idx, count in layer_results.items(): - fd.write(f"{layer_idx},{expert_idx},{count}\n") - self.reset() + def dump_record(self, output_mode: _OutputMode = "file"): + """Dump the expert distribution record and reset the recorder after dumping.""" + output = self._accumulator.dump(output_mode=output_mode) + self._reset() + return output + + +_global_expert_distribution_recorder: Optional[ExpertDistributionRecorder] = ( + _ExpertDistributionRecorderNoop() +) + + +def get_global_expert_distribution_recorder(): + return _global_expert_distribution_recorder + + +def set_global_expert_distribution_recorder(value): + global _global_expert_distribution_recorder + _global_expert_distribution_recorder = value + + +# --------------------------------------- SinglePassGatherer ----------------------------------------- + + +class _SinglePassGatherer(ABC): + @staticmethod + def init_new( + server_args: ServerArgs, + expert_location_metadata: "ExpertLocationMetadata", + rank: int, + ) -> "_SinglePassGatherer": + if server_args.expert_distribution_recorder_mode == "per_token": + return _DetailSinglePassGatherer( + server_args, expert_location_metadata, rank + ) + if server_args.enable_deepep_moe: + if server_args.deepep_mode == "normal": + return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank) + elif server_args.deepep_mode == "low_latency": + return _DeepepLowLatencySinglePassGatherer( + expert_location_metadata, rank + ) + else: + raise NotImplementedError + return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank) + + def __init__(self, expert_location_metadata: "ExpertLocationMetadata", rank: int): + self._expert_location_metadata = expert_location_metadata + self._rank = rank + + def on_forward_pass_start(self, forward_batch: ForwardBatch): + pass + + def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): + pass + + def on_deepep_dispatch_normal( + self, + layer_idx: int, + local_physical_count_of_layer: List[int], + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + ): + pass + + def on_deepep_dispatch_low_latency( + self, layer_idx: int, local_physical_count_of_layer: torch.Tensor + ): + pass + + def reset(self): + raise NotImplementedError + + def collect(self) -> Dict: + raise NotImplementedError + + +class _LayerBasedSinglePassGatherer(_SinglePassGatherer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._objects_of_layer = {} + + def _on_layer_data(self, layer_idx: int, objects: List[int]): + assert 0 <= layer_idx < self._expert_location_metadata.num_layers + if layer_idx in self._objects_of_layer: + self._objects_of_layer[layer_idx] = _list_sum( + self._objects_of_layer[layer_idx], objects + ) + else: + self._objects_of_layer[layer_idx] = objects + + def reset(self): + self._objects_of_layer.clear() + + def _collect_objects(self, pad_len: int) -> torch.Tensor: + data = [ + self._objects_of_layer.get(layer_index) or ([0] * pad_len) + for layer_index in range(self._expert_location_metadata.num_layers) + ] + return torch.tensor(data) + + +def _list_sum(a: List, b: List) -> List: + return [x + y for x, y in zip(a, b, strict=True)] + + +class _SelectExpertsSinglePassGatherer(_LayerBasedSinglePassGatherer): + # pretty slow, but we will use the DeepEP Gatherer in production + def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor): + topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() + torch.cuda.synchronize() + + global_physical_count = [ + 0 + ] * self._expert_location_metadata.num_physical_experts + for token_record in topk_ids_list: + for global_physical_expert_idx in token_record: + global_physical_count[global_physical_expert_idx] += 1 + + self._on_layer_data(layer_idx, global_physical_count) + + def collect(self) -> Dict: + global_physical_count = super()._collect_objects( + pad_len=self._expert_location_metadata.num_physical_experts + ) + return dict(global_physical_count=global_physical_count) + + +class _DeepepNormalSinglePassGatherer(_LayerBasedSinglePassGatherer): + def on_deepep_dispatch_normal( + self, + layer_idx: int, + local_physical_count_of_layer: List[int], + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + ): + assert isinstance(local_physical_count_of_layer, list) + self._on_layer_data(layer_idx, local_physical_count_of_layer) + + def collect(self) -> Dict: + local_physical_count = super()._collect_objects( + pad_len=self._expert_location_metadata.num_local_physical_experts + ) + global_physical_count = _convert_local_to_global_physical_count( + local_physical_count, + rank=self._rank, + num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts, + num_physical_experts=self._expert_location_metadata.num_physical_experts, + ) + return dict(global_physical_count=global_physical_count) + + +class _DeepepLowLatencySinglePassGatherer(_SinglePassGatherer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._data = torch.zeros( + ( + self._expert_location_metadata.num_layers, + self._expert_location_metadata.num_local_physical_experts, + ), + dtype=torch.int, + device="cuda", + ) + + def on_deepep_dispatch_low_latency( + self, layer_idx: int, local_physical_count_of_layer: torch.Tensor + ): + # Most naive implementation, can optimize later + self._data[layer_idx, :] += local_physical_count_of_layer + + def reset(self): + self._data[...] = 0 + + def collect(self) -> Dict: + # Can optimize if bottleneck + global_physical_count = _convert_local_to_global_physical_count( + self._data, + rank=self._rank, + num_local_physical_experts=self._expert_location_metadata.num_local_physical_experts, + num_physical_experts=self._expert_location_metadata.num_physical_experts, + ) + return dict(global_physical_count=global_physical_count) + + +def _convert_local_to_global_physical_count( + local_physical_count: torch.Tensor, + rank: int, + num_local_physical_experts: int, + num_physical_experts: int, +) -> torch.Tensor: + dtype = local_physical_count.dtype + device = local_physical_count.device + num_layers, _ = local_physical_count.shape + + ans = torch.zeros((num_layers, num_physical_experts), dtype=dtype, device=device) + ans[ + :, num_local_physical_experts * rank : num_local_physical_experts * (rank + 1) + ] = local_physical_count + return ans + + +# --------------------------------------- Accumulator ----------------------------------------- + +_SINGLE_PASS_GATHERER_KEY_PRIMARY = "primary" + + +class _Accumulator(ABC): + @staticmethod + def init_new( + server_args: ServerArgs, + expert_location_metadata: "ExpertLocationMetadata", + rank: int, + ) -> "_Accumulator": + return _Accumulator.get_class(server_args)( + server_args, expert_location_metadata, rank + ) + + @staticmethod + def get_class(server_args: ServerArgs) -> Type["_Accumulator"]: + return { + "stat": _StatAccumulator, + # TODO pr-chain: enable this later + # "per_pass": _DetailAccumulator, + # "per_token": _DetailAccumulator, + }[server_args.expert_distribution_recorder_mode] + + def __init__( + self, + server_args: ServerArgs, + expert_location_metadata: "ExpertLocationMetadata", + rank: int, + ): + self._server_args = server_args + self._expert_location_metadata = expert_location_metadata + self._rank = rank + + def get_single_pass_gatherer_keys(self): + return [_SINGLE_PASS_GATHERER_KEY_PRIMARY] + + def get_single_pass_gatherer_key(self, debug_name: Optional[str]): + return _SINGLE_PASS_GATHERER_KEY_PRIMARY + + def append( + self, + forward_pass_id: int, + gatherer_key: str, + single_pass_data: Dict, + ): + pass + + def reset(self): + pass + + def dump(self, output_mode: _OutputMode): + pass + + +class _StatAccumulator(_Accumulator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._global_physical_count_of_buffered_step = _Buffer.init_new( + item_shape=( + self._expert_location_metadata.num_layers, + # Cannot use local_physical_count to support select_experts + self._expert_location_metadata.num_physical_experts, + ), + buffer_size=self._server_args.expert_distribution_recorder_buffer_size, + dtype=torch.int32, + device=self._server_args.device, + ) + + def append( + self, + forward_pass_id: int, + gatherer_key: str, + single_pass_data: Dict, + ): + super().append(forward_pass_id, gatherer_key, single_pass_data) + # Can optimize if overhead here is large + self._global_physical_count_of_buffered_step.append( + single_pass_data["global_physical_count"] + ) + + def reset(self): + super().reset() + self._global_physical_count_of_buffered_step.reset() + + def dump(self, output_mode: _OutputMode): + logical_count_of_buffered_step = _convert_global_physical_count_to_logical_count( + self._global_physical_count_of_buffered_step.get_all(), + num_layers=self._expert_location_metadata.num_layers, + num_logical_experts=self._expert_location_metadata.num_logical_experts, + physical_to_logical_map=self._expert_location_metadata.physical_to_logical_map, + ) + torch.distributed.all_reduce( + logical_count_of_buffered_step, op=torch.distributed.ReduceOp.SUM + ) + output = dict( + rank=self._rank, + logical_count=logical_count_of_buffered_step, + ) + + if output_mode == "file": + if self._rank == 0: + _dump_to_file(f"expert_distribution_recorder_{time.time()}.pt", output) + elif output_mode == "object": + return output + else: + raise NotImplementedError + + +def _dump_to_file(name, data): + save_dir = Path(os.environ.get("SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR", "/tmp")) + path_output = save_dir / name + logger.info(f"Write expert distribution to {path_output}") + if not save_dir.exists(): + save_dir.mkdir(parents=True, exist_ok=True) + torch.save(data, str(path_output)) + + +class _Buffer: + @staticmethod + def init_new(item_shape: Tuple, buffer_size: int, dtype, device): + if buffer_size < 0: + return _InfiniteBuffer(item_shape, dtype=dtype, device=device) + else: + return _CircularBuffer(item_shape, buffer_size, dtype=dtype, device=device) + + def append(self, value: torch.Tensor): + raise NotImplementedError + + def get_all(self) -> torch.Tensor: + raise NotImplementedError + + def reset(self): + raise NotImplementedError + + +class _CircularBuffer(_Buffer): + def __init__(self, item_shape: Tuple, buffer_size: int, dtype, device): + self._buffer = torch.zeros( + (buffer_size, *item_shape), dtype=dtype, device=device + ) + self._curr_index = 0 + + def append(self, value: torch.Tensor): + self._buffer[self._curr_index] = value + self._curr_index = (self._curr_index + 1) % len(self._buffer) + + def get_all(self) -> torch.Tensor: + return self._buffer + + def reset(self): + self._buffer[...] = 0 + + +class _InfiniteBuffer(_Buffer): + def __init__(self, item_shape: Tuple, dtype, device): + self._item_shape = item_shape + self._buffer = torch.zeros((128, *item_shape), dtype=dtype, device=device) + self._size = 0 + + def append(self, value: torch.Tensor): + curr_buffer_size = len(self._buffer) + dtype = self._buffer.dtype + device = self._buffer.device + + if self._size == curr_buffer_size: + new_buffer = torch.zeros( + (2 * curr_buffer_size, *self._item_shape), dtype=dtype, device=device + ) + new_buffer[:curr_buffer_size] = self._buffer + self._buffer = new_buffer + + self._buffer[self._size] = value + self._size += 1 + + def get_all(self) -> torch.Tensor: + return self._buffer[: self._size] + + def reset(self): + self._buffer[...] = 0 + self._size = 0 + + +def _convert_global_physical_count_to_logical_count( + # (whatever, num_layers, num_physical_experts) + global_physical_count: torch.Tensor, + num_layers: int, + num_logical_experts: int, + physical_to_logical_map: torch.Tensor, +): + dim_extra, _, _ = global_physical_count.shape + dtype = global_physical_count.dtype + device = global_physical_count.device + logical_count = torch.zeros( + (dim_extra, num_layers, num_logical_experts), dtype=dtype, device=device + ) + logical_count.scatter_add_( + dim=2, + index=physical_to_logical_map.unsqueeze(0).expand(dim_extra, -1, -1), + src=global_physical_count, + ) + return logical_count diff --git a/python/sglang/srt/managers/expert_location.py b/python/sglang/srt/managers/expert_location.py new file mode 100644 index 000000000..44496cdde --- /dev/null +++ b/python/sglang/srt/managers/expert_location.py @@ -0,0 +1,273 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import json +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import torch +import torch.distributed +import torch.nn.functional as F + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.model_loader import get_model_architecture +from sglang.srt.server_args import ServerArgs + +logger = logging.getLogger(__name__) + + +@dataclass +class ExpertLocationMetadata: + physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) + logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) + logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts) + + # -------------------------------- properties ------------------------------------ + + @property + def num_layers(self) -> int: + return self.physical_to_logical_map.shape[0] + + @property + def num_physical_experts(self) -> int: + return self.physical_to_logical_map.shape[1] + + @property + def num_local_physical_experts(self) -> int: + ans, remainder = divmod(self.num_physical_experts, self.ep_size) + assert remainder == 0 + return ans + + @property + def num_logical_experts(self) -> int: + return self.logical_to_all_physical_map.shape[1] + + @property + def ep_size(self): + # TODO change when EP size != world size + return torch.distributed.get_world_size() + + def __post_init__(self): + num_layers_0, num_physical_experts_0 = self.physical_to_logical_map.shape + num_layers_1, num_logical_experts_0, num_physical_experts_1 = ( + self.logical_to_all_physical_map.shape + ) + num_layers_2, num_logical_experts_1 = ( + self.logical_to_all_physical_map_num_valid.shape + ) + # TODO pr-chain: enable this later + # assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3 + # assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2 + assert num_physical_experts_0 == num_physical_experts_1 + + # -------------------------------- construction ------------------------------------ + + @staticmethod + def init_trivial(server_args: ServerArgs, model_config: ModelConfig): + """Trivial location - logical expert i corresponds to physical expert i""" + common = ExpertLocationMetadata._init_common(server_args, model_config) + num_physical_experts = common["num_physical_experts"] + model_config_for_expert_location = common["model_config_for_expert_location"] + num_layers = model_config_for_expert_location.num_layers + num_logical_experts = model_config_for_expert_location.num_logical_experts + + physical_to_logical_map = ( + torch.arange(0, num_physical_experts).repeat(num_layers, 1) + % num_logical_experts + ) + + return ExpertLocationMetadata.init_by_mapping( + server_args, + model_config, + physical_to_logical_map=physical_to_logical_map, + ) + + @staticmethod + def init_by_mapping( + server_args: ServerArgs, + model_config: ModelConfig, + physical_to_logical_map, + ): + if not isinstance(physical_to_logical_map, torch.Tensor): + physical_to_logical_map = torch.tensor(physical_to_logical_map) + physical_to_logical_map = physical_to_logical_map.to(server_args.device) + + common = ExpertLocationMetadata._init_common(server_args, model_config) + model_config_for_expert_location = common["model_config_for_expert_location"] + logical_to_all_physical_map = _compute_logical_to_all_physical_map( + physical_to_logical_map, + num_logical_experts=model_config_for_expert_location.num_logical_experts, + ) + + return ExpertLocationMetadata._init_raw( + ep_size=common["ep_size"], + physical_to_logical_map=physical_to_logical_map, + logical_to_all_physical_map=logical_to_all_physical_map, + ) + + @staticmethod + def _init_common(server_args: ServerArgs, model_config: ModelConfig): + model_config_for_expert_location = ( + ModelConfigForExpertLocation.from_model_config(model_config) + ) + + num_physical_experts = ( + model_config_for_expert_location.num_logical_experts + # TODO pr-chain: enable this later + # + server_args.ep_num_redundant_experts + ) + ep_size = server_args.ep_size + assert num_physical_experts % ep_size == 0 + num_local_physical_experts = num_physical_experts // ep_size + + return dict( + model_config_for_expert_location=model_config_for_expert_location, + num_physical_experts=num_physical_experts, + num_local_physical_experts=num_local_physical_experts, + ep_size=ep_size, + ) + + @staticmethod + def _init_raw( + ep_size: int, + physical_to_logical_map: torch.Tensor, + logical_to_all_physical_map: torch.Tensor, + ): + _, num_physical_experts = physical_to_logical_map.shape + + logical_to_all_physical_map_padded = F.pad( + logical_to_all_physical_map, + (0, num_physical_experts - logical_to_all_physical_map.shape[-1]), + value=-1, + ) + + logical_to_all_physical_map_num_valid = torch.count_nonzero( + logical_to_all_physical_map != -1, dim=-1 + ) + + return ExpertLocationMetadata( + physical_to_logical_map=physical_to_logical_map, + logical_to_all_physical_map=logical_to_all_physical_map_padded, + logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, + ) + + +_global_expert_location_metadata: Optional[ExpertLocationMetadata] = None + + +def get_global_expert_location_metadata(): + return _global_expert_location_metadata + + +def set_global_expert_location_metadata(value): + global _global_expert_location_metadata + assert _global_expert_location_metadata is None + _global_expert_location_metadata = value + + +def _compute_logical_to_all_physical_map( + physical_to_logical_map: torch.Tensor, num_logical_experts: int +): + # This is rarely called, so we use for loops for maximum clarity + + num_layers, num_physical_experts = physical_to_logical_map.shape + + logical_to_all_physical_map = [ + [[] for _ in range(num_logical_experts)] for _ in range(num_layers) + ] + for layer_id in range(num_layers): + for physical_expert_id in range(num_physical_experts): + logical_expert_id = physical_to_logical_map[ + layer_id, physical_expert_id + ].item() + logical_to_all_physical_map[layer_id][logical_expert_id].append( + physical_expert_id + ) + + logical_to_all_physical_map = _pad_nested_array( + logical_to_all_physical_map, pad_value=-1 + ) + + return torch.tensor( + logical_to_all_physical_map, device=physical_to_logical_map.device + ) + + +def _pad_nested_array(arr, pad_value): + max_len = max(len(inner) for outer in arr for inner in outer) + padded = [ + [inner + [pad_value] * (max_len - len(inner)) for inner in outer] + for outer in arr + ] + return padded + + +@dataclass +class ModelConfigForExpertLocation: + num_layers: int + num_logical_experts: int + num_groups: Optional[int] = None + + @staticmethod + def init_dummy(): + return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1) + + @staticmethod + def from_model_config(model_config: ModelConfig): + model_class, _ = get_model_architecture(model_config) + if hasattr(model_class, "get_model_config_for_expert_location"): + return model_class.get_model_config_for_expert_location( + model_config.hf_config + ) + else: + return ModelConfigForExpertLocation.init_dummy() + + +def compute_initial_expert_location_metadata( + server_args: ServerArgs, model_config: ModelConfig +) -> ExpertLocationMetadata: + data = server_args.init_expert_location + if data == "trivial": + logger.info("init_expert_location from trivial") + return ExpertLocationMetadata.init_trivial(server_args, model_config) + + # TODO unify with the utils function + if data.endswith(".pt"): + data_dict = torch.load(data, weights_only=True) + elif data.endswith(".json"): + data_dict = json.loads(Path(data).read_text()) + else: + data_dict = json.loads(data) + + if "physical_to_logical_map" in data_dict: + logger.info( + "init_expert_location from init_by_mapping using ServerArgs.init_expert_location" + ) + return ExpertLocationMetadata.init_by_mapping( + server_args, model_config, **data_dict + ) + elif "logical_count" in data_dict: + # TODO pr-chain: enable this later + raise NotImplementedError + # logger.info( + # "init_expert_location from init_by_eplb using ServerArgs.init_expert_location" + # ) + # return ExpertLocationMetadata.init_by_eplb( + # server_args, model_config, logical_count=data_dict["logical_count"] + # ) + else: + raise NotImplementedError( + f"Unknown init_expert_location format ({list(data_dict.keys())=})" + ) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 0506460b1..72a3f7246 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -59,7 +59,10 @@ from sglang.srt.hf_transformers_utils import ( ) from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder +from sglang.srt.managers.expert_distribution import ( + ExpertDistributionRecorder, + get_global_expert_distribution_recorder, +) from sglang.srt.managers.io_struct import ( AbortReq, CloseSessionReqInput, @@ -142,8 +145,6 @@ from sglang.srt.utils import ( ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback -expert_distribution_recorder = ExpertDistributionRecorder() - logger = logging.getLogger(__name__) # Test retract decode for debugging purposes @@ -2162,11 +2163,11 @@ class Scheduler( def expert_distribution_handle(self, recv_req: ExpertDistributionReq): if recv_req == ExpertDistributionReq.START_RECORD: - expert_distribution_recorder.start_record() + get_global_expert_distribution_recorder().start_record() elif recv_req == ExpertDistributionReq.STOP_RECORD: - expert_distribution_recorder.stop_record() + get_global_expert_distribution_recorder().stop_record() elif recv_req == ExpertDistributionReq.DUMP_RECORD: - expert_distribution_recorder.dump_record() + get_global_expert_distribution_recorder().dump_record() else: raise ValueError("Unrecognized ExpertDistributionReq value") return ExpertDistributionReqOutput() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4ce681c14..78a94a898 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -52,6 +52,16 @@ from sglang.srt.layers.quantization.deep_gemm import ( from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager +from sglang.srt.managers.expert_distribution import ( + ExpertDistributionRecorder, + get_global_expert_distribution_recorder, + set_global_expert_distribution_recorder, +) +from sglang.srt.managers.expert_location import ( + compute_initial_expert_location_metadata, + get_global_expert_location_metadata, + set_global_expert_location_metadata, +) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, @@ -161,6 +171,8 @@ class ModelRunner: self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA self.attention_chunk_size = model_config.attention_chunk_size + self.forward_pass_id = 0 + # Model-specific adjustment self.model_specific_adjustment() @@ -219,6 +231,25 @@ class ModelRunner: enable=self.server_args.enable_memory_saver ) + if not self.is_draft_worker: + set_global_expert_location_metadata( + compute_initial_expert_location_metadata(server_args, self.model_config) + ) + if self.tp_rank == 0 and get_bool_env_var( + "SGLANG_LOG_EXPERT_LOCATION_METADATA" + ): + logger.info( + f"Initial expert_location_metadata: {get_global_expert_location_metadata().debug_str()}" + ) + + set_global_expert_distribution_recorder( + ExpertDistributionRecorder.init_new( + server_args, + get_global_expert_location_metadata(), + rank=self.tp_rank, + ) + ) + # Load the model self.sampler = Sampler() self.load_model() @@ -1093,6 +1124,22 @@ class ModelRunner: forward_batch: ForwardBatch, skip_attn_backend_init: bool = False, pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: + self.forward_pass_id += 1 + + with get_global_expert_distribution_recorder().with_forward_pass( + self.forward_pass_id, + forward_batch, + ): + return self._forward_raw( + forward_batch, skip_attn_backend_init, pp_proxy_tensors + ) + + def _forward_raw( + self, + forward_batch: ForwardBatch, + skip_attn_backend_init: bool, + pp_proxy_tensors: Optional[PPProxyTensors], ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: can_run_cuda_graph = bool( forward_batch.forward_mode.is_cuda_graph() diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e422a5038..3fb003ff9 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -77,7 +77,11 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder +from sglang.srt.managers.expert_distribution import ( + ExpertDistributionRecorder, + get_global_expert_distribution_recorder, +) +from sglang.srt.managers.expert_location import ModelConfigForExpertLocation from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -109,8 +113,6 @@ if _is_hip: decode_attention_fwd_grouped_rope, ) -expert_distribution_recorder = ExpertDistributionRecorder() - logger = logging.getLogger(__name__) @@ -302,6 +304,7 @@ class DeepseekV2MoE(nn.Module): def forward( self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None ) -> torch.Tensor: + forward_mode = forward_batch.forward_mode if (not self._enable_deepep_moe) or is_non_idle_and_non_empty( forward_mode, hidden_states ): @@ -1278,7 +1281,7 @@ class DeepseekV2DecoderLayer(nn.Module): ) # Fully Connected - hidden_states = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states, forward_batch) # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter # Scatter @@ -1422,11 +1425,11 @@ class DeepseekV2Model(nn.Module): residual = None for i in range(len(self.layers)): - expert_distribution_recorder.set_current_layer(i) - layer = self.layers[i] - hidden_states, residual = layer( - positions, hidden_states, forward_batch, residual, zero_allocator - ) + with get_global_expert_distribution_recorder().with_current_layer(i): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, forward_batch, residual, zero_allocator + ) if not forward_batch.forward_mode.is_idle(): if residual is None: hidden_states = self.norm(hidden_states) @@ -1872,6 +1875,14 @@ class DeepseekV2ForCausalLM(nn.Module): torch.cuda.empty_cache() torch.cuda.synchronize() + @classmethod + def get_model_config_for_expert_location(cls, config): + return ModelConfigForExpertLocation( + num_layers=config.num_hidden_layers, + num_logical_experts=config.n_routed_experts, + num_groups=config.n_group, + ) + class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 525498d5b..261b707d7 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -59,14 +59,16 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder +from sglang.srt.managers.expert_distribution import ( + ExpertDistributionRecorder, + get_global_expert_distribution_recorder, +) +from sglang.srt.managers.expert_location import ModelConfigForExpertLocation from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix, make_layers -expert_distribution_recorder = ExpertDistributionRecorder() - logger = logging.getLogger(__name__) @@ -591,11 +593,11 @@ class Qwen2MoeModel(nn.Module): residual = pp_proxy_tensors["residual"] for i in range(self.start_layer, self.end_layer): - expert_distribution_recorder.set_current_layer(i) - layer = self.layers[i] - hidden_states, residual = layer( - positions, hidden_states, forward_batch, residual - ) + with get_global_expert_distribution_recorder().with_current_layer(i): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, forward_batch, residual + ) if not self.pp_group.is_last_rank: return PPProxyTensors( { @@ -752,5 +754,13 @@ class Qwen2MoeForCausalLM(nn.Module): else: logger.warning(f"Parameter {name} not found in params_dict") + @classmethod + def get_model_config_for_expert_location(cls, config): + return ModelConfigForExpertLocation( + num_layers=config.num_hidden_layers, + num_logical_experts=config.num_experts, + num_groups=None, + ) + EntryClass = Qwen2MoeForCausalLM diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1e650fe71..59e8dccc1 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -170,6 +170,11 @@ class ServerArgs: enable_ep_moe: bool = False enable_deepep_moe: bool = False deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" + init_expert_location: str = "trivial" + expert_distribution_recorder_mode: Optional[ + Literal["stat", "per_pass", "per_token"] + ] = None + expert_distribution_recorder_buffer_size: Optional[int] = None deepep_config: Optional[str] = None enable_torch_compile: bool = False torch_compile_max_bs: int = 32 @@ -361,6 +366,15 @@ class ServerArgs: "Pipeline parallelism is incompatible with overlap schedule." ) + if self.expert_distribution_recorder_buffer_size is None: + # TODO pr-chain: enable this later + # if (x := self.eplb_rebalance_num_iterations) is not None: + # self.expert_distribution_recorder_buffer_size = x + if False: + pass + elif self.expert_distribution_recorder_mode is not None: + self.expert_distribution_recorder_buffer_size = 1000 + # Speculative Decoding if self.speculative_algorithm == "NEXTN": # NEXTN shares the same implementation of EAGLE @@ -1257,6 +1271,24 @@ class ServerArgs: default="auto", help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.", ) + parser.add_argument( + "--init-expert-location", + type=str, + default=ServerArgs.init_expert_location, + help="Initial location of EP experts.", + ) + parser.add_argument( + "--expert-distribution-recorder-mode", + type=str, + default=ServerArgs.expert_distribution_recorder_mode, + help="Mode of expert distribution recorder.", + ) + parser.add_argument( + "--expert-distribution-recorder-buffer-size", + type=int, + default=ServerArgs.expert_distribution_recorder_buffer_size, + help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.", + ) parser.add_argument( "--deepep-config", type=str, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 0c16667ef..884e715fa 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -46,7 +46,19 @@ from importlib.util import find_spec from io import BytesIO from multiprocessing.reduction import ForkingPickler from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Optional, + Protocol, + Set, + Tuple, + TypeVar, + Union, +) import numpy as np import psutil @@ -2126,3 +2138,25 @@ def load_json_config(data: str): def dispose_tensor(x: torch.Tensor): x.set_(torch.empty((0,), device=x.device, dtype=x.dtype)) + + +T = TypeVar("T") + + +class Withable(Generic[T]): + def __init__(self): + self._value: Optional[T] = None + + @property + def value(self) -> T: + return self._value + + @contextmanager + def with_value(self, new_value: T): + assert self._value is None + self._value = new_value + try: + yield + finally: + assert self._value is new_value + self._value = None diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index e3826303d..b0efcfb38 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -1,9 +1,10 @@ -import csv -import glob import os +import tempfile import unittest +from pathlib import Path import requests +import torch from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( @@ -16,108 +17,86 @@ from sglang.test.test_utils import ( class TestExpertDistribution(CustomTestCase): - def setUp(self): - # Clean up any existing expert distribution files before each test - for f in glob.glob("expert_distribution_*.csv"): - os.remove(f) - - def tearDown(self): - # Clean up any expert distribution files after each test - for f in glob.glob("expert_distribution_*.csv"): - os.remove(f) - def test_expert_distribution_record(self): + # TODO: Add tests for DeepEP gatherer (currently our CI cannot run that) + for info in [ + dict(model_path="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"), + dict(model_path="Qwen/Qwen1.5-MoE-A2.7B"), + dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", tp_size=2), + # TODO enable in next PR + # dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_pass"), + # dict(model_path="Qwen/Qwen1.5-MoE-A2.7B", mode="per_token"), + ]: + with self.subTest(info=info): + self._execute_core(**info) + + def _execute_core(self, model_path: str, mode: str = "stat", tp_size: int = 1): """Test expert distribution record endpoints""" - process = popen_launch_server( - # The feature is only implemented in deepseek_v2.py - "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", - DEFAULT_URL_FOR_TEST, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - ], - ) + with tempfile.TemporaryDirectory() as tmp_dir: + os.environ["SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR"] = tmp_dir - try: - # Start recording - response = requests.post( - f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record" - ) - self.assertEqual(response.status_code, 200) - - # Make some requests to generate expert distribution data - response = requests.post( - f"{DEFAULT_URL_FOR_TEST}/generate", - json={ - "text": "The capital of France is", - "sampling_params": { - "temperature": 0, - "max_new_tokens": 32, - }, - }, - ) - self.assertEqual(response.status_code, 200) - - # Stop recording - response = requests.post( - f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record" - ) - self.assertEqual(response.status_code, 200) - - # Dump the recorded data - response = requests.post( - f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record" - ) - self.assertEqual(response.status_code, 200) - - # Verify the dumped file exists and has correct format - csv_files = glob.glob("expert_distribution_*.csv") - self.assertEqual( - len(csv_files), - 1, - f"Expected exactly one expert distribution CSV file {csv_files=}", + process = popen_launch_server( + model_path, + DEFAULT_URL_FOR_TEST, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp-size", + str(tp_size), + "--expert-distribution-recorder-mode", + mode, + "--disable-cuda-graph", + "--disable-overlap-schedule", + ], ) - # Check CSV file format - with open(csv_files[0], "r") as f: - csv_reader = csv.reader(f) - - # Check header - header = next(csv_reader) - self.assertEqual( - header, - ["layer_id", "expert_id", "count"], - "CSV header should be 'layer_id,expert_id,count'", + try: + # Start recording + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record" ) + self.assertEqual(response.status_code, 200) + + # Make some requests to generate expert distribution data + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + }, + }, + ) + self.assertEqual(response.status_code, 200) + + # Stop recording + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record" + ) + self.assertEqual(response.status_code, 200) + + # Dump the recorded data + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record" + ) + self.assertEqual(response.status_code, 200) # Check data rows - rows = list(csv_reader) - self.assertGreater(len(rows), 0, "CSV file should contain data rows") + data = torch.load( + list(Path(tmp_dir).glob("*.pt"))[0], weights_only=True + ) + print(f"{data=}") - for row in rows: - # Verify each row has 3 columns - self.assertEqual( - len(row), - 3, - "Each row should have layer_id, expert_id and count", - ) + if mode in ["per_pass", "per_token"]: + self.assertGreater(len(data), 0, "Should contain data rows") + else: + logical_count = data["logical_count"] + print(f"{logical_count.sum()=} {logical_count=}") + self.assertTrue(logical_count.sum() > 0) - # Verify data types - layer_id, expert_id, count = row - self.assertTrue( - layer_id.isdigit(), - f"layer_id should be an integer {row=} {rows=}", - ) - self.assertTrue( - expert_id.isdigit(), - f"expert_id should be an integer {row=} {rows=}", - ) - self.assertTrue( - count.isdigit(), f"count should be an integer {row=} {rows=}" - ) - - finally: - kill_process_tree(process.pid) + finally: + kill_process_tree(process.pid) if __name__ == "__main__":