Simplify tokenizer manager (#2254)
This commit is contained in:
@@ -24,7 +24,6 @@ import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
@@ -94,7 +93,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
# Fast API
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -105,7 +104,7 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
tokenizer_manager: TokenizerManager = None
|
||||
_max_total_num_tokens = None
|
||||
scheduler_info: Dict = None
|
||||
|
||||
##### Native API endpoints #####
|
||||
|
||||
@@ -171,16 +170,6 @@ async def flush_cache():
|
||||
)
|
||||
|
||||
|
||||
def start_profile():
|
||||
"""Start profiling."""
|
||||
tokenizer_manager.start_profile()
|
||||
|
||||
|
||||
def stop_profile():
|
||||
"""Stop profiling."""
|
||||
tokenizer_manager.stop_profile()
|
||||
|
||||
|
||||
@app.get("/start_profile")
|
||||
@app.post("/start_profile")
|
||||
async def start_profile_async():
|
||||
@@ -245,6 +234,8 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
||||
)
|
||||
|
||||
|
||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||
@app.api_route("/generate", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
"""Handle a generate request."""
|
||||
@@ -278,11 +269,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
)
|
||||
|
||||
|
||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||
app.post("/generate")(generate_request)
|
||||
app.put("/generate")(generate_request)
|
||||
|
||||
|
||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||
"""Handle an embedding request."""
|
||||
@@ -295,10 +282,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||
)
|
||||
|
||||
|
||||
app.post("/encode")(encode_request)
|
||||
app.put("/encode")(encode_request)
|
||||
|
||||
|
||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||
@time_func_latency
|
||||
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
||||
@@ -311,10 +295,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||
)
|
||||
|
||||
|
||||
app.post("/classify")(classify_request)
|
||||
app.put("/classify")(classify_request)
|
||||
|
||||
|
||||
##### OpenAI-compatible API endpoints #####
|
||||
|
||||
|
||||
@@ -392,11 +372,11 @@ def launch_engine(
|
||||
server_args: ServerArgs,
|
||||
):
|
||||
"""
|
||||
Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess.
|
||||
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
||||
"""
|
||||
|
||||
global tokenizer_manager
|
||||
global _max_total_num_tokens
|
||||
global scheduler_info
|
||||
|
||||
# Configure global environment
|
||||
configure_logger(server_args)
|
||||
@@ -462,8 +442,8 @@ def launch_engine(
|
||||
if server_args.chat_template:
|
||||
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
||||
|
||||
# Wait for model to finish loading & get max token nums
|
||||
scheduler_info = []
|
||||
# Wait for model to finish loading
|
||||
scheduler_infos = []
|
||||
for i in range(len(scheduler_pipe_readers)):
|
||||
data = scheduler_pipe_readers[i].recv()
|
||||
|
||||
@@ -471,10 +451,10 @@ def launch_engine(
|
||||
raise RuntimeError(
|
||||
"Initialization failed. Please see the error messages above."
|
||||
)
|
||||
scheduler_info.append(data)
|
||||
scheduler_infos.append(data)
|
||||
|
||||
# Assume all schedulers have same max_total_num_tokens
|
||||
_max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
||||
scheduler_info = scheduler_infos[0]
|
||||
|
||||
|
||||
def launch_server(
|
||||
@@ -488,12 +468,12 @@ def launch_server(
|
||||
|
||||
1. HTTP server: A FastAPI server that routes requests to the engine.
|
||||
2. SRT engine:
|
||||
1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler.
|
||||
1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
|
||||
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
|
||||
3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
|
||||
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
|
||||
|
||||
Note:
|
||||
1. The HTTP server and Tokenizer Manager both run in the main process.
|
||||
1. The HTTP server and TokenizerManager both run in the main process.
|
||||
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
|
||||
"""
|
||||
launch_engine(server_args=server_args)
|
||||
@@ -502,7 +482,7 @@ def launch_server(
|
||||
if server_args.api_key:
|
||||
add_api_key_middleware(app, server_args.api_key)
|
||||
|
||||
# add prometheus middleware
|
||||
# Add prometheus middleware
|
||||
if server_args.enable_metrics:
|
||||
add_prometheus_middleware(app)
|
||||
enable_func_timer()
|
||||
@@ -514,7 +494,7 @@ def launch_server(
|
||||
t.start()
|
||||
|
||||
try:
|
||||
# Listen for HTTP requests
|
||||
# Update logging configs
|
||||
LOGGING_CONFIG["formatters"]["default"][
|
||||
"fmt"
|
||||
] = "[%(asctime)s] %(levelprefix)s %(message)s"
|
||||
@@ -523,6 +503,8 @@ def launch_server(
|
||||
"fmt"
|
||||
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
|
||||
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
# Listen for HTTP requests
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=server_args.host,
|
||||
@@ -538,8 +520,7 @@ def launch_server(
|
||||
async def _get_server_info():
|
||||
return {
|
||||
**dataclasses.asdict(tokenizer_manager.server_args), # server args
|
||||
"memory_pool_size": await tokenizer_manager.get_memory_pool_size(), # memory pool size
|
||||
"max_total_num_tokens": _max_total_num_tokens, # max total num tokens
|
||||
**scheduler_info,
|
||||
"version": __version__,
|
||||
}
|
||||
|
||||
@@ -645,6 +626,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
kill_process_tree(os.getpid())
|
||||
return
|
||||
|
||||
# Debug print
|
||||
# logger.info(f"{res.json()=}")
|
||||
|
||||
logger.info("The server is fired up and ready to roll!")
|
||||
@@ -821,18 +803,11 @@ class Engine:
|
||||
launching the HTTP server adds unnecessary complexity or overhead,
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
def __init__(self, log_level: str = "error", *args, **kwargs):
|
||||
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
# runtime server default log level is log
|
||||
# offline engine works in scripts, so we set it to error
|
||||
|
||||
if "log_level" not in kwargs:
|
||||
kwargs["log_level"] = "error"
|
||||
|
||||
server_args = ServerArgs(*args, **kwargs)
|
||||
server_args = ServerArgs(*args, log_level=log_level, **kwargs)
|
||||
launch_engine(server_args=server_args)
|
||||
|
||||
def generate(
|
||||
@@ -955,5 +930,11 @@ class Engine:
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(encode_request(obj, None))
|
||||
|
||||
def start_profile(self):
|
||||
tokenizer_manager.start_profile()
|
||||
|
||||
def stop_profile(self):
|
||||
tokenizer_manager.stop_profile()
|
||||
|
||||
async def get_server_info(self):
|
||||
return await _get_server_info()
|
||||
|
||||
Reference in New Issue
Block a user