Simplify tokenizer manager (#2254)

This commit is contained in:
Lianmin Zheng
2024-11-29 02:18:51 -08:00
committed by GitHub
parent 8b48496aaf
commit fe97a2d40f
7 changed files with 34 additions and 103 deletions

View File

@@ -24,7 +24,6 @@ import logging
import multiprocessing as mp
import os
import signal
import sys
import threading
import time
from http import HTTPStatus
@@ -94,7 +93,7 @@ logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
# Fast API
app = FastAPI()
app.add_middleware(
CORSMiddleware,
@@ -105,7 +104,7 @@ app.add_middleware(
)
tokenizer_manager: TokenizerManager = None
_max_total_num_tokens = None
scheduler_info: Dict = None
##### Native API endpoints #####
@@ -171,16 +170,6 @@ async def flush_cache():
)
def start_profile():
"""Start profiling."""
tokenizer_manager.start_profile()
def stop_profile():
"""Stop profiling."""
tokenizer_manager.stop_profile()
@app.get("/start_profile")
@app.post("/start_profile")
async def start_profile_async():
@@ -245,6 +234,8 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
)
# fastapi implicitly converts json in the request to obj (dataclass)
@app.api_route("/generate", methods=["POST", "PUT"])
@time_func_latency
async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request."""
@@ -278,11 +269,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
)
# fastapi implicitly converts json in the request to obj (dataclass)
app.post("/generate")(generate_request)
app.put("/generate")(generate_request)
@app.api_route("/encode", methods=["POST", "PUT"])
@time_func_latency
async def encode_request(obj: EmbeddingReqInput, request: Request):
"""Handle an embedding request."""
@@ -295,10 +282,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
)
app.post("/encode")(encode_request)
app.put("/encode")(encode_request)
@app.api_route("/encode", methods=["POST", "PUT"])
@time_func_latency
async def classify_request(obj: EmbeddingReqInput, request: Request):
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
@@ -311,10 +295,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
)
app.post("/classify")(classify_request)
app.put("/classify")(classify_request)
##### OpenAI-compatible API endpoints #####
@@ -392,11 +372,11 @@ def launch_engine(
server_args: ServerArgs,
):
"""
Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess.
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
"""
global tokenizer_manager
global _max_total_num_tokens
global scheduler_info
# Configure global environment
configure_logger(server_args)
@@ -462,8 +442,8 @@ def launch_engine(
if server_args.chat_template:
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
# Wait for model to finish loading & get max token nums
scheduler_info = []
# Wait for model to finish loading
scheduler_infos = []
for i in range(len(scheduler_pipe_readers)):
data = scheduler_pipe_readers[i].recv()
@@ -471,10 +451,10 @@ def launch_engine(
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
scheduler_info.append(data)
scheduler_infos.append(data)
# Assume all schedulers have same max_total_num_tokens
_max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
scheduler_info = scheduler_infos[0]
def launch_server(
@@ -488,12 +468,12 @@ def launch_server(
1. HTTP server: A FastAPI server that routes requests to the engine.
2. SRT engine:
1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler.
1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
Note:
1. The HTTP server and Tokenizer Manager both run in the main process.
1. The HTTP server and TokenizerManager both run in the main process.
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
"""
launch_engine(server_args=server_args)
@@ -502,7 +482,7 @@ def launch_server(
if server_args.api_key:
add_api_key_middleware(app, server_args.api_key)
# add prometheus middleware
# Add prometheus middleware
if server_args.enable_metrics:
add_prometheus_middleware(app)
enable_func_timer()
@@ -514,7 +494,7 @@ def launch_server(
t.start()
try:
# Listen for HTTP requests
# Update logging configs
LOGGING_CONFIG["formatters"]["default"][
"fmt"
] = "[%(asctime)s] %(levelprefix)s %(message)s"
@@ -523,6 +503,8 @@ def launch_server(
"fmt"
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
# Listen for HTTP requests
uvicorn.run(
app,
host=server_args.host,
@@ -538,8 +520,7 @@ def launch_server(
async def _get_server_info():
return {
**dataclasses.asdict(tokenizer_manager.server_args), # server args
"memory_pool_size": await tokenizer_manager.get_memory_pool_size(), # memory pool size
"max_total_num_tokens": _max_total_num_tokens, # max total num tokens
**scheduler_info,
"version": __version__,
}
@@ -645,6 +626,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
kill_process_tree(os.getpid())
return
# Debug print
# logger.info(f"{res.json()=}")
logger.info("The server is fired up and ready to roll!")
@@ -821,18 +803,11 @@ class Engine:
launching the HTTP server adds unnecessary complexity or overhead,
"""
def __init__(self, *args, **kwargs):
def __init__(self, log_level: str = "error", *args, **kwargs):
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit.register(self.shutdown)
# runtime server default log level is log
# offline engine works in scripts, so we set it to error
if "log_level" not in kwargs:
kwargs["log_level"] = "error"
server_args = ServerArgs(*args, **kwargs)
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
launch_engine(server_args=server_args)
def generate(
@@ -955,5 +930,11 @@ class Engine:
loop = asyncio.get_event_loop()
return loop.run_until_complete(encode_request(obj, None))
def start_profile(self):
tokenizer_manager.start_profile()
def stop_profile(self):
tokenizer_manager.stop_profile()
async def get_server_info(self):
return await _get_server_info()