diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 0f258a9d9..695ccefeb 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -712,6 +712,26 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re return ORJSONResponse(content=response_data, status_code=200) +@app.post("/pause_generation") +async def pause_generation(request: Request): + """Pause generation.""" + await _global_state.tokenizer_manager.pause_generation() + return ORJSONResponse( + content={"message": "Generation paused successfully.", "status": "ok"}, + status_code=200, + ) + + +@app.post("/continue_generation") +async def continue_generation(request: Request): + """Continue generation.""" + await _global_state.tokenizer_manager.continue_generation() + return ORJSONResponse( + content={"message": "Generation continued successfully.", "status": "ok"}, + status_code=200, + ) + + ##### OpenAI-compatible API endpoints ##### diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 15635f5c1..a030bf367 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -203,6 +203,8 @@ class TokenizerManager: self.is_image_gen = self.model_config.is_image_gen self.context_len = self.model_config.context_len self.image_token_id = self.model_config.image_token_id + self._updating = False + self._cond = asyncio.Condition() if self.model_config.is_multimodal: import_processors() @@ -421,6 +423,9 @@ class TokenizerManager: request: Optional[fastapi.Request] = None, ): created_time = time.time() + async with self._cond: + await self._cond.wait_for(lambda: not self._updating) + self.auto_create_handle_loop() obj.normalize_batch_and_arguments() @@ -902,6 +907,16 @@ class TokenizerManager: self.auto_create_handle_loop() await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD) + async def pause_generation(self): + async with self._cond: + self._updating = True + self.abort_request(abort_all=True) + + async def continue_generation(self): + async with self._cond: + self._updating = False + self._cond.notify_all() + async def update_weights_from_disk( self, obj: UpdateWeightFromDiskReqInput,