Add endpoints to dump selected expert ids (#4435)
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
@@ -343,6 +343,36 @@ async def stop_profile_async():
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/start_expert_distribution_record", methods=["GET", "POST"])
|
||||
async def start_expert_distribution_record_async():
|
||||
"""Start recording the expert distribution. Clear the previous record if any."""
|
||||
_global_state.tokenizer_manager.start_expert_distribution_record()
|
||||
return Response(
|
||||
content="Start recording the expert distribution.\n",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/stop_expert_distribution_record", methods=["GET", "POST"])
|
||||
async def stop_expert_distribution_record_async():
|
||||
"""Stop recording the expert distribution."""
|
||||
_global_state.tokenizer_manager.stop_expert_distribution_record()
|
||||
return Response(
|
||||
content="Stop recording the expert distribution.\n",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/dump_expert_distribution_record", methods=["GET", "POST"])
|
||||
async def dump_expert_distribution_record_async():
|
||||
"""Dump expert distribution record."""
|
||||
_global_state.tokenizer_manager.dump_expert_distribution_record()
|
||||
return Response(
|
||||
content="Dump expert distribution record.\n",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/update_weights_from_disk")
|
||||
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
||||
"""Update the weights from disk inplace without re-launching the server."""
|
||||
|
||||
@@ -21,6 +21,10 @@ from sglang.srt.utils import get_compiler_backend, is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
from sglang.srt.managers.utils import ExpertDistributionRecorder
|
||||
|
||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
|
||||
|
||||
def fused_topk_native(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -223,4 +227,6 @@ def select_experts(
|
||||
renormalize=renormalize,
|
||||
)
|
||||
|
||||
expert_distribution_recorder.record_new_token(topk_ids)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
@@ -658,6 +658,12 @@ class ProfileReqType(Enum):
|
||||
STOP_PROFILE = 2
|
||||
|
||||
|
||||
class ExpertDistributionReq(Enum):
|
||||
START_RECORD = 1
|
||||
STOP_RECORD = 2
|
||||
DUMP_RECORD = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileReq:
|
||||
type: ProfileReqType
|
||||
|
||||
@@ -56,6 +56,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
CloseSessionReqInput,
|
||||
ExpertDistributionReq,
|
||||
FlushCacheReq,
|
||||
GetInternalStateReq,
|
||||
GetInternalStateReqOutput,
|
||||
@@ -104,7 +105,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
|
||||
from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.managers.utils import validate_input_length
|
||||
from sglang.srt.managers.utils import ExpertDistributionRecorder, validate_input_length
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
@@ -128,6 +129,8 @@ from sglang.srt.utils import (
|
||||
)
|
||||
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
||||
|
||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Test retract decode for debugging purposes
|
||||
@@ -403,6 +406,7 @@ class Scheduler(
|
||||
(GetInternalStateReq, self.get_internal_state),
|
||||
(SetInternalStateReq, self.set_internal_state),
|
||||
(RpcReqInput, self.handle_rpc_request),
|
||||
(ExpertDistributionReq, self.expert_distribution_handle),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1892,6 +1896,16 @@ class Scheduler(
|
||||
ProfileReqOutput(success=True, message="Succeeded.")
|
||||
)
|
||||
|
||||
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
||||
if recv_req == ExpertDistributionReq.START_RECORD:
|
||||
expert_distribution_recorder.start_record()
|
||||
elif recv_req == ExpertDistributionReq.STOP_RECORD:
|
||||
expert_distribution_recorder.stop_record()
|
||||
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
|
||||
expert_distribution_recorder.dump_record()
|
||||
else:
|
||||
raise ValueError("Unrecognized ExpertDistributionReq value")
|
||||
|
||||
def open_session(self, recv_req: OpenSessionReqInput):
|
||||
# handle error
|
||||
session_id = recv_req.session_id
|
||||
|
||||
@@ -60,6 +60,7 @@ from sglang.srt.managers.io_struct import (
|
||||
CloseSessionReqInput,
|
||||
ConfigureLoggingReq,
|
||||
EmbeddingReqInput,
|
||||
ExpertDistributionReq,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
GetInternalStateReq,
|
||||
@@ -638,6 +639,18 @@ class TokenizerManager:
|
||||
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def start_expert_distribution_record(self):
|
||||
req = ExpertDistributionReq.START_RECORD
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def stop_expert_distribution_record(self):
|
||||
req = ExpertDistributionReq.STOP_RECORD
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def dump_expert_distribution_record(self):
|
||||
req = ExpertDistributionReq.DUMP_RECORD
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
async def update_weights_from_disk(
|
||||
self,
|
||||
obj: UpdateWeightFromDiskReqInput,
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
||||
|
||||
@@ -42,3 +47,75 @@ def validate_input_length(
|
||||
return error_msg
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# global expert distribution recording
|
||||
class ExpertDistributionRecorder:
|
||||
# This class is a singleton class
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, "instance"):
|
||||
cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def __init__(self):
|
||||
# the length of the dictionary is the number of layers
|
||||
# the length of the list is the number of tokens
|
||||
# the length of the tuple is topk's k value
|
||||
self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict(
|
||||
list
|
||||
)
|
||||
self._record = False
|
||||
self._current_layer_id = "UNKNOWN"
|
||||
|
||||
def set_current_layer(self, layer_idx):
|
||||
self._current_layer_id = layer_idx
|
||||
|
||||
def record_new_token(self, topk_ids):
|
||||
if not self._record:
|
||||
return
|
||||
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
|
||||
torch.cuda.synchronize()
|
||||
for i in topk_ids_list:
|
||||
self._expert_distribution_record[self._current_layer_id].append(tuple(i))
|
||||
|
||||
def reset(self):
|
||||
"""Reset the expert distribution recorder."""
|
||||
logger.info("Resetting expert distribution record...")
|
||||
self._record = False
|
||||
self._expert_distribution_record.clear()
|
||||
self._current_layer_id = "UNKNOWN"
|
||||
|
||||
def start_record(self):
|
||||
"""Start recording the expert distribution. Reset the recorder and set the recording flag to True."""
|
||||
if self._record == True:
|
||||
logger.warning(
|
||||
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
|
||||
)
|
||||
self.reset()
|
||||
self._record = True
|
||||
|
||||
def stop_record(self):
|
||||
"""Stop recording the expert distribution. Set the recording flag to False."""
|
||||
if self._record == False:
|
||||
logger.warning(
|
||||
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
|
||||
)
|
||||
self._record = False
|
||||
|
||||
def dump_record(self):
|
||||
"""Dump the expert distribution record to a file. Reset the recorder after dumping."""
|
||||
results = {}
|
||||
for layer_idx, layer_record in self._expert_distribution_record.items():
|
||||
results[layer_idx] = defaultdict(int)
|
||||
for token_record in layer_record:
|
||||
for expert_idx in token_record:
|
||||
results[layer_idx][expert_idx] += 1
|
||||
with open(
|
||||
f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv",
|
||||
"w",
|
||||
) as fd:
|
||||
fd.write("layer_id,expert_id,count\n")
|
||||
for layer_idx, layer_results in results.items():
|
||||
for expert_idx, count in layer_results.items():
|
||||
fd.write(f"{layer_idx},{expert_idx},{count}\n")
|
||||
self.reset()
|
||||
|
||||
@@ -68,6 +68,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.managers.utils import ExpertDistributionRecorder
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
|
||||
@@ -80,6 +81,8 @@ if _is_cuda:
|
||||
else:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
def __init__(
|
||||
@@ -1160,6 +1163,7 @@ class DeepseekV2Model(nn.Module):
|
||||
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
expert_distribution_recorder.set_current_layer(i)
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
|
||||
Reference in New Issue
Block a user