From 199bb01d00af4cf0307c1e4ddee6074f30771f03 Mon Sep 17 00:00:00 2001 From: yuhsaun-t <12108766+yuhsaun-t@users.noreply.github.com> Date: Mon, 24 Mar 2025 21:34:19 -0700 Subject: [PATCH] Add endpoints to dump selected expert ids (#4435) Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com> --- docs/backend/native_api.ipynb | 64 ++++++++++ python/sglang/srt/entrypoints/http_server.py | 30 +++++ python/sglang/srt/layers/moe/topk.py | 6 + python/sglang/srt/managers/io_struct.py | 6 + python/sglang/srt/managers/scheduler.py | 16 ++- .../sglang/srt/managers/tokenizer_manager.py | 13 ++ python/sglang/srt/managers/utils.py | 79 ++++++++++++- python/sglang/srt/models/deepseek_v2.py | 4 + test/srt/run_suite.py | 1 + test/srt/test_expert_distribution.py | 111 ++++++++++++++++++ 10 files changed, 328 insertions(+), 2 deletions(-) create mode 100755 test/srt/test_expert_distribution.py diff --git a/docs/backend/native_api.ipynb b/docs/backend/native_api.ipynb index 48fc617ea..72b65c6ca 100644 --- a/docs/backend/native_api.ipynb +++ b/docs/backend/native_api.ipynb @@ -17,6 +17,9 @@ "- `/update_weights`\n", "- `/encode`(embedding model)\n", "- `/classify`(reward model)\n", + "- `/start_expert_distribution_record`\n", + "- `/stop_expert_distribution_record`\n", + "- `/dump_expert_distribution_record`\n", "\n", "We mainly use `requests` to test these APIs in the following examples. You can also use `curl`." ] @@ -362,6 +365,67 @@ "terminate_process(reward_process)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Capture expert selection distribution in MoE models\n", + "\n", + "SGLang Runtime supports recording the number of times an expert is selected in a MoE model run for each expert in the model. This is useful when analyzing the throughput of the model and plan for optimization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "expert_record_server_process, port = launch_server_cmd(\n", + " \"python -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0\"\n", + ")\n", + "\n", + "wait_for_server(f\"http://localhost:{port}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = requests.post(f\"http://localhost:{port}/start_expert_distribution_record\")\n", + "print_highlight(response)\n", + "\n", + "url = f\"http://localhost:{port}/generate\"\n", + "data = {\"text\": \"What is the capital of France?\"}\n", + "\n", + "response = requests.post(url, json=data)\n", + "print_highlight(response.json())\n", + "\n", + "response = requests.post(f\"http://localhost:{port}/stop_expert_distribution_record\")\n", + "print_highlight(response)\n", + "\n", + "response = requests.post(f\"http://localhost:{port}/dump_expert_distribution_record\")\n", + "print_highlight(response)\n", + "\n", + "import glob\n", + "\n", + "output_file = glob.glob(\"expert_distribution_*.csv\")[0]\n", + "with open(output_file, \"r\") as f:\n", + " print_highlight(\"Content of dumped record:\")\n", + " for line in f:\n", + " print_highlight(line.strip())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(expert_record_server_process)" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 62b151162..36c33aeef 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -343,6 +343,36 @@ async def stop_profile_async(): ) +@app.api_route("/start_expert_distribution_record", methods=["GET", "POST"]) +async def start_expert_distribution_record_async(): + """Start recording the expert distribution. Clear the previous record if any.""" + _global_state.tokenizer_manager.start_expert_distribution_record() + return Response( + content="Start recording the expert distribution.\n", + status_code=200, + ) + + +@app.api_route("/stop_expert_distribution_record", methods=["GET", "POST"]) +async def stop_expert_distribution_record_async(): + """Stop recording the expert distribution.""" + _global_state.tokenizer_manager.stop_expert_distribution_record() + return Response( + content="Stop recording the expert distribution.\n", + status_code=200, + ) + + +@app.api_route("/dump_expert_distribution_record", methods=["GET", "POST"]) +async def dump_expert_distribution_record_async(): + """Dump expert distribution record.""" + _global_state.tokenizer_manager.dump_expert_distribution_record() + return Response( + content="Dump expert distribution record.\n", + status_code=200, + ) + + @app.post("/update_weights_from_disk") async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): """Update the weights from disk inplace without re-launching the server.""" diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index e975819a9..378eef795 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -21,6 +21,10 @@ from sglang.srt.utils import get_compiler_backend, is_cuda _is_cuda = is_cuda() +from sglang.srt.managers.utils import ExpertDistributionRecorder + +expert_distribution_recorder = ExpertDistributionRecorder() + def fused_topk_native( hidden_states: torch.Tensor, @@ -223,4 +227,6 @@ def select_experts( renormalize=renormalize, ) + expert_distribution_recorder.record_new_token(topk_ids) + return topk_weights, topk_ids diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index dad8c2ef1..7e2fcbb5f 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -658,6 +658,12 @@ class ProfileReqType(Enum): STOP_PROFILE = 2 +class ExpertDistributionReq(Enum): + START_RECORD = 1 + STOP_RECORD = 2 + DUMP_RECORD = 3 + + @dataclass class ProfileReq: type: ProfileReqType diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index bf5ab8dc1..8b1f521de 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -56,6 +56,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( AbortReq, CloseSessionReqInput, + ExpertDistributionReq, FlushCacheReq, GetInternalStateReq, GetInternalStateReqOutput, @@ -104,7 +105,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import ( from sglang.srt.managers.session_controller import Session from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient -from sglang.srt.managers.utils import validate_input_length +from sglang.srt.managers.utils import ExpertDistributionRecorder, validate_input_length from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache @@ -128,6 +129,8 @@ from sglang.srt.utils import ( ) from sglang.utils import TypeBasedDispatcher, get_exception_traceback +expert_distribution_recorder = ExpertDistributionRecorder() + logger = logging.getLogger(__name__) # Test retract decode for debugging purposes @@ -403,6 +406,7 @@ class Scheduler( (GetInternalStateReq, self.get_internal_state), (SetInternalStateReq, self.set_internal_state), (RpcReqInput, self.handle_rpc_request), + (ExpertDistributionReq, self.expert_distribution_handle), ] ) @@ -1892,6 +1896,16 @@ class Scheduler( ProfileReqOutput(success=True, message="Succeeded.") ) + def expert_distribution_handle(self, recv_req: ExpertDistributionReq): + if recv_req == ExpertDistributionReq.START_RECORD: + expert_distribution_recorder.start_record() + elif recv_req == ExpertDistributionReq.STOP_RECORD: + expert_distribution_recorder.stop_record() + elif recv_req == ExpertDistributionReq.DUMP_RECORD: + expert_distribution_recorder.dump_record() + else: + raise ValueError("Unrecognized ExpertDistributionReq value") + def open_session(self, recv_req: OpenSessionReqInput): # handle error session_id = recv_req.session_id diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 8d5ef72ce..25834d32a 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -60,6 +60,7 @@ from sglang.srt.managers.io_struct import ( CloseSessionReqInput, ConfigureLoggingReq, EmbeddingReqInput, + ExpertDistributionReq, FlushCacheReq, GenerateReqInput, GetInternalStateReq, @@ -638,6 +639,18 @@ class TokenizerManager: req = ProfileReq(type=ProfileReqType.STOP_PROFILE) self.send_to_scheduler.send_pyobj(req) + def start_expert_distribution_record(self): + req = ExpertDistributionReq.START_RECORD + self.send_to_scheduler.send_pyobj(req) + + def stop_expert_distribution_record(self): + req = ExpertDistributionReq.STOP_RECORD + self.send_to_scheduler.send_pyobj(req) + + def dump_expert_distribution_record(self): + req = ExpertDistributionReq.DUMP_RECORD + self.send_to_scheduler.send_pyobj(req) + async def update_weights_from_disk( self, obj: UpdateWeightFromDiskReqInput, diff --git a/python/sglang/srt/managers/utils.py b/python/sglang/srt/managers/utils.py index 10a120963..2730075ff 100644 --- a/python/sglang/srt/managers/utils.py +++ b/python/sglang/srt/managers/utils.py @@ -1,6 +1,11 @@ +import json import logging +import time +from collections import defaultdict from http import HTTPStatus -from typing import Optional +from typing import Dict, List, Optional, Tuple + +import torch from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req @@ -42,3 +47,75 @@ def validate_input_length( return error_msg return None + + +# global expert distribution recording +class ExpertDistributionRecorder: + # This class is a singleton class + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls) + return cls.instance + + def __init__(self): + # the length of the dictionary is the number of layers + # the length of the list is the number of tokens + # the length of the tuple is topk's k value + self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict( + list + ) + self._record = False + self._current_layer_id = "UNKNOWN" + + def set_current_layer(self, layer_idx): + self._current_layer_id = layer_idx + + def record_new_token(self, topk_ids): + if not self._record: + return + topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist() + torch.cuda.synchronize() + for i in topk_ids_list: + self._expert_distribution_record[self._current_layer_id].append(tuple(i)) + + def reset(self): + """Reset the expert distribution recorder.""" + logger.info("Resetting expert distribution record...") + self._record = False + self._expert_distribution_record.clear() + self._current_layer_id = "UNKNOWN" + + def start_record(self): + """Start recording the expert distribution. Reset the recorder and set the recording flag to True.""" + if self._record == True: + logger.warning( + "SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?" + ) + self.reset() + self._record = True + + def stop_record(self): + """Stop recording the expert distribution. Set the recording flag to False.""" + if self._record == False: + logger.warning( + "SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?" + ) + self._record = False + + def dump_record(self): + """Dump the expert distribution record to a file. Reset the recorder after dumping.""" + results = {} + for layer_idx, layer_record in self._expert_distribution_record.items(): + results[layer_idx] = defaultdict(int) + for token_record in layer_record: + for expert_idx in token_record: + results[layer_idx][expert_idx] += 1 + with open( + f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv", + "w", + ) as fd: + fd.write("layer_id,expert_id,count\n") + for layer_idx, layer_results in results.items(): + for expert_idx, count in layer_results.items(): + fd.write(f"{layer_idx},{expert_idx},{count}\n") + self.reset() diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index f847b601c..0555e0cd2 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -68,6 +68,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.managers.utils import ExpertDistributionRecorder from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip @@ -80,6 +81,8 @@ if _is_cuda: else: from vllm import _custom_ops as ops +expert_distribution_recorder = ExpertDistributionRecorder() + class DeepseekV2MLP(nn.Module): def __init__( @@ -1160,6 +1163,7 @@ class DeepseekV2Model(nn.Module): residual = None for i in range(len(self.layers)): + expert_distribution_recorder.set_current_layer(i) layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, forward_batch, residual diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index e4e5f3252..24978502b 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -30,6 +30,7 @@ suites = { TestFile("test_ebnf_constrained.py"), TestFile("test_fp8_kernel.py", 2), TestFile("test_embedding_openai_server.py", 36), + TestFile("test_expert_distribution.py", 31), TestFile("test_gguf.py", 78), TestFile("test_gptqmodel_dynamic.py", 72), TestFile("test_hidden_states.py", 55), diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py new file mode 100755 index 000000000..57f5d7d8d --- /dev/null +++ b/test/srt/test_expert_distribution.py @@ -0,0 +1,111 @@ +import csv +import glob +import os +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestExpertDistribution(unittest.TestCase): + def setUp(self): + # Clean up any existing expert distribution files before each test + for f in glob.glob("expert_distribution_*.csv"): + os.remove(f) + + def tearDown(self): + # Clean up any expert distribution files after each test + for f in glob.glob("expert_distribution_*.csv"): + os.remove(f) + + def test_expert_distribution_record(self): + """Test expert distribution record endpoints""" + process = popen_launch_server( + DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_TEST, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + ) + + try: + # Start recording + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record" + ) + self.assertEqual(response.status_code, 200) + + # Make some requests to generate expert distribution data + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + }, + }, + ) + self.assertEqual(response.status_code, 200) + + # Stop recording + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record" + ) + self.assertEqual(response.status_code, 200) + + # Dump the recorded data + response = requests.post( + f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record" + ) + self.assertEqual(response.status_code, 200) + + # Verify the dumped file exists and has correct format + csv_files = glob.glob("expert_distribution_*.csv") + self.assertEqual( + len(csv_files), 1, "Expected exactly one expert distribution CSV file" + ) + + # Check CSV file format + with open(csv_files[0], "r") as f: + csv_reader = csv.reader(f) + + # Check header + header = next(csv_reader) + self.assertEqual( + header, + ["layer_id", "expert_id", "count"], + "CSV header should be 'layer_id,expert_id,count'", + ) + + # Check data rows + rows = list(csv_reader) + self.assertGreater(len(rows), 0, "CSV file should contain data rows") + + for row in rows: + # Verify each row has 3 columns + self.assertEqual( + len(row), + 3, + "Each row should have layer_id, expert_id and count", + ) + + # Verify data types + layer_id, expert_id, count = row + self.assertTrue(layer_id.isdigit(), "layer_id should be an integer") + self.assertTrue( + expert_id.isdigit(), "expert_id should be an integer" + ) + self.assertTrue(count.isdigit(), "count should be an integer") + + finally: + kill_process_tree(process.pid) + + +if __name__ == "__main__": + unittest.main()