Simplify tokenizer manager (#2254)
This commit is contained in:
@@ -21,14 +21,13 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.api import Engine
|
||||
from sglang.bench_serving import (
|
||||
get_dataset,
|
||||
get_tokenizer,
|
||||
sample_random_requests,
|
||||
set_ulimit,
|
||||
)
|
||||
from sglang.srt.server import Runtime, start_profile, stop_profile
|
||||
from sglang.srt.server import Engine, Runtime
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
@@ -204,12 +203,12 @@ def throughput_test_once(
|
||||
|
||||
st = time.perf_counter()
|
||||
if profile:
|
||||
start_profile()
|
||||
backend.start_profile()
|
||||
|
||||
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
|
||||
|
||||
if profile:
|
||||
stop_profile()
|
||||
backend.stop_profile()
|
||||
monitor_trace_file(os.getenv("SGLANG_TORCH_PROFILER_DIR"))
|
||||
|
||||
latency = time.perf_counter() - st
|
||||
|
||||
@@ -338,7 +338,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
"pixel_values": pixel_values,
|
||||
"image_hashes": image_hashes,
|
||||
"image_sizes": image_sizes,
|
||||
"modalities": request_obj.modalities,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
"image_grid_thws": image_grid_thws,
|
||||
}
|
||||
|
||||
|
||||
@@ -376,16 +376,6 @@ class ProfileReq(Enum):
|
||||
STOP_PROFILE = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetMemPoolSizeReq:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetMemPoolSizeReqOutput:
|
||||
size: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenSessionReqInput:
|
||||
capacity_of_str_len: int
|
||||
|
||||
@@ -38,8 +38,6 @@ from sglang.srt.managers.io_struct import (
|
||||
BatchTokenIDOut,
|
||||
CloseSessionReqInput,
|
||||
FlushCacheReq,
|
||||
GetMemPoolSizeReq,
|
||||
GetMemPoolSizeReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -521,10 +519,6 @@ class Scheduler:
|
||||
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
|
||||
elif isinstance(recv_req, CloseSessionReqInput):
|
||||
self.close_session(recv_req)
|
||||
elif isinstance(recv_req, GetMemPoolSizeReq):
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid request: {recv_req}")
|
||||
|
||||
|
||||
@@ -45,8 +45,6 @@ from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
GetMemPoolSizeReq,
|
||||
GetMemPoolSizeReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -218,7 +216,7 @@ class TokenizerManager:
|
||||
input_ids = obj.input_ids
|
||||
|
||||
if self.is_generation:
|
||||
image_inputs = await self.image_processor.process_images_async(
|
||||
image_inputs: Dict = await self.image_processor.process_images_async(
|
||||
obj.image_data, input_text or input_ids, obj
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
@@ -406,25 +404,6 @@ class TokenizerManager:
|
||||
req = ProfileReq.STOP_PROFILE
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
async def get_memory_pool_size(self):
|
||||
if self.to_create_loop:
|
||||
self.create_handle_loop()
|
||||
|
||||
req = GetMemPoolSizeReq()
|
||||
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
self.mem_pool_size = asyncio.Future()
|
||||
|
||||
# FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
|
||||
if self.server_args.dp_size == 1:
|
||||
res = await self.mem_pool_size
|
||||
return res.size
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.mem_pool_size_tmp = []
|
||||
res = await self.mem_pool_size
|
||||
ret = [r.size for r in res]
|
||||
return ret
|
||||
|
||||
async def update_weights(
|
||||
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
@@ -552,15 +531,6 @@ class TokenizerManager:
|
||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||
self.model_update_result.set_result(self.model_update_tmp)
|
||||
continue
|
||||
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
||||
if self.server_args.dp_size == 1:
|
||||
self.mem_pool_size.set_result(recv_obj)
|
||||
else: # self.sever_args.dp_size > 1
|
||||
self.mem_pool_size_tmp.append(recv_obj)
|
||||
# set future if the all results are received
|
||||
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
|
||||
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
|
||||
continue
|
||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
recv_obj.session_id
|
||||
|
||||
@@ -24,7 +24,6 @@ import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
@@ -94,7 +93,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
# Fast API
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -105,7 +104,7 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
tokenizer_manager: TokenizerManager = None
|
||||
_max_total_num_tokens = None
|
||||
scheduler_info: Dict = None
|
||||
|
||||
##### Native API endpoints #####
|
||||
|
||||
@@ -171,16 +170,6 @@ async def flush_cache():
|
||||
)
|
||||
|
||||
|
||||
def start_profile():
|
||||
"""Start profiling."""
|
||||
tokenizer_manager.start_profile()
|
||||
|
||||
|
||||
def stop_profile():
|
||||
"""Stop profiling."""
|
||||
tokenizer_manager.stop_profile()
|
||||
|
||||
|
||||
@app.get("/start_profile")
|
||||
@app.post("/start_profile")
|
||||
async def start_profile_async():
|
||||
@@ -245,6 +234,8 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
||||
)
|
||||
|
||||
|
||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||
@app.api_route("/generate", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
"""Handle a generate request."""
|
||||
@@ -278,11 +269,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
)
|
||||
|
||||
|
||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||
app.post("/generate")(generate_request)
|
||||
app.put("/generate")(generate_request)
|
||||
|
||||
|
||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||
"""Handle an embedding request."""
|
||||
@@ -295,10 +282,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||
)
|
||||
|
||||
|
||||
app.post("/encode")(encode_request)
|
||||
app.put("/encode")(encode_request)
|
||||
|
||||
|
||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
||||
@@ -311,10 +295,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||
)
|
||||
|
||||
|
||||
app.post("/classify")(classify_request)
|
||||
app.put("/classify")(classify_request)
|
||||
|
||||
|
||||
##### OpenAI-compatible API endpoints #####
|
||||
|
||||
|
||||
@@ -392,11 +372,11 @@ def launch_engine(
|
||||
server_args: ServerArgs,
|
||||
):
|
||||
"""
|
||||
Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess.
|
||||
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
||||
"""
|
||||
|
||||
global tokenizer_manager
|
||||
global _max_total_num_tokens
|
||||
global scheduler_info
|
||||
|
||||
# Configure global environment
|
||||
configure_logger(server_args)
|
||||
@@ -462,8 +442,8 @@ def launch_engine(
|
||||
if server_args.chat_template:
|
||||
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
||||
|
||||
# Wait for model to finish loading & get max token nums
|
||||
scheduler_info = []
|
||||
# Wait for model to finish loading
|
||||
scheduler_infos = []
|
||||
for i in range(len(scheduler_pipe_readers)):
|
||||
data = scheduler_pipe_readers[i].recv()
|
||||
|
||||
@@ -471,10 +451,10 @@ def launch_engine(
|
||||
raise RuntimeError(
|
||||
"Initialization failed. Please see the error messages above."
|
||||
)
|
||||
scheduler_info.append(data)
|
||||
scheduler_infos.append(data)
|
||||
|
||||
# Assume all schedulers have same max_total_num_tokens
|
||||
_max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
||||
scheduler_info = scheduler_infos[0]
|
||||
|
||||
|
||||
def launch_server(
|
||||
@@ -488,12 +468,12 @@ def launch_server(
|
||||
|
||||
1. HTTP server: A FastAPI server that routes requests to the engine.
|
||||
2. SRT engine:
|
||||
1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler.
|
||||
1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
|
||||
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
|
||||
3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
|
||||
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
|
||||
|
||||
Note:
|
||||
1. The HTTP server and Tokenizer Manager both run in the main process.
|
||||
1. The HTTP server and TokenizerManager both run in the main process.
|
||||
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
|
||||
"""
|
||||
launch_engine(server_args=server_args)
|
||||
@@ -502,7 +482,7 @@ def launch_server(
|
||||
if server_args.api_key:
|
||||
add_api_key_middleware(app, server_args.api_key)
|
||||
|
||||
# add prometheus middleware
|
||||
# Add prometheus middleware
|
||||
if server_args.enable_metrics:
|
||||
add_prometheus_middleware(app)
|
||||
enable_func_timer()
|
||||
@@ -514,7 +494,7 @@ def launch_server(
|
||||
t.start()
|
||||
|
||||
try:
|
||||
# Listen for HTTP requests
|
||||
# Update logging configs
|
||||
LOGGING_CONFIG["formatters"]["default"][
|
||||
"fmt"
|
||||
] = "[%(asctime)s] %(levelprefix)s %(message)s"
|
||||
@@ -523,6 +503,8 @@ def launch_server(
|
||||
"fmt"
|
||||
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
|
||||
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
# Listen for HTTP requests
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=server_args.host,
|
||||
@@ -538,8 +520,7 @@ def launch_server(
|
||||
async def _get_server_info():
|
||||
return {
|
||||
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
||||
"memory_pool_size": await tokenizer_manager.get_memory_pool_size(), # memory pool size
|
||||
"max_total_num_tokens": _max_total_num_tokens, # max total num tokens
|
||||
**scheduler_info,
|
||||
"version": __version__,
|
||||
}
|
||||
|
||||
@@ -645,6 +626,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
kill_process_tree(os.getpid())
|
||||
return
|
||||
|
||||
# Debug print
|
||||
# logger.info(f"{res.json()=}")
|
||||
|
||||
logger.info("The server is fired up and ready to roll!")
|
||||
@@ -821,18 +803,11 @@ class Engine:
|
||||
launching the HTTP server adds unnecessary complexity or overhead,
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
def __init__(self, log_level: str = "error", *args, **kwargs):
|
||||
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
# runtime server default log level is log
|
||||
# offline engine works in scripts, so we set it to error
|
||||
|
||||
if "log_level" not in kwargs:
|
||||
kwargs["log_level"] = "error"
|
||||
|
||||
server_args = ServerArgs(*args, **kwargs)
|
||||
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
||||
launch_engine(server_args=server_args)
|
||||
|
||||
def generate(
|
||||
@@ -955,5 +930,11 @@ class Engine:
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(encode_request(obj, None))
|
||||
|
||||
def start_profile(self):
|
||||
tokenizer_manager.start_profile()
|
||||
|
||||
def stop_profile(self):
|
||||
tokenizer_manager.stop_profile()
|
||||
|
||||
async def get_server_info(self):
|
||||
return await _get_server_info()
|
||||
|
||||
Reference in New Issue
Block a user