Add endpoints to dump selected expert ids (#4435)

Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
yuhsaun-t
2025-03-24 21:34:19 -07:00
committed by GitHub
parent 6b7038babd
commit 199bb01d00
10 changed files with 328 additions and 2 deletions

View File

@@ -343,6 +343,36 @@ async def stop_profile_async():
)
@app.api_route("/start_expert_distribution_record", methods=["GET", "POST"])
async def start_expert_distribution_record_async():
"""Start recording the expert distribution. Clear the previous record if any."""
_global_state.tokenizer_manager.start_expert_distribution_record()
return Response(
content="Start recording the expert distribution.\n",
status_code=200,
)
@app.api_route("/stop_expert_distribution_record", methods=["GET", "POST"])
async def stop_expert_distribution_record_async():
"""Stop recording the expert distribution."""
_global_state.tokenizer_manager.stop_expert_distribution_record()
return Response(
content="Stop recording the expert distribution.\n",
status_code=200,
)
@app.api_route("/dump_expert_distribution_record", methods=["GET", "POST"])
async def dump_expert_distribution_record_async():
"""Dump expert distribution record."""
_global_state.tokenizer_manager.dump_expert_distribution_record()
return Response(
content="Dump expert distribution record.\n",
status_code=200,
)
@app.post("/update_weights_from_disk")
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
"""Update the weights from disk inplace without re-launching the server."""

View File

@@ -21,6 +21,10 @@ from sglang.srt.utils import get_compiler_backend, is_cuda
_is_cuda = is_cuda()
from sglang.srt.managers.utils import ExpertDistributionRecorder
expert_distribution_recorder = ExpertDistributionRecorder()
def fused_topk_native(
hidden_states: torch.Tensor,
@@ -223,4 +227,6 @@ def select_experts(
renormalize=renormalize,
)
expert_distribution_recorder.record_new_token(topk_ids)
return topk_weights, topk_ids

View File

@@ -658,6 +658,12 @@ class ProfileReqType(Enum):
STOP_PROFILE = 2
class ExpertDistributionReq(Enum):
START_RECORD = 1
STOP_RECORD = 2
DUMP_RECORD = 3
@dataclass
class ProfileReq:
type: ProfileReqType

View File

@@ -56,6 +56,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
AbortReq,
CloseSessionReqInput,
ExpertDistributionReq,
FlushCacheReq,
GetInternalStateReq,
GetInternalStateReqOutput,
@@ -104,7 +105,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.managers.utils import ExpertDistributionRecorder, validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
@@ -128,6 +129,8 @@ from sglang.srt.utils import (
)
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes
@@ -403,6 +406,7 @@ class Scheduler(
(GetInternalStateReq, self.get_internal_state),
(SetInternalStateReq, self.set_internal_state),
(RpcReqInput, self.handle_rpc_request),
(ExpertDistributionReq, self.expert_distribution_handle),
]
)
@@ -1892,6 +1896,16 @@ class Scheduler(
ProfileReqOutput(success=True, message="Succeeded.")
)
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD:
expert_distribution_recorder.start_record()
elif recv_req == ExpertDistributionReq.STOP_RECORD:
expert_distribution_recorder.stop_record()
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
expert_distribution_recorder.dump_record()
else:
raise ValueError("Unrecognized ExpertDistributionReq value")
def open_session(self, recv_req: OpenSessionReqInput):
# handle error
session_id = recv_req.session_id

View File

@@ -60,6 +60,7 @@ from sglang.srt.managers.io_struct import (
CloseSessionReqInput,
ConfigureLoggingReq,
EmbeddingReqInput,
ExpertDistributionReq,
FlushCacheReq,
GenerateReqInput,
GetInternalStateReq,
@@ -638,6 +639,18 @@ class TokenizerManager:
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
self.send_to_scheduler.send_pyobj(req)
def start_expert_distribution_record(self):
req = ExpertDistributionReq.START_RECORD
self.send_to_scheduler.send_pyobj(req)
def stop_expert_distribution_record(self):
req = ExpertDistributionReq.STOP_RECORD
self.send_to_scheduler.send_pyobj(req)
def dump_expert_distribution_record(self):
req = ExpertDistributionReq.DUMP_RECORD
self.send_to_scheduler.send_pyobj(req)
async def update_weights_from_disk(
self,
obj: UpdateWeightFromDiskReqInput,

View File

@@ -1,6 +1,11 @@
import json
import logging
import time
from collections import defaultdict
from http import HTTPStatus
from typing import Optional
from typing import Dict, List, Optional, Tuple
import torch
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
@@ -42,3 +47,75 @@ def validate_input_length(
return error_msg
return None
# global expert distribution recording
class ExpertDistributionRecorder:
# This class is a singleton class
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
return cls.instance
def __init__(self):
# the length of the dictionary is the number of layers
# the length of the list is the number of tokens
# the length of the tuple is topk's k value
self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict(
list
)
self._record = False
self._current_layer_id = "UNKNOWN"
def set_current_layer(self, layer_idx):
self._current_layer_id = layer_idx
def record_new_token(self, topk_ids):
if not self._record:
return
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
torch.cuda.synchronize()
for i in topk_ids_list:
self._expert_distribution_record[self._current_layer_id].append(tuple(i))
def reset(self):
"""Reset the expert distribution recorder."""
logger.info("Resetting expert distribution record...")
self._record = False
self._expert_distribution_record.clear()
self._current_layer_id = "UNKNOWN"
def start_record(self):
"""Start recording the expert distribution. Reset the recorder and set the recording flag to True."""
if self._record == True:
logger.warning(
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
)
self.reset()
self._record = True
def stop_record(self):
"""Stop recording the expert distribution. Set the recording flag to False."""
if self._record == False:
logger.warning(
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
)
self._record = False
def dump_record(self):
"""Dump the expert distribution record to a file. Reset the recorder after dumping."""
results = {}
for layer_idx, layer_record in self._expert_distribution_record.items():
results[layer_idx] = defaultdict(int)
for token_record in layer_record:
for expert_idx in token_record:
results[layer_idx][expert_idx] += 1
with open(
f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv",
"w",
) as fd:
fd.write("layer_id,expert_id,count\n")
for layer_idx, layer_results in results.items():
for expert_idx, count in layer_results.items():
fd.write(f"{layer_idx},{expert_idx},{count}\n")
self.reset()

View File

@@ -68,6 +68,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.managers.utils import ExpertDistributionRecorder
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
@@ -80,6 +81,8 @@ if _is_cuda:
else:
from vllm import _custom_ops as ops
expert_distribution_recorder = ExpertDistributionRecorder()
class DeepseekV2MLP(nn.Module):
def __init__(
@@ -1160,6 +1163,7 @@ class DeepseekV2Model(nn.Module):
residual = None
for i in range(len(self.layers)):
expert_distribution_recorder.set_current_layer(i)
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual