Add endpoints to dump selected expert ids (#4435)
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
@@ -56,6 +56,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
CloseSessionReqInput,
|
||||
ExpertDistributionReq,
|
||||
FlushCacheReq,
|
||||
GetInternalStateReq,
|
||||
GetInternalStateReqOutput,
|
||||
@@ -104,7 +105,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
|
||||
from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.managers.utils import validate_input_length
|
||||
from sglang.srt.managers.utils import ExpertDistributionRecorder, validate_input_length
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
@@ -128,6 +129,8 @@ from sglang.srt.utils import (
|
||||
)
|
||||
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
||||
|
||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Test retract decode for debugging purposes
|
||||
@@ -403,6 +406,7 @@ class Scheduler(
|
||||
(GetInternalStateReq, self.get_internal_state),
|
||||
(SetInternalStateReq, self.set_internal_state),
|
||||
(RpcReqInput, self.handle_rpc_request),
|
||||
(ExpertDistributionReq, self.expert_distribution_handle),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1892,6 +1896,16 @@ class Scheduler(
|
||||
ProfileReqOutput(success=True, message="Succeeded.")
|
||||
)
|
||||
|
||||
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
||||
if recv_req == ExpertDistributionReq.START_RECORD:
|
||||
expert_distribution_recorder.start_record()
|
||||
elif recv_req == ExpertDistributionReq.STOP_RECORD:
|
||||
expert_distribution_recorder.stop_record()
|
||||
elif recv_req == ExpertDistributionReq.DUMP_RECORD:
|
||||
expert_distribution_recorder.dump_record()
|
||||
else:
|
||||
raise ValueError("Unrecognized ExpertDistributionReq value")
|
||||
|
||||
def open_session(self, recv_req: OpenSessionReqInput):
|
||||
# handle error
|
||||
session_id = recv_req.session_id
|
||||
|
||||
Reference in New Issue
Block a user