OAI Server Skeleton & Core Utility Endpoints (#7179)
This commit is contained in:
363
python/sglang/srt/entrypoints/openai/api_server.py
Normal file
363
python/sglang/srt/entrypoints/openai/api_server.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
SGLang OpenAI-Compatible API Server.
|
||||
|
||||
This file implements OpenAI-compatible HTTP APIs for the inference engine via FastAPI.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import uvicorn
|
||||
import uvloop
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import Response
|
||||
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
FakeBootstrapHost,
|
||||
register_disaggregation_server,
|
||||
)
|
||||
from sglang.srt.entrypoints.engine import Engine, _launch_subprocesses
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
add_prometheus_middleware,
|
||||
delete_directory,
|
||||
get_bool_env_var,
|
||||
kill_process_tree,
|
||||
set_uvicorn_logging_configs,
|
||||
)
|
||||
from sglang.srt.warmup import execute_warmups
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
# Store global states
|
||||
class AppState:
|
||||
engine: Optional[Engine] = None
|
||||
server_args: Optional[ServerArgs] = None
|
||||
tokenizer_manager: Optional[TokenizerManager] = None
|
||||
scheduler_info: Optional[Dict] = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
app.state.server_args.enable_metrics = True # By default, we enable metrics
|
||||
|
||||
server_args = app.state.server_args
|
||||
|
||||
# Initialize engine
|
||||
logger.info(f"SGLang OpenAI server (PID: {os.getpid()}) is initializing...")
|
||||
|
||||
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
|
||||
app.state.tokenizer_manager = tokenizer_manager
|
||||
app.state.scheduler_info = scheduler_info
|
||||
|
||||
if server_args.enable_metrics:
|
||||
add_prometheus_middleware(app)
|
||||
enable_func_timer()
|
||||
|
||||
# Initialize engine state attribute to None for now
|
||||
app.state.engine = None
|
||||
|
||||
if server_args.warmups is not None:
|
||||
await execute_warmups(
|
||||
server_args.warmups.split(","), app.state.tokenizer_manager
|
||||
)
|
||||
logger.info("Warmup ended")
|
||||
|
||||
warmup_thread = getattr(app, "warmup_thread", None)
|
||||
if warmup_thread is not None:
|
||||
warmup_thread.start()
|
||||
|
||||
yield
|
||||
|
||||
# Lifespan shutdown
|
||||
if hasattr(app.state, "engine") and app.state.engine is not None:
|
||||
logger.info("SGLang engine is shutting down.")
|
||||
# Add engine cleanup logic here when implemented
|
||||
|
||||
|
||||
# Fast API app with CORS enabled
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
# TODO: check where /openai.json is created or why we use this
|
||||
openapi_url=None if get_bool_env_var("DISABLE_OPENAPI_DOC") else "/openapi.json",
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/health", methods=["GET"])
|
||||
async def health() -> Response:
|
||||
"""Health check. Used for readiness and liveness probes."""
|
||||
# In the future, this could check engine health more deeply
|
||||
# For now, if the server is up, it's healthy.
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.api_route("/v1/models", methods=["GET"])
|
||||
async def show_models():
|
||||
"""Show available models. Currently, it returns the served model name.
|
||||
|
||||
This endpoint is compatible with the OpenAI API standard.
|
||||
"""
|
||||
served_model_names = [app.state.tokenizer_manager.served_model_name]
|
||||
model_cards = []
|
||||
for served_model_name in served_model_names:
|
||||
model_cards.append(
|
||||
ModelCard(
|
||||
id=served_model_name,
|
||||
root=served_model_name,
|
||||
max_model_len=app.state.tokenizer_manager.model_config.context_len,
|
||||
)
|
||||
)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
|
||||
@app.get("/get_model_info")
|
||||
async def get_model_info():
|
||||
"""Get the model information."""
|
||||
result = {
|
||||
"model_path": app.state.tokenizer_manager.model_path,
|
||||
"tokenizer_path": app.state.tokenizer_manager.server_args.tokenizer_path,
|
||||
"is_generation": app.state.tokenizer_manager.is_generation,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def openai_v1_completions(raw_request: Request):
|
||||
pass
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def openai_v1_chat_completions(raw_request: Request):
|
||||
pass
|
||||
|
||||
|
||||
@app.post("/v1/embeddings")
|
||||
async def openai_v1_embeddings(raw_request: Request):
|
||||
pass
|
||||
|
||||
|
||||
@app.post("/v1/score")
|
||||
async def v1_score_request(raw_request: Request):
|
||||
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
|
||||
pass
|
||||
|
||||
|
||||
# Additional API endpoints will be implemented in separate serving_*.py modules
|
||||
# and mounted as APIRouters in future PRs
|
||||
|
||||
|
||||
def _wait_and_warmup(
|
||||
server_args: ServerArgs,
|
||||
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
|
||||
image_token_text: str,
|
||||
launch_callback: Optional[Callable[[], None]] = None,
|
||||
):
|
||||
return
|
||||
# TODO: Please wait until the /generate implementation is complete,
|
||||
# or confirm if modifications are needed before removing this.
|
||||
|
||||
headers = {}
|
||||
url = server_args.url()
|
||||
if server_args.api_key:
|
||||
headers["Authorization"] = f"Bearer {server_args.api_key}"
|
||||
|
||||
# Wait until the server is launched
|
||||
success = False
|
||||
for _ in range(120):
|
||||
time.sleep(1)
|
||||
try:
|
||||
res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||
assert res.status_code == 200, f"{res=}, {res.text=}"
|
||||
success = True
|
||||
break
|
||||
except (AssertionError, requests.exceptions.RequestException):
|
||||
last_traceback = get_exception_traceback()
|
||||
pass
|
||||
|
||||
if not success:
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(last_traceback)
|
||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||
kill_process_tree(os.getpid())
|
||||
return
|
||||
|
||||
model_info = res.json()
|
||||
|
||||
# Send a warmup request
|
||||
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
||||
# TODO: Replace with OpenAI API
|
||||
max_new_tokens = 8 if model_info["is_generation"] else 1
|
||||
json_data = {
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
},
|
||||
}
|
||||
if server_args.skip_tokenizer_init:
|
||||
json_data["input_ids"] = [[10, 11, 12] for _ in range(server_args.dp_size)]
|
||||
# TODO Workaround the bug that embedding errors for list of size 1
|
||||
if server_args.dp_size == 1:
|
||||
json_data["input_ids"] = json_data["input_ids"][0]
|
||||
else:
|
||||
json_data["text"] = ["The capital city of France is"] * server_args.dp_size
|
||||
# TODO Workaround the bug that embedding errors for list of size 1
|
||||
if server_args.dp_size == 1:
|
||||
json_data["text"] = json_data["text"][0]
|
||||
|
||||
# Debug dumping
|
||||
if server_args.debug_tensor_dump_input_file:
|
||||
json_data.pop("text", None)
|
||||
json_data["input_ids"] = np.load(
|
||||
server_args.debug_tensor_dump_input_file
|
||||
).tolist()
|
||||
json_data["sampling_params"]["max_new_tokens"] = 0
|
||||
|
||||
try:
|
||||
if server_args.disaggregation_mode == "null":
|
||||
res = requests.post(
|
||||
url + request_name,
|
||||
json=json_data,
|
||||
headers=headers,
|
||||
timeout=600,
|
||||
)
|
||||
assert res.status_code == 200, f"{res}"
|
||||
else:
|
||||
logger.info(f"Start of prefill warmup ...")
|
||||
json_data = {
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 8,
|
||||
"ignore_eos": True,
|
||||
},
|
||||
"bootstrap_host": [FakeBootstrapHost] * server_args.dp_size,
|
||||
# This is a hack to ensure fake transfer is enabled during prefill warmup
|
||||
# ensure each dp rank has a unique bootstrap_room during prefill warmup
|
||||
"bootstrap_room": [
|
||||
i * (2**63 // server_args.dp_size) + (i % server_args.tp_size)
|
||||
for i in range(server_args.dp_size)
|
||||
],
|
||||
"input_ids": [[0, 1, 2, 3]] * server_args.dp_size,
|
||||
}
|
||||
res = requests.post(
|
||||
url + request_name,
|
||||
json=json_data,
|
||||
headers=headers,
|
||||
timeout=1800, # because of deep gemm precache is very long if not precache.
|
||||
)
|
||||
logger.info(
|
||||
f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
last_traceback = get_exception_traceback()
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(last_traceback)
|
||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||
kill_process_tree(os.getpid())
|
||||
return
|
||||
|
||||
# Debug print
|
||||
# logger.info(f"{res.json()=}")
|
||||
|
||||
logger.info("The server is fired up and ready to roll!")
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send("ready")
|
||||
|
||||
if server_args.delete_ckpt_after_loading:
|
||||
delete_directory(server_args.model_path)
|
||||
|
||||
if server_args.debug_tensor_dump_input_file:
|
||||
kill_process_tree(os.getpid())
|
||||
|
||||
if server_args.pdlb_url is not None:
|
||||
register_disaggregation_server(
|
||||
server_args.disaggregation_mode,
|
||||
server_args.port,
|
||||
server_args.disaggregation_bootstrap_port,
|
||||
server_args.pdlb_url,
|
||||
)
|
||||
|
||||
if launch_callback is not None:
|
||||
launch_callback()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="SGLang OpenAI-Compatible API Server")
|
||||
# Add arguments from ServerArgs. This allows reuse of existing CLI definitions.
|
||||
ServerArgs.add_cli_args(parser)
|
||||
# Potentially add server-specific arguments here in the future if needed
|
||||
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
|
||||
# Store server_args in app.state for access in lifespan and endpoints
|
||||
app.state.server_args = server_args
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=server_args.log_level.upper(),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s",
|
||||
)
|
||||
|
||||
# Send a warmup request - we will create the thread launch it
|
||||
# in the lifespan after all other warmups have fired.
|
||||
warmup_thread = threading.Thread(
|
||||
target=_wait_and_warmup,
|
||||
args=(
|
||||
server_args,
|
||||
None,
|
||||
None, # Never used
|
||||
None,
|
||||
),
|
||||
)
|
||||
app.warmup_thread = warmup_thread
|
||||
|
||||
try:
|
||||
# Start the server
|
||||
set_uvicorn_logging_configs()
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=server_args.host,
|
||||
port=server_args.port,
|
||||
log_level=server_args.log_level.lower(),
|
||||
timeout_keep_alive=60, # Increased keep-alive for potentially long requests
|
||||
loop="uvloop", # Use uvloop for better performance if available
|
||||
)
|
||||
finally:
|
||||
warmup_thread.join()
|
||||
Reference in New Issue
Block a user