Add endpoints to dump selected expert ids (#4435)

Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
yuhsaun-t
2025-03-24 21:34:19 -07:00
committed by GitHub
parent 6b7038babd
commit 199bb01d00
10 changed files with 328 additions and 2 deletions

View File

@@ -56,6 +56,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
AbortReq,
CloseSessionReqInput,
ExpertDistributionReq,
FlushCacheReq,
GetInternalStateReq,
GetInternalStateReqOutput,
@@ -104,7 +105,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.managers.utils import ExpertDistributionRecorder, validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
@@ -128,6 +129,8 @@ from sglang.srt.utils import (
)
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__)
# Test retract decode for debugging purposes
@@ -403,6 +406,7 @@ class Scheduler(
(GetInternalStateReq, self.get_internal_state),
(SetInternalStateReq, self.set_internal_state),
(RpcReqInput, self.handle_rpc_request),
(ExpertDistributionReq, self.expert_distribution_handle),
]
)
@@ -1892,6 +1896,16 @@ class Scheduler(
ProfileReqOutput(success=True, message="Succeeded.")
)
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
if recv_req == ExpertDistributionReq.START_RECORD:
expert_distribution_recorder.start_record()
elif recv_req == ExpertDistributionReq.STOP_RECORD:
expert_distribution_recorder.stop_record()
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
expert_distribution_recorder.dump_record()
else:
raise ValueError("Unrecognized ExpertDistributionReq value")
def open_session(self, recv_req: OpenSessionReqInput):
# handle error
session_id = recv_req.session_id