Add endpoints to dump selected expert ids (#4435)
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
@@ -17,6 +17,9 @@
|
||||
"- `/update_weights`\n",
|
||||
"- `/encode`(embedding model)\n",
|
||||
"- `/classify`(reward model)\n",
|
||||
"- `/start_expert_distribution_record`\n",
|
||||
"- `/stop_expert_distribution_record`\n",
|
||||
"- `/dump_expert_distribution_record`\n",
|
||||
"\n",
|
||||
"We mainly use `requests` to test these APIs in the following examples. You can also use `curl`."
|
||||
]
|
||||
@@ -362,6 +365,67 @@
|
||||
"terminate_process(reward_process)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Capture expert selection distribution in MoE models\n",
|
||||
"\n",
|
||||
"SGLang Runtime supports recording the number of times an expert is selected in a MoE model run for each expert in the model. This is useful when analyzing the throughput of the model and plan for optimization."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"expert_record_server_process, port = launch_server_cmd(\n",
|
||||
" \"python -m sglang.launch_server --model-path Qwen/Qwen1.5-MoE-A2.7B --host 0.0.0.0\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"wait_for_server(f\"http://localhost:{port}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"response = requests.post(f\"http://localhost:{port}/start_expert_distribution_record\")\n",
|
||||
"print_highlight(response)\n",
|
||||
"\n",
|
||||
"url = f\"http://localhost:{port}/generate\"\n",
|
||||
"data = {\"text\": \"What is the capital of France?\"}\n",
|
||||
"\n",
|
||||
"response = requests.post(url, json=data)\n",
|
||||
"print_highlight(response.json())\n",
|
||||
"\n",
|
||||
"response = requests.post(f\"http://localhost:{port}/stop_expert_distribution_record\")\n",
|
||||
"print_highlight(response)\n",
|
||||
"\n",
|
||||
"response = requests.post(f\"http://localhost:{port}/dump_expert_distribution_record\")\n",
|
||||
"print_highlight(response)\n",
|
||||
"\n",
|
||||
"import glob\n",
|
||||
"\n",
|
||||
"output_file = glob.glob(\"expert_distribution_*.csv\")[0]\n",
|
||||
"with open(output_file, \"r\") as f:\n",
|
||||
" print_highlight(\"Content of dumped record:\")\n",
|
||||
" for line in f:\n",
|
||||
" print_highlight(line.strip())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"terminate_process(expert_record_server_process)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
||||
@@ -343,6 +343,36 @@ async def stop_profile_async():
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/start_expert_distribution_record", methods=["GET", "POST"])
|
||||
async def start_expert_distribution_record_async():
|
||||
"""Start recording the expert distribution. Clear the previous record if any."""
|
||||
_global_state.tokenizer_manager.start_expert_distribution_record()
|
||||
return Response(
|
||||
content="Start recording the expert distribution.\n",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/stop_expert_distribution_record", methods=["GET", "POST"])
|
||||
async def stop_expert_distribution_record_async():
|
||||
"""Stop recording the expert distribution."""
|
||||
_global_state.tokenizer_manager.stop_expert_distribution_record()
|
||||
return Response(
|
||||
content="Stop recording the expert distribution.\n",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/dump_expert_distribution_record", methods=["GET", "POST"])
|
||||
async def dump_expert_distribution_record_async():
|
||||
"""Dump expert distribution record."""
|
||||
_global_state.tokenizer_manager.dump_expert_distribution_record()
|
||||
return Response(
|
||||
content="Dump expert distribution record.\n",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/update_weights_from_disk")
|
||||
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
||||
"""Update the weights from disk inplace without re-launching the server."""
|
||||
|
||||
@@ -21,6 +21,10 @@ from sglang.srt.utils import get_compiler_backend, is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
from sglang.srt.managers.utils import ExpertDistributionRecorder
|
||||
|
||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
|
||||
|
||||
def fused_topk_native(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -223,4 +227,6 @@ def select_experts(
|
||||
renormalize=renormalize,
|
||||
)
|
||||
|
||||
expert_distribution_recorder.record_new_token(topk_ids)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
@@ -658,6 +658,12 @@ class ProfileReqType(Enum):
|
||||
STOP_PROFILE = 2
|
||||
|
||||
|
||||
class ExpertDistributionReq(Enum):
|
||||
START_RECORD = 1
|
||||
STOP_RECORD = 2
|
||||
DUMP_RECORD = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileReq:
|
||||
type: ProfileReqType
|
||||
|
||||
@@ -56,6 +56,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
CloseSessionReqInput,
|
||||
ExpertDistributionReq,
|
||||
FlushCacheReq,
|
||||
GetInternalStateReq,
|
||||
GetInternalStateReqOutput,
|
||||
@@ -104,7 +105,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
|
||||
from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.managers.utils import validate_input_length
|
||||
from sglang.srt.managers.utils import ExpertDistributionRecorder, validate_input_length
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
@@ -128,6 +129,8 @@ from sglang.srt.utils import (
|
||||
)
|
||||
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
||||
|
||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Test retract decode for debugging purposes
|
||||
@@ -403,6 +406,7 @@ class Scheduler(
|
||||
(GetInternalStateReq, self.get_internal_state),
|
||||
(SetInternalStateReq, self.set_internal_state),
|
||||
(RpcReqInput, self.handle_rpc_request),
|
||||
(ExpertDistributionReq, self.expert_distribution_handle),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1892,6 +1896,16 @@ class Scheduler(
|
||||
ProfileReqOutput(success=True, message="Succeeded.")
|
||||
)
|
||||
|
||||
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
||||
if recv_req == ExpertDistributionReq.START_RECORD:
|
||||
expert_distribution_recorder.start_record()
|
||||
elif recv_req == ExpertDistributionReq.STOP_RECORD:
|
||||
expert_distribution_recorder.stop_record()
|
||||
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
|
||||
expert_distribution_recorder.dump_record()
|
||||
else:
|
||||
raise ValueError("Unrecognized ExpertDistributionReq value")
|
||||
|
||||
def open_session(self, recv_req: OpenSessionReqInput):
|
||||
# handle error
|
||||
session_id = recv_req.session_id
|
||||
|
||||
@@ -60,6 +60,7 @@ from sglang.srt.managers.io_struct import (
|
||||
CloseSessionReqInput,
|
||||
ConfigureLoggingReq,
|
||||
EmbeddingReqInput,
|
||||
ExpertDistributionReq,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
GetInternalStateReq,
|
||||
@@ -638,6 +639,18 @@ class TokenizerManager:
|
||||
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def start_expert_distribution_record(self):
|
||||
req = ExpertDistributionReq.START_RECORD
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def stop_expert_distribution_record(self):
|
||||
req = ExpertDistributionReq.STOP_RECORD
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def dump_expert_distribution_record(self):
|
||||
req = ExpertDistributionReq.DUMP_RECORD
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
async def update_weights_from_disk(
|
||||
self,
|
||||
obj: UpdateWeightFromDiskReqInput,
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
||||
|
||||
@@ -42,3 +47,75 @@ def validate_input_length(
|
||||
return error_msg
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# global expert distribution recording
|
||||
class ExpertDistributionRecorder:
|
||||
# This class is a singleton class
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def __init__(self):
|
||||
# the length of the dictionary is the number of layers
|
||||
# the length of the list is the number of tokens
|
||||
# the length of the tuple is topk's k value
|
||||
self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict(
|
||||
list
|
||||
)
|
||||
self._record = False
|
||||
self._current_layer_id = "UNKNOWN"
|
||||
|
||||
def set_current_layer(self, layer_idx):
|
||||
self._current_layer_id = layer_idx
|
||||
|
||||
def record_new_token(self, topk_ids):
|
||||
if not self._record:
|
||||
return
|
||||
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
|
||||
torch.cuda.synchronize()
|
||||
for i in topk_ids_list:
|
||||
self._expert_distribution_record[self._current_layer_id].append(tuple(i))
|
||||
|
||||
def reset(self):
|
||||
"""Reset the expert distribution recorder."""
|
||||
logger.info("Resetting expert distribution record...")
|
||||
self._record = False
|
||||
self._expert_distribution_record.clear()
|
||||
self._current_layer_id = "UNKNOWN"
|
||||
|
||||
def start_record(self):
|
||||
"""Start recording the expert distribution. Reset the recorder and set the recording flag to True."""
|
||||
if self._record == True:
|
||||
logger.warning(
|
||||
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
|
||||
)
|
||||
self.reset()
|
||||
self._record = True
|
||||
|
||||
def stop_record(self):
|
||||
"""Stop recording the expert distribution. Set the recording flag to False."""
|
||||
if self._record == False:
|
||||
logger.warning(
|
||||
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
|
||||
)
|
||||
self._record = False
|
||||
|
||||
def dump_record(self):
|
||||
"""Dump the expert distribution record to a file. Reset the recorder after dumping."""
|
||||
results = {}
|
||||
for layer_idx, layer_record in self._expert_distribution_record.items():
|
||||
results[layer_idx] = defaultdict(int)
|
||||
for token_record in layer_record:
|
||||
for expert_idx in token_record:
|
||||
results[layer_idx][expert_idx] += 1
|
||||
with open(
|
||||
f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv",
|
||||
"w",
|
||||
) as fd:
|
||||
fd.write("layer_id,expert_id,count\n")
|
||||
for layer_idx, layer_results in results.items():
|
||||
for expert_idx, count in layer_results.items():
|
||||
fd.write(f"{layer_idx},{expert_idx},{count}\n")
|
||||
self.reset()
|
||||
|
||||
@@ -68,6 +68,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.managers.utils import ExpertDistributionRecorder
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
|
||||
@@ -80,6 +81,8 @@ if _is_cuda:
|
||||
else:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
def __init__(
|
||||
@@ -1160,6 +1163,7 @@ class DeepseekV2Model(nn.Module):
|
||||
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
expert_distribution_recorder.set_current_layer(i)
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
|
||||
@@ -30,6 +30,7 @@ suites = {
|
||||
TestFile("test_ebnf_constrained.py"),
|
||||
TestFile("test_fp8_kernel.py", 2),
|
||||
TestFile("test_embedding_openai_server.py", 36),
|
||||
TestFile("test_expert_distribution.py", 31),
|
||||
TestFile("test_gguf.py", 78),
|
||||
TestFile("test_gptqmodel_dynamic.py", 72),
|
||||
TestFile("test_hidden_states.py", 55),
|
||||
|
||||
111
test/srt/test_expert_distribution.py
Executable file
111
test/srt/test_expert_distribution.py
Executable file
@@ -0,0 +1,111 @@
|
||||
import csv
|
||||
import glob
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestExpertDistribution(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Clean up any existing expert distribution files before each test
|
||||
for f in glob.glob("expert_distribution_*.csv"):
|
||||
os.remove(f)
|
||||
|
||||
def tearDown(self):
|
||||
# Clean up any expert distribution files after each test
|
||||
for f in glob.glob("expert_distribution_*.csv"):
|
||||
os.remove(f)
|
||||
|
||||
def test_expert_distribution_record(self):
|
||||
"""Test expert distribution record endpoints"""
|
||||
process = popen_launch_server(
|
||||
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
)
|
||||
|
||||
try:
|
||||
# Start recording
|
||||
response = requests.post(
|
||||
f"{DEFAULT_URL_FOR_TEST}/start_expert_distribution_record"
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Make some requests to generate expert distribution data
|
||||
response = requests.post(
|
||||
f"{DEFAULT_URL_FOR_TEST}/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 32,
|
||||
},
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Stop recording
|
||||
response = requests.post(
|
||||
f"{DEFAULT_URL_FOR_TEST}/stop_expert_distribution_record"
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Dump the recorded data
|
||||
response = requests.post(
|
||||
f"{DEFAULT_URL_FOR_TEST}/dump_expert_distribution_record"
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Verify the dumped file exists and has correct format
|
||||
csv_files = glob.glob("expert_distribution_*.csv")
|
||||
self.assertEqual(
|
||||
len(csv_files), 1, "Expected exactly one expert distribution CSV file"
|
||||
)
|
||||
|
||||
# Check CSV file format
|
||||
with open(csv_files[0], "r") as f:
|
||||
csv_reader = csv.reader(f)
|
||||
|
||||
# Check header
|
||||
header = next(csv_reader)
|
||||
self.assertEqual(
|
||||
header,
|
||||
["layer_id", "expert_id", "count"],
|
||||
"CSV header should be 'layer_id,expert_id,count'",
|
||||
)
|
||||
|
||||
# Check data rows
|
||||
rows = list(csv_reader)
|
||||
self.assertGreater(len(rows), 0, "CSV file should contain data rows")
|
||||
|
||||
for row in rows:
|
||||
# Verify each row has 3 columns
|
||||
self.assertEqual(
|
||||
len(row),
|
||||
3,
|
||||
"Each row should have layer_id, expert_id and count",
|
||||
)
|
||||
|
||||
# Verify data types
|
||||
layer_id, expert_id, count = row
|
||||
self.assertTrue(layer_id.isdigit(), "layer_id should be an integer")
|
||||
self.assertTrue(
|
||||
expert_id.isdigit(), "expert_id should be an integer"
|
||||
)
|
||||
self.assertTrue(count.isdigit(), "count should be an integer")
|
||||
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user