diff --git a/docs/backend/function_calling.ipynb b/docs/backend/function_calling.ipynb index 11971dd90..f1b2b6bf3 100644 --- a/docs/backend/function_calling.ipynb +++ b/docs/backend/function_calling.ipynb @@ -426,7 +426,7 @@ "from sglang.srt.managers.io_struct import Tool, Function\n", "\n", "llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n", - "tokenizer = llm.tokenizer_manager.tokenizer\n", + "tokenizer = llm.orchestrator.tokenizer\n", "input_ids = tokenizer.apply_chat_template(\n", " messages, tokenize=True, add_generation_prompt=True, tools=tools\n", ")\n", diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 13642f580..a7a0716e3 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -48,8 +48,8 @@ from sglang.srt.managers.io_struct import ( UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.scheduler import run_scheduler_process -from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api +from sglang.srt.orchestration.std.orchestrator import StdOrchestrator from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( @@ -74,12 +74,12 @@ class Engine: The entry point to the inference engine. - The engine consists of three components: - 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. + 1. StdOrchestrator: Tokenizes the requests and sends them to the scheduler. 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. Note: - 1. The HTTP server, Engine, and TokenizerManager both run in the main process. + 1. The HTTP server, Engine, and StdOrchestrator both run in the main process. 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. """ @@ -102,10 +102,8 @@ class Engine: atexit.register(self.shutdown) # Launch subprocesses - tokenizer_manager, scheduler_info = _launch_subprocesses( - server_args=server_args - ) - self.tokenizer_manager = tokenizer_manager + orchestrator, scheduler_info = _launch_subprocesses(server_args=server_args) + self.orchestrator = orchestrator self.scheduler_info = scheduler_info def generate( @@ -147,7 +145,7 @@ class Engine: stream=stream, ) loop = asyncio.get_event_loop() - generator = self.tokenizer_manager.generate_request(obj, None) + generator = self.orchestrator.generate_request(obj, None) if stream: @@ -197,7 +195,7 @@ class Engine: stream=stream, custom_logit_processor=custom_logit_processor, ) - generator = self.tokenizer_manager.generate_request(obj, None) + generator = self.orchestrator.generate_request(obj, None) if stream is True: return generator @@ -215,7 +213,7 @@ class Engine: obj = EmbeddingReqInput(text=prompt) loop = asyncio.get_event_loop() - generator = self.tokenizer_manager.generate_request(obj, None) + generator = self.orchestrator.generate_request(obj, None) ret = loop.run_until_complete(generator.__anext__()) return ret @@ -224,14 +222,14 @@ class Engine: kill_process_tree(os.getpid(), include_parent=False) def start_profile(self): - self.tokenizer_manager.start_profile() + self.orchestrator.start_profile() def stop_profile(self): - self.tokenizer_manager.stop_profile() + self.orchestrator.stop_profile() def get_server_info(self): return { - **dataclasses.asdict(self.tokenizer_manager.server_args), # server args + **dataclasses.asdict(self.orchestrator.server_args), # server args **self.scheduler_info, "version": __version__, } @@ -256,7 +254,7 @@ class Engine: ) loop = asyncio.get_event_loop() return loop.run_until_complete( - self.tokenizer_manager.init_weights_update_group(obj, None) + self.orchestrator.init_weights_update_group(obj, None) ) def update_weights_from_distributed(self, name: str, dtype, shape): @@ -268,7 +266,7 @@ class Engine: ) loop = asyncio.get_event_loop() return loop.run_until_complete( - self.tokenizer_manager.update_weights_from_distributed(obj, None) + self.orchestrator.update_weights_from_distributed(obj, None) ) def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]): @@ -278,23 +276,21 @@ class Engine: ) loop = asyncio.get_event_loop() return loop.run_until_complete( - self.tokenizer_manager.update_weights_from_tensor(obj, None) + self.orchestrator.update_weights_from_tensor(obj, None) ) def get_weights_by_name(self, name: str, truncate_size: int = 100): """Get weights by parameter name.""" obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) loop = asyncio.get_event_loop() - return loop.run_until_complete( - self.tokenizer_manager.get_weights_by_name(obj, None) - ) + return loop.run_until_complete(self.orchestrator.get_weights_by_name(obj, None)) def release_memory_occupation(self): """Release GPU occupation temporarily.""" obj = ReleaseMemoryOccupationReqInput() loop = asyncio.get_event_loop() return loop.run_until_complete( - self.tokenizer_manager.release_memory_occupation(obj, None) + self.orchestrator.release_memory_occupation(obj, None) ) def resume_memory_occupation(self): @@ -302,7 +298,7 @@ class Engine: obj = ResumeMemoryOccupationReqInput() loop = asyncio.get_event_loop() return loop.run_until_complete( - self.tokenizer_manager.resume_memory_occupation(obj, None) + self.orchestrator.resume_memory_occupation(obj, None) ) @@ -351,9 +347,9 @@ def _set_envs_and_config(server_args: ServerArgs): mp.set_start_method("spawn", force=True) -def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]: +def _launch_subprocesses(server_args: ServerArgs) -> Tuple[StdOrchestrator, Dict]: """ - Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. + Launch the StdOrchestrator in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. """ # Configure global environment configure_logger(server_args) @@ -436,10 +432,10 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic detoken_proc.start() # Launch tokenizer process - tokenizer_manager = TokenizerManager(server_args, port_args) + orchestrator = StdOrchestrator(server_args, port_args) if server_args.chat_template: load_chat_template_for_openai_api( - tokenizer_manager, server_args.chat_template, server_args.model_path + orchestrator, server_args.chat_template, server_args.model_path ) # Wait for the model to finish loading @@ -463,5 +459,5 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic # Assume all schedulers have the same scheduler_info scheduler_info = scheduler_infos[0] - tokenizer_manager.configure_max_req_input_len(scheduler_info["max_req_input_len"]) - return tokenizer_manager, scheduler_info + orchestrator.configure_max_req_input_len(scheduler_info["max_req_input_len"]) + return orchestrator, scheduler_info diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 2b2421a37..71828a91b 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -54,7 +54,6 @@ from sglang.srt.managers.io_struct import ( UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, ) -from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.metrics.func_timer import enable_func_timer from sglang.srt.openai_api.adapter import ( v1_batches, @@ -69,6 +68,7 @@ from sglang.srt.openai_api.adapter import ( v1_retrieve_file_content, ) from sglang.srt.openai_api.protocol import ModelCard, ModelList +from sglang.srt.orchestration.std.orchestrator import StdOrchestrator from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( add_api_key_middleware, @@ -97,7 +97,7 @@ app.add_middleware( # Store global states @dataclasses.dataclass class _GlobalState: - tokenizer_manager: TokenizerManager + orchestrator: StdOrchestrator scheduler_info: Dict @@ -124,7 +124,7 @@ async def health_generate(request: Request) -> Response: sampling_params = {"max_new_tokens": 1, "temperature": 0.7} - if _global_state.tokenizer_manager.is_generation: + if _global_state.orchestrator.is_generation: gri = GenerateReqInput( input_ids=[0], sampling_params=sampling_params, log_metrics=False ) @@ -134,7 +134,7 @@ async def health_generate(request: Request) -> Response: ) try: - async for _ in _global_state.tokenizer_manager.generate_request(gri, request): + async for _ in _global_state.orchestrator.generate_request(gri, request): break return Response(status_code=200) except Exception as e: @@ -146,9 +146,9 @@ async def health_generate(request: Request) -> Response: async def get_model_info(): """Get the model information.""" result = { - "model_path": _global_state.tokenizer_manager.model_path, - "tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path, - "is_generation": _global_state.tokenizer_manager.is_generation, + "model_path": _global_state.orchestrator.model_path, + "tokenizer_path": _global_state.orchestrator.server_args.tokenizer_path, + "is_generation": _global_state.orchestrator.is_generation, } return result @@ -156,7 +156,7 @@ async def get_model_info(): @app.get("/get_server_info") async def get_server_info(): return { - **dataclasses.asdict(_global_state.tokenizer_manager.server_args), + **dataclasses.asdict(_global_state.orchestrator.server_args), **_global_state.scheduler_info, "version": __version__, } @@ -170,7 +170,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): async def stream_results() -> AsyncIterator[bytes]: try: - async for out in _global_state.tokenizer_manager.generate_request( + async for out in _global_state.orchestrator.generate_request( obj, request ): yield b"data: " + orjson.dumps( @@ -186,11 +186,11 @@ async def generate_request(obj: GenerateReqInput, request: Request): return StreamingResponse( stream_results(), media_type="text/event-stream", - background=_global_state.tokenizer_manager.create_abort_task(obj), + background=_global_state.orchestrator.create_abort_task(obj), ) else: try: - ret = await _global_state.tokenizer_manager.generate_request( + ret = await _global_state.orchestrator.generate_request( obj, request ).__anext__() return ret @@ -203,7 +203,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): async def encode_request(obj: EmbeddingReqInput, request: Request): """Handle an embedding request.""" try: - ret = await _global_state.tokenizer_manager.generate_request( + ret = await _global_state.orchestrator.generate_request( obj, request ).__anext__() return ret @@ -215,7 +215,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request): async def classify_request(obj: EmbeddingReqInput, request: Request): """Handle a reward model request. Now the arguments and return values are the same as embedding models.""" try: - ret = await _global_state.tokenizer_manager.generate_request( + ret = await _global_state.orchestrator.generate_request( obj, request ).__anext__() return ret @@ -226,7 +226,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): @app.post("/flush_cache") async def flush_cache(): """Flush the radix cache.""" - _global_state.tokenizer_manager.flush_cache() + _global_state.orchestrator.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " "(When there are running or waiting requests, the operation will not be performed.)\n", @@ -237,7 +237,7 @@ async def flush_cache(): @app.api_route("/start_profile", methods=["GET", "POST"]) async def start_profile_async(): """Start profiling.""" - _global_state.tokenizer_manager.start_profile() + _global_state.orchestrator.start_profile() return Response( content="Start profiling.\n", status_code=200, @@ -247,7 +247,7 @@ async def start_profile_async(): @app.api_route("/stop_profile", methods=["GET", "POST"]) async def stop_profile_async(): """Stop profiling.""" - _global_state.tokenizer_manager.stop_profile() + _global_state.orchestrator.stop_profile() return Response( content="Stop profiling. This will take some time.\n", status_code=200, @@ -257,7 +257,7 @@ async def stop_profile_async(): @app.post("/update_weights_from_disk") async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): """Update the weights from disk in-place without re-launching the server.""" - success, message = await _global_state.tokenizer_manager.update_weights_from_disk( + success, message = await _global_state.orchestrator.update_weights_from_disk( obj, request ) content = {"success": success, "message": message} @@ -278,7 +278,7 @@ async def init_weights_update_group( obj: InitWeightsUpdateGroupReqInput, request: Request ): """Initialize the parameter update group.""" - success, message = await _global_state.tokenizer_manager.init_weights_update_group( + success, message = await _global_state.orchestrator.init_weights_update_group( obj, request ) content = {"success": success, "message": message} @@ -293,10 +293,8 @@ async def update_weights_from_distributed( obj: UpdateWeightsFromDistributedReqInput, request: Request ): """Update model parameter from distributed online.""" - success, message = ( - await _global_state.tokenizer_manager.update_weights_from_distributed( - obj, request - ) + success, message = await _global_state.orchestrator.update_weights_from_distributed( + obj, request ) content = {"success": success, "message": message} if success: @@ -309,7 +307,7 @@ async def update_weights_from_distributed( async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): """Get model parameter by name.""" try: - ret = await _global_state.tokenizer_manager.get_weights_by_name(obj, request) + ret = await _global_state.orchestrator.get_weights_by_name(obj, request) if ret is None: return _create_error_response("Get parameter by name failed") else: @@ -324,7 +322,7 @@ async def release_memory_occupation( ): """Release GPU occupation temporarily""" try: - await _global_state.tokenizer_manager.release_memory_occupation(obj, request) + await _global_state.orchestrator.release_memory_occupation(obj, request) except Exception as e: return _create_error_response(e) @@ -335,7 +333,7 @@ async def resume_memory_occupation( ): """Resume GPU occupation""" try: - await _global_state.tokenizer_manager.resume_memory_occupation(obj, request) + await _global_state.orchestrator.resume_memory_occupation(obj, request) except Exception as e: return _create_error_response(e) @@ -344,7 +342,7 @@ async def resume_memory_occupation( async def open_session(obj: OpenSessionReqInput, request: Request): """Open a session, and return its unique session id.""" try: - session_id = await _global_state.tokenizer_manager.open_session(obj, request) + session_id = await _global_state.orchestrator.open_session(obj, request) if session_id is None: raise Exception( "Failed to open the session. Check if a session with the same id is still open." @@ -358,7 +356,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request): async def close_session(obj: CloseSessionReqInput, request: Request): """Close the session""" try: - await _global_state.tokenizer_manager.close_session(obj, request) + await _global_state.orchestrator.close_session(obj, request) return Response(status_code=200) except Exception as e: return _create_error_response(e) @@ -367,7 +365,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request): @app.api_route("/configure_logging", methods=["GET", "POST"]) async def configure_logging(obj: ConfigureLoggingReq, request: Request): """Close the session""" - _global_state.tokenizer_manager.configure_logging(obj) + _global_state.orchestrator.configure_logging(obj) return Response(status_code=200) @@ -398,24 +396,24 @@ async def function_call_request(obj: FunctionCallReqInput, request: Request): @app.post("/v1/completions") async def openai_v1_completions(raw_request: Request): - return await v1_completions(_global_state.tokenizer_manager, raw_request) + return await v1_completions(_global_state.orchestrator, raw_request) @app.post("/v1/chat/completions") async def openai_v1_chat_completions(raw_request: Request): - return await v1_chat_completions(_global_state.tokenizer_manager, raw_request) + return await v1_chat_completions(_global_state.orchestrator, raw_request) @app.post("/v1/embeddings", response_class=ORJSONResponse) async def openai_v1_embeddings(raw_request: Request): - response = await v1_embeddings(_global_state.tokenizer_manager, raw_request) + response = await v1_embeddings(_global_state.orchestrator, raw_request) return response @app.get("/v1/models", response_class=ORJSONResponse) def available_models(): """Show available models.""" - served_model_names = [_global_state.tokenizer_manager.served_model_name] + served_model_names = [_global_state.orchestrator.served_model_name] model_cards = [] for served_model_name in served_model_names: model_cards.append(ModelCard(id=served_model_name, root=served_model_name)) @@ -425,7 +423,7 @@ def available_models(): @app.post("/v1/files") async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): return await v1_files_create( - file, purpose, _global_state.tokenizer_manager.server_args.file_storage_pth + file, purpose, _global_state.orchestrator.server_args.file_storage_pth ) @@ -437,13 +435,13 @@ async def delete_file(file_id: str): @app.post("/v1/batches") async def openai_v1_batches(raw_request: Request): - return await v1_batches(_global_state.tokenizer_manager, raw_request) + return await v1_batches(_global_state.orchestrator, raw_request) @app.post("/v1/batches/{batch_id}/cancel") async def cancel_batches(batch_id: str): # https://platform.openai.com/docs/api-reference/batch/cancel - return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id) + return await v1_cancel_batch(_global_state.orchestrator, batch_id) @app.get("/v1/batches/{batch_id}") @@ -492,18 +490,18 @@ def launch_server( - HTTP server: A FastAPI server that routes requests to the engine. - The engine consists of three components: - 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. + 1. StdOrchestrator: Tokenizes the requests and sends them to the scheduler. 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. Note: - 1. The HTTP server, Engine, and TokenizerManager both run in the main process. + 1. The HTTP server, Engine, and StdOrchestrator both run in the main process. 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. """ - tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args) + orchestrator, scheduler_info = _launch_subprocesses(server_args=server_args) set_global_state( _GlobalState( - tokenizer_manager=tokenizer_manager, + orchestrator=orchestrator, scheduler_info=scheduler_info, ) ) @@ -523,7 +521,7 @@ def launch_server( args=( server_args, pipe_finish_writer, - _global_state.tokenizer_manager.image_token_id, + _global_state.orchestrator.image_token_id, ), ) t.start() diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py index 57fc4a6b4..c3d4b33c6 100644 --- a/python/sglang/srt/managers/image_processor.py +++ b/python/sglang/srt/managers/image_processor.py @@ -241,7 +241,7 @@ class LlavaImageProcessor(BaseImageProcessor): return pixel_values, image_hash, image.size except Exception: - logger.error("Exception in TokenizerManager:\n" + get_exception_traceback()) + logger.error("Exception in StdOrchestrator:\n" + get_exception_traceback()) async def _process_single_image( self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str @@ -491,7 +491,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor): return pixel_values, image_hash, image.size, image_grid_thws except Exception: - logger.error("Exception in TokenizerManager:\n" + get_exception_traceback()) + logger.error("Exception in StdOrchestrator:\n" + get_exception_traceback()) async def _process_single_image(self, image_data: Union[bytes, str]): if self.executor is not None: diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 67225cf84..3530c2e4d 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -13,7 +13,7 @@ # ============================================================================== """ The definition of objects transfered between different -processes (TokenizerManager, DetokenizerManager, Controller). +processes (StdOrchestrator, DetokenizerManager, Controller). """ import uuid diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 9c00c8b25..a2f8b0479 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -173,7 +173,7 @@ class Scheduler: ) if server_args.skip_tokenizer_init: - # Directly send to the TokenizerManager + # Directly send to the StdOrchestrator self.send_to_detokenizer = get_zmq_socket( context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 0556f852a..cf257e9c3 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -117,7 +117,7 @@ def create_streaming_error_response( return json_str -def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, model_path): +def load_chat_template_for_openai_api(orchestrator, chat_template_arg, model_path): global chat_template_name logger.info( @@ -133,9 +133,7 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode if chat_template_arg.endswith(".jinja"): with open(chat_template_arg, "r") as f: chat_template = "".join(f.readlines()).strip("\n") - tokenizer_manager.tokenizer.chat_template = chat_template.replace( - "\\n", "\n" - ) + orchestrator.tokenizer.chat_template = chat_template.replace("\\n", "\n") chat_template_name = None else: assert chat_template_arg.endswith( @@ -231,7 +229,7 @@ async def v1_delete_file(file_id: str): return FileDeleteResponse(id=file_id, deleted=True) -async def v1_batches(tokenizer_manager, raw_request: Request): +async def v1_batches(orchestrator, raw_request: Request): try: body = await raw_request.json() @@ -252,7 +250,7 @@ async def v1_batches(tokenizer_manager, raw_request: Request): batch_storage[batch_id] = batch_response # Start processing the batch asynchronously - asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request)) + asyncio.create_task(process_batch(orchestrator, batch_id, batch_request)) # Return the initial batch_response return batch_response @@ -263,7 +261,7 @@ async def v1_batches(tokenizer_manager, raw_request: Request): return {"error": str(e)} -async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest): +async def process_batch(orchestrator, batch_id: str, batch_request: BatchRequest): try: # Update the batch status to "in_progress" batch_storage[batch_id].status = "in_progress" @@ -306,7 +304,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe if end_point == "/v1/chat/completions": adapted_request, request = v1_chat_generate_request( - all_requests, tokenizer_manager, request_ids=request_ids + all_requests, orchestrator, request_ids=request_ids ) elif end_point == "/v1/completions": adapted_request, request = v1_generate_request( @@ -314,7 +312,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe ) try: - ret = await tokenizer_manager.generate_request(adapted_request).__anext__() + ret = await orchestrator.generate_request(adapted_request).__anext__() if not isinstance(ret, list): ret = [ret] if end_point == "/v1/chat/completions": @@ -322,12 +320,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe request, ret, to_file=True, - cache_report=tokenizer_manager.server_args.enable_cache_report, - tool_call_parser=tokenizer_manager.server_args.tool_call_parser, + cache_report=orchestrator.server_args.enable_cache_report, + tool_call_parser=orchestrator.server_args.tool_call_parser, ) else: responses = v1_generate_response( - request, ret, tokenizer_manager, to_file=True + request, ret, orchestrator, to_file=True ) except Exception as e: @@ -399,7 +397,7 @@ async def v1_retrieve_batch(batch_id: str): return batch_response -async def v1_cancel_batch(tokenizer_manager, batch_id: str): +async def v1_cancel_batch(orchestrator, batch_id: str): # Retrieve the batch job from the in-memory storage batch_response = batch_storage.get(batch_id) if batch_response is None: @@ -410,7 +408,7 @@ async def v1_cancel_batch(tokenizer_manager, batch_id: str): # Start cancelling the batch asynchronously asyncio.create_task( cancel_batch( - tokenizer_manager=tokenizer_manager, + orchestrator=orchestrator, batch_id=batch_id, input_file_id=batch_response.input_file_id, ) @@ -427,7 +425,7 @@ async def v1_cancel_batch(tokenizer_manager, batch_id: str): ) -async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str): +async def cancel_batch(orchestrator, batch_id: str, input_file_id: str): try: # Update the batch status to "cancelling" batch_storage[batch_id].status = "cancelling" @@ -451,7 +449,7 @@ async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str): # Cancel requests by request_ids for rid in request_ids: - tokenizer_manager.abort_request(rid=rid) + orchestrator.abort_request(rid=rid) retrieve_batch = batch_storage[batch_id] retrieve_batch.status = "cancelled" @@ -579,7 +577,7 @@ def v1_generate_request( return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] -def v1_generate_response(request, ret, tokenizer_manager, to_file=False): +def v1_generate_response(request, ret, orchestrator, to_file=False): choices = [] echo = False @@ -591,15 +589,13 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False): elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list): # for the case of multiple token ids prompts prompts = [ - tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True) + orchestrator.tokenizer.decode(prompt, skip_special_tokens=True) for prompt in request.prompt ] elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int): # for the case of single token ids prompt prompts = [ - tokenizer_manager.tokenizer.decode( - request.prompt, skip_special_tokens=True - ) + orchestrator.tokenizer.decode(request.prompt, skip_special_tokens=True) ] else: # for the case of single str prompt @@ -709,7 +705,7 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False): return response -async def v1_completions(tokenizer_manager, raw_request: Request): +async def v1_completions(orchestrator, raw_request: Request): request_json = await raw_request.json() all_requests = [CompletionRequest(**request_json)] adapted_request, request = v1_generate_request(all_requests) @@ -722,7 +718,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request): prompt_tokens = {} completion_tokens = {} try: - async for content in tokenizer_manager.generate_request( + async for content in orchestrator.generate_request( adapted_request, raw_request ): index = content.get("index", 0) @@ -745,14 +741,14 @@ async def v1_completions(tokenizer_manager, raw_request: Request): prompts = request.prompt[index // request.n] elif isinstance(request.prompt[0], int): # for the case of single token ids prompt - prompts = tokenizer_manager.tokenizer.decode( + prompts = orchestrator.tokenizer.decode( request.prompt, skip_special_tokens=True ) elif isinstance(request.prompt[0], list) and isinstance( request.prompt[0][0], int ): # for the case of multiple token ids prompts - prompts = tokenizer_manager.tokenizer.decode( + prompts = orchestrator.tokenizer.decode( request.prompt[index // request.n], skip_special_tokens=True, ) @@ -847,12 +843,12 @@ async def v1_completions(tokenizer_manager, raw_request: Request): return StreamingResponse( generate_stream_resp(), media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(adapted_request), + background=orchestrator.create_abort_task(adapted_request), ) # Non-streaming response. try: - ret = await tokenizer_manager.generate_request( + ret = await orchestrator.generate_request( adapted_request, raw_request ).__anext__() except ValueError as e: @@ -861,13 +857,13 @@ async def v1_completions(tokenizer_manager, raw_request: Request): if not isinstance(ret, list): ret = [ret] - response = v1_generate_response(request, ret, tokenizer_manager) + response = v1_generate_response(request, ret, orchestrator) return response def v1_chat_generate_request( all_requests: List[ChatCompletionRequest], - tokenizer_manager, + orchestrator, request_ids: List[str] = None, ): input_ids = [] @@ -922,7 +918,7 @@ def v1_chat_generate_request( assistant_prefix = None try: - prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( + prompt_ids = orchestrator.tokenizer.apply_chat_template( openai_compatible_messages, tokenize=True, add_generation_prompt=True, @@ -933,7 +929,7 @@ def v1_chat_generate_request( # has a different tools input format that is not compatiable # with openAI's apply_chat_template tool_call format, like Mistral. tools = [t if "function" in t else {"function": t} for t in tools] - prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( + prompt_ids = orchestrator.tokenizer.apply_chat_template( openai_compatible_messages, tokenize=True, add_generation_prompt=True, @@ -941,11 +937,8 @@ def v1_chat_generate_request( ) if assistant_prefix: - encoded = tokenizer_manager.tokenizer.encode(assistant_prefix) - if ( - encoded - and encoded[0] == tokenizer_manager.tokenizer.bos_token_id - ): + encoded = orchestrator.tokenizer.encode(assistant_prefix) + if encoded and encoded[0] == orchestrator.tokenizer.bos_token_id: encoded = encoded[1:] prompt_ids += encoded stop = request.stop @@ -962,7 +955,7 @@ def v1_chat_generate_request( stop.append(request.stop) else: stop.extend(request.stop) - prompt_ids = tokenizer_manager.tokenizer.encode(prompt) + prompt_ids = orchestrator.tokenizer.encode(prompt) else: # Use the raw prompt and stop strings if the messages is already a string. prompt_ids = request.messages @@ -1201,10 +1194,10 @@ def v1_chat_generate_response( return response -async def v1_chat_completions(tokenizer_manager, raw_request: Request): +async def v1_chat_completions(orchestrator, raw_request: Request): request_json = await raw_request.json() all_requests = [ChatCompletionRequest(**request_json)] - adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager) + adapted_request, request = v1_chat_generate_request(all_requests, orchestrator) if adapted_request.stream: parser_dict = {} @@ -1216,7 +1209,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): prompt_tokens = {} completion_tokens = {} try: - async for content in tokenizer_manager.generate_request( + async for content in orchestrator.generate_request( adapted_request, raw_request ): index = content.get("index", 0) @@ -1306,7 +1299,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): if index not in parser_dict: parser_dict[index] = FunctionCallParser( tools=request.tools, - tool_call_parser=tokenizer_manager.server_args.tool_call_parser, + tool_call_parser=orchestrator.server_args.tool_call_parser, ) parser = parser_dict[index] @@ -1438,12 +1431,12 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): return StreamingResponse( generate_stream_resp(), media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(adapted_request), + background=orchestrator.create_abort_task(adapted_request), ) # Non-streaming response. try: - ret = await tokenizer_manager.generate_request( + ret = await orchestrator.generate_request( adapted_request, raw_request ).__anext__() except ValueError as e: @@ -1454,14 +1447,14 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): response = v1_chat_generate_response( request, ret, - cache_report=tokenizer_manager.server_args.enable_cache_report, - tool_call_parser=tokenizer_manager.server_args.tool_call_parser, + cache_report=orchestrator.server_args.enable_cache_report, + tool_call_parser=orchestrator.server_args.tool_call_parser, ) return response -def v1_embedding_request(all_requests, tokenizer_manager): +def v1_embedding_request(all_requests, orchestrator): prompts = [] sampling_params_list = [] first_prompt_type = type(all_requests[0].input) @@ -1516,13 +1509,13 @@ def v1_embedding_response(ret, model_path, to_file=False): ) -async def v1_embeddings(tokenizer_manager, raw_request: Request): +async def v1_embeddings(orchestrator, raw_request: Request): request_json = await raw_request.json() all_requests = [EmbeddingRequest(**request_json)] - adapted_request, request = v1_embedding_request(all_requests, tokenizer_manager) + adapted_request, request = v1_embedding_request(all_requests, orchestrator) try: - ret = await tokenizer_manager.generate_request( + ret = await orchestrator.generate_request( adapted_request, raw_request ).__anext__() except ValueError as e: @@ -1531,7 +1524,7 @@ async def v1_embeddings(tokenizer_manager, raw_request: Request): if not isinstance(ret, list): ret = [ret] - response = v1_embedding_response(ret, tokenizer_manager.model_path) + response = v1_embedding_response(ret, orchestrator.model_path) return response diff --git a/python/sglang/srt/orchestration/__init__.py b/python/sglang/srt/orchestration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/sglang/srt/orchestration/std/__init__.py b/python/sglang/srt/orchestration/std/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/orchestration/std/orchestrator.py similarity index 97% rename from python/sglang/srt/managers/tokenizer_manager.py rename to python/sglang/srt/orchestration/std/orchestrator.py index 0a7b95d52..cd4a7633c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/orchestration/std/orchestrator.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TokenizerManager is a process that tokenizes the text.""" import asyncio import logging @@ -66,8 +65,8 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) logger = logging.getLogger(__name__) -class TokenizerManager: - """TokenizerManager is a process that tokenizes the text.""" +class StdOrchestrator: + """StdOrchestrator is the primary entrypoint of orchestration.std package""" def __init__( self, @@ -439,20 +438,20 @@ async def print_exception_wrapper(func): await func() except Exception: traceback = get_exception_traceback() - logger.error(f"TokenizerManager hit an exception: {traceback}") + logger.error(f"StdOrchestrator hit an exception: {traceback}") kill_process_tree(os.getpid(), include_parent=True) sys.exit(1) class SignalHandler: - def __init__(self, tokenizer_manager): - self.tokenizer_manager = tokenizer_manager + def __init__(self, orchestrator): + self.orchestrator = orchestrator def signal_handler(self, signum=None, frame=None): logger.warning( f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..." ) - self.tokenizer_manager.gracefully_exit = True + self.orchestrator.gracefully_exit = True T = TypeVar("T") diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b4c6a1224..994fb121b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1039,7 +1039,7 @@ class PortArgs: if dp_rank is None: scheduler_input_port = ( port_base + 2 - ) # TokenizerManager to DataParallelController + ) # StdOrchestrator to DataParallelController else: scheduler_input_port = port_base + 2 + 1 + dp_rank