From fe97a2d40f9faeac16dbf58fae9161718cdb4b31 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 29 Nov 2024 02:18:51 -0800 Subject: [PATCH] Simplify tokenizer manager (#2254) --- python/sglang/bench_offline_throughput.py | 7 +- python/sglang/srt/managers/image_processor.py | 2 +- python/sglang/srt/managers/io_struct.py | 10 --- python/sglang/srt/managers/scheduler.py | 6 -- .../sglang/srt/managers/tokenizer_manager.py | 32 +------- python/sglang/srt/server.py | 77 +++++++------------ test/srt/test_srt_endpoint.py | 3 - 7 files changed, 34 insertions(+), 103 deletions(-) diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index 70fbb9add..2e9eb1ad2 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -21,14 +21,13 @@ from typing import Dict, List, Optional, Tuple import numpy as np -from sglang.api import Engine from sglang.bench_serving import ( get_dataset, get_tokenizer, sample_random_requests, set_ulimit, ) -from sglang.srt.server import Runtime, start_profile, stop_profile +from sglang.srt.server import Engine, Runtime from sglang.srt.server_args import ServerArgs @@ -204,12 +203,12 @@ def throughput_test_once( st = time.perf_counter() if profile: - start_profile() + backend.start_profile() gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params) if profile: - stop_profile() + backend.stop_profile() monitor_trace_file(os.getenv("SGLANG_TORCH_PROFILER_DIR")) latency = time.perf_counter() - st diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py index 4cfd210ef..7120fa48d 100644 --- a/python/sglang/srt/managers/image_processor.py +++ b/python/sglang/srt/managers/image_processor.py @@ -338,7 +338,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor): "pixel_values": pixel_values, "image_hashes": image_hashes, "image_sizes": image_sizes, - "modalities": request_obj.modalities, + "modalities": request_obj.modalities or ["image"], "image_grid_thws": image_grid_thws, } diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 8b1f88fa2..25cf459af 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -376,16 +376,6 @@ class ProfileReq(Enum): STOP_PROFILE = 2 -@dataclass -class GetMemPoolSizeReq: - pass - - -@dataclass -class GetMemPoolSizeReqOutput: - size: int - - @dataclass class OpenSessionReqInput: capacity_of_str_len: int diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index c7e831811..e0a67b435 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -38,8 +38,6 @@ from sglang.srt.managers.io_struct import ( BatchTokenIDOut, CloseSessionReqInput, FlushCacheReq, - GetMemPoolSizeReq, - GetMemPoolSizeReqOutput, OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, @@ -521,10 +519,6 @@ class Scheduler: self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id)) elif isinstance(recv_req, CloseSessionReqInput): self.close_session(recv_req) - elif isinstance(recv_req, GetMemPoolSizeReq): - self.send_to_tokenizer.send_pyobj( - GetMemPoolSizeReqOutput(self.max_total_num_tokens) - ) else: raise ValueError(f"Invalid request: {recv_req}") diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 15518e9e5..3b3998bec 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -45,8 +45,6 @@ from sglang.srt.managers.io_struct import ( EmbeddingReqInput, FlushCacheReq, GenerateReqInput, - GetMemPoolSizeReq, - GetMemPoolSizeReqOutput, OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, @@ -218,7 +216,7 @@ class TokenizerManager: input_ids = obj.input_ids if self.is_generation: - image_inputs = await self.image_processor.process_images_async( + image_inputs: Dict = await self.image_processor.process_images_async( obj.image_data, input_text or input_ids, obj ) if image_inputs and "input_ids" in image_inputs: @@ -406,25 +404,6 @@ class TokenizerManager: req = ProfileReq.STOP_PROFILE self.send_to_scheduler.send_pyobj(req) - async def get_memory_pool_size(self): - if self.to_create_loop: - self.create_handle_loop() - - req = GetMemPoolSizeReq() - - self.send_to_scheduler.send_pyobj(req) - self.mem_pool_size = asyncio.Future() - - # FIXME: Each request should have its own future instead of using `self.mem_pool_size`. - if self.server_args.dp_size == 1: - res = await self.mem_pool_size - return res.size - else: # self.server_args.dp_size > 1 - self.mem_pool_size_tmp = [] - res = await self.mem_pool_size - ret = [r.size for r in res] - return ret - async def update_weights( self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None ): @@ -552,15 +531,6 @@ class TokenizerManager: if len(self.model_update_tmp) == self.server_args.dp_size: self.model_update_result.set_result(self.model_update_tmp) continue - elif isinstance(recv_obj, GetMemPoolSizeReqOutput): - if self.server_args.dp_size == 1: - self.mem_pool_size.set_result(recv_obj) - else: # self.sever_args.dp_size > 1 - self.mem_pool_size_tmp.append(recv_obj) - # set future if the all results are received - if len(self.mem_pool_size_tmp) == self.server_args.dp_size: - self.mem_pool_size.set_result(self.mem_pool_size_tmp) - continue elif isinstance(recv_obj, OpenSessionReqOutput): self.session_futures[recv_obj.session_id].set_result( recv_obj.session_id diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index c95893067..32523bb9d 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -24,7 +24,6 @@ import logging import multiprocessing as mp import os import signal -import sys import threading import time from http import HTTPStatus @@ -94,7 +93,7 @@ logger = logging.getLogger(__name__) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - +# Fast API app = FastAPI() app.add_middleware( CORSMiddleware, @@ -105,7 +104,7 @@ app.add_middleware( ) tokenizer_manager: TokenizerManager = None -_max_total_num_tokens = None +scheduler_info: Dict = None ##### Native API endpoints ##### @@ -171,16 +170,6 @@ async def flush_cache(): ) -def start_profile(): - """Start profiling.""" - tokenizer_manager.start_profile() - - -def stop_profile(): - """Stop profiling.""" - tokenizer_manager.stop_profile() - - @app.get("/start_profile") @app.post("/start_profile") async def start_profile_async(): @@ -245,6 +234,8 @@ async def close_session(obj: CloseSessionReqInput, request: Request): ) +# fastapi implicitly converts json in the request to obj (dataclass) +@app.api_route("/generate", methods=["POST", "PUT"]) @time_func_latency async def generate_request(obj: GenerateReqInput, request: Request): """Handle a generate request.""" @@ -278,11 +269,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): ) -# fastapi implicitly converts json in the request to obj (dataclass) -app.post("/generate")(generate_request) -app.put("/generate")(generate_request) - - +@app.api_route("/encode", methods=["POST", "PUT"]) @time_func_latency async def encode_request(obj: EmbeddingReqInput, request: Request): """Handle an embedding request.""" @@ -295,10 +282,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request): ) -app.post("/encode")(encode_request) -app.put("/encode")(encode_request) - - +@app.api_route("/encode", methods=["POST", "PUT"]) @time_func_latency async def classify_request(obj: EmbeddingReqInput, request: Request): """Handle a reward model request. Now the arguments and return values are the same as embedding models.""" @@ -311,10 +295,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): ) -app.post("/classify")(classify_request) -app.put("/classify")(classify_request) - - ##### OpenAI-compatible API endpoints ##### @@ -392,11 +372,11 @@ def launch_engine( server_args: ServerArgs, ): """ - Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess. + Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. """ global tokenizer_manager - global _max_total_num_tokens + global scheduler_info # Configure global environment configure_logger(server_args) @@ -462,8 +442,8 @@ def launch_engine( if server_args.chat_template: load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) - # Wait for model to finish loading & get max token nums - scheduler_info = [] + # Wait for model to finish loading + scheduler_infos = [] for i in range(len(scheduler_pipe_readers)): data = scheduler_pipe_readers[i].recv() @@ -471,10 +451,10 @@ def launch_engine( raise RuntimeError( "Initialization failed. Please see the error messages above." ) - scheduler_info.append(data) + scheduler_infos.append(data) # Assume all schedulers have same max_total_num_tokens - _max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] + scheduler_info = scheduler_infos[0] def launch_server( @@ -488,12 +468,12 @@ def launch_server( 1. HTTP server: A FastAPI server that routes requests to the engine. 2. SRT engine: - 1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler. + 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. - 3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. + 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. Note: - 1. The HTTP server and Tokenizer Manager both run in the main process. + 1. The HTTP server and TokenizerManager both run in the main process. 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. """ launch_engine(server_args=server_args) @@ -502,7 +482,7 @@ def launch_server( if server_args.api_key: add_api_key_middleware(app, server_args.api_key) - # add prometheus middleware + # Add prometheus middleware if server_args.enable_metrics: add_prometheus_middleware(app) enable_func_timer() @@ -514,7 +494,7 @@ def launch_server( t.start() try: - # Listen for HTTP requests + # Update logging configs LOGGING_CONFIG["formatters"]["default"][ "fmt" ] = "[%(asctime)s] %(levelprefix)s %(message)s" @@ -523,6 +503,8 @@ def launch_server( "fmt" ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s' LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" + + # Listen for HTTP requests uvicorn.run( app, host=server_args.host, @@ -538,8 +520,7 @@ def launch_server( async def _get_server_info(): return { **dataclasses.asdict(tokenizer_manager.server_args), # server args - "memory_pool_size": await tokenizer_manager.get_memory_pool_size(), # memory pool size - "max_total_num_tokens": _max_total_num_tokens, # max total num tokens + **scheduler_info, "version": __version__, } @@ -645,6 +626,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer): kill_process_tree(os.getpid()) return + # Debug print # logger.info(f"{res.json()=}") logger.info("The server is fired up and ready to roll!") @@ -821,18 +803,11 @@ class Engine: launching the HTTP server adds unnecessary complexity or overhead, """ - def __init__(self, *args, **kwargs): - + def __init__(self, log_level: str = "error", *args, **kwargs): # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() atexit.register(self.shutdown) - # runtime server default log level is log - # offline engine works in scripts, so we set it to error - - if "log_level" not in kwargs: - kwargs["log_level"] = "error" - - server_args = ServerArgs(*args, **kwargs) + server_args = ServerArgs(*args, log_level=log_level, **kwargs) launch_engine(server_args=server_args) def generate( @@ -955,5 +930,11 @@ class Engine: loop = asyncio.get_event_loop() return loop.run_until_complete(encode_request(obj, None)) + def start_profile(self): + tokenizer_manager.start_profile() + + def stop_profile(self): + tokenizer_manager.stop_profile() + async def get_server_info(self): return await _get_server_info() diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 006059e03..aff1d4a78 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -220,9 +220,6 @@ class TestSRTEndpoint(unittest.TestCase): max_total_num_tokens = response_json["max_total_num_tokens"] self.assertIsInstance(max_total_num_tokens, int) - memory_pool_size = response_json["memory_pool_size"] - self.assertIsInstance(memory_pool_size, int) - attention_backend = response_json["attention_backend"] self.assertIsInstance(attention_backend, str)