diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 36c33aeef..114204d17 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -346,7 +346,7 @@ async def stop_profile_async(): @app.api_route("/start_expert_distribution_record", methods=["GET", "POST"]) async def start_expert_distribution_record_async(): """Start recording the expert distribution. Clear the previous record if any.""" - _global_state.tokenizer_manager.start_expert_distribution_record() + await _global_state.tokenizer_manager.start_expert_distribution_record() return Response( content="Start recording the expert distribution.\n", status_code=200, @@ -356,7 +356,7 @@ async def start_expert_distribution_record_async(): @app.api_route("/stop_expert_distribution_record", methods=["GET", "POST"]) async def stop_expert_distribution_record_async(): """Stop recording the expert distribution.""" - _global_state.tokenizer_manager.stop_expert_distribution_record() + await _global_state.tokenizer_manager.stop_expert_distribution_record() return Response( content="Stop recording the expert distribution.\n", status_code=200, @@ -366,7 +366,7 @@ async def stop_expert_distribution_record_async(): @app.api_route("/dump_expert_distribution_record", methods=["GET", "POST"]) async def dump_expert_distribution_record_async(): """Dump expert distribution record.""" - _global_state.tokenizer_manager.dump_expert_distribution_record() + await _global_state.tokenizer_manager.dump_expert_distribution_record() return Response( content="Dump expert distribution record.\n", status_code=200, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 7e2fcbb5f..a94d1968e 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -664,6 +664,11 @@ class ExpertDistributionReq(Enum): DUMP_RECORD = 3 +@dataclass +class ExpertDistributionReqOutput: + pass + + @dataclass class ProfileReq: type: ProfileReqType diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8b1f521de..f1f0b896f 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -57,6 +57,7 @@ from sglang.srt.managers.io_struct import ( AbortReq, CloseSessionReqInput, ExpertDistributionReq, + ExpertDistributionReqOutput, FlushCacheReq, GetInternalStateReq, GetInternalStateReqOutput, @@ -1905,6 +1906,7 @@ class Scheduler( expert_distribution_recorder.dump_record() else: raise ValueError("Unrecognized ExpertDistributionReq value") + return ExpertDistributionReqOutput() def open_session(self, recv_req: OpenSessionReqInput): # handle error diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 25834d32a..05ecbe9c0 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -61,6 +61,7 @@ from sglang.srt.managers.io_struct import ( ConfigureLoggingReq, EmbeddingReqInput, ExpertDistributionReq, + ExpertDistributionReqOutput, FlushCacheReq, GenerateReqInput, GetInternalStateReq, @@ -264,6 +265,9 @@ class TokenizerManager: self.get_internal_state_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.expert_distribution_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self._result_dispatcher = TypeBasedDispatcher( [ @@ -313,6 +317,10 @@ class TokenizerManager: GetInternalStateReqOutput, self.get_internal_state_communicator.handle_recv, ), + ( + ExpertDistributionReqOutput, + self.expert_distribution_communicator.handle_recv, + ), (HealthCheckOutput, lambda x: None), ] ) @@ -639,17 +647,14 @@ class TokenizerManager: req = ProfileReq(type=ProfileReqType.STOP_PROFILE) self.send_to_scheduler.send_pyobj(req) - def start_expert_distribution_record(self): - req = ExpertDistributionReq.START_RECORD - self.send_to_scheduler.send_pyobj(req) + async def start_expert_distribution_record(self): + await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD) - def stop_expert_distribution_record(self): - req = ExpertDistributionReq.STOP_RECORD - self.send_to_scheduler.send_pyobj(req) + async def stop_expert_distribution_record(self): + await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD) - def dump_expert_distribution_record(self): - req = ExpertDistributionReq.DUMP_RECORD - self.send_to_scheduler.send_pyobj(req) + async def dump_expert_distribution_record(self): + await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD) async def update_weights_from_disk( self, diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py index 57f5d7d8d..cd8fb7e3e 100755 --- a/test/srt/test_expert_distribution.py +++ b/test/srt/test_expert_distribution.py @@ -28,9 +28,13 @@ class TestExpertDistribution(unittest.TestCase): def test_expert_distribution_record(self): """Test expert distribution record endpoints""" process = popen_launch_server( - DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST, + # The feature is only implemented in deepseek_v2.py + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", DEFAULT_URL_FOR_TEST, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + ], ) try: @@ -68,7 +72,9 @@ class TestExpertDistribution(unittest.TestCase): # Verify the dumped file exists and has correct format csv_files = glob.glob("expert_distribution_*.csv") self.assertEqual( - len(csv_files), 1, "Expected exactly one expert distribution CSV file" + len(csv_files), + 1, + f"Expected exactly one expert distribution CSV file {csv_files=}", ) # Check CSV file format @@ -97,11 +103,17 @@ class TestExpertDistribution(unittest.TestCase): # Verify data types layer_id, expert_id, count = row - self.assertTrue(layer_id.isdigit(), "layer_id should be an integer") self.assertTrue( - expert_id.isdigit(), "expert_id should be an integer" + layer_id.isdigit(), + f"layer_id should be an integer {row=} {rows=}", + ) + self.assertTrue( + expert_id.isdigit(), + f"expert_id should be an integer {row=} {rows=}", + ) + self.assertTrue( + count.isdigit(), f"count should be an integer {row=} {rows=}" ) - self.assertTrue(count.isdigit(), "count should be an integer") finally: kill_process_tree(process.pid)