Separate two entry points: Engine and HTTP server (#2996)
Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
579
python/sglang/srt/entrypoints/http_server.py
Normal file
579
python/sglang/srt/entrypoints/http_server.py
Normal file
@@ -0,0 +1,579 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
The entry point of inference server. (SRT = SGLang Runtime)
|
||||
|
||||
This file implements HTTP APIs for the inferenc engine via fastapi.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
import multiprocessing as multiprocessing
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Dict, Optional
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
|
||||
import orjson
|
||||
import requests
|
||||
import uvicorn
|
||||
import uvloop
|
||||
from fastapi import FastAPI, File, Form, Request, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||
from sglang.srt.managers.io_struct import (
|
||||
CloseSessionReqInput,
|
||||
ConfigureLoggingReq,
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
OpenSessionReqInput,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
)
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||
from sglang.srt.openai_api.adapter import (
|
||||
v1_batches,
|
||||
v1_cancel_batch,
|
||||
v1_chat_completions,
|
||||
v1_completions,
|
||||
v1_delete_file,
|
||||
v1_embeddings,
|
||||
v1_files_create,
|
||||
v1_retrieve_batch,
|
||||
v1_retrieve_file,
|
||||
v1_retrieve_file_content,
|
||||
)
|
||||
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
add_api_key_middleware,
|
||||
add_prometheus_middleware,
|
||||
delete_directory,
|
||||
kill_process_tree,
|
||||
set_uvicorn_logging_configs,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
from sglang.version import __version__
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
# Fast API
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# Store global states
|
||||
@dataclasses.dataclass
|
||||
class _GlobalState:
|
||||
tokenizer_manager: TokenizerManager
|
||||
scheduler_info: Dict
|
||||
|
||||
|
||||
_global_state: Optional[_GlobalState] = None
|
||||
|
||||
|
||||
def set_global_state(global_state: _GlobalState):
|
||||
global _global_state
|
||||
_global_state = global_state
|
||||
|
||||
|
||||
##### Native API endpoints #####
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Check the health of the http server."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/health_generate")
|
||||
async def health_generate(request: Request) -> Response:
|
||||
"""Check the health of the inference server by generating one token."""
|
||||
|
||||
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
|
||||
|
||||
if _global_state.tokenizer_manager.is_generation:
|
||||
gri = GenerateReqInput(
|
||||
input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||
)
|
||||
else:
|
||||
gri = EmbeddingReqInput(
|
||||
input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||
)
|
||||
|
||||
try:
|
||||
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
|
||||
break
|
||||
return Response(status_code=200)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return Response(status_code=503)
|
||||
|
||||
|
||||
@app.get("/get_model_info")
|
||||
async def get_model_info():
|
||||
"""Get the model information."""
|
||||
result = {
|
||||
"model_path": _global_state.tokenizer_manager.model_path,
|
||||
"tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
|
||||
"is_generation": _global_state.tokenizer_manager.is_generation,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/get_server_info")
|
||||
async def get_server_info():
|
||||
return {
|
||||
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
|
||||
**_global_state.scheduler_info,
|
||||
"version": __version__,
|
||||
}
|
||||
|
||||
|
||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||
@app.api_route("/generate", methods=["POST", "PUT"])
|
||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
"""Handle a generate request."""
|
||||
if obj.stream:
|
||||
|
||||
async def stream_results() -> AsyncIterator[bytes]:
|
||||
try:
|
||||
async for out in _global_state.tokenizer_manager.generate_request(
|
||||
obj, request
|
||||
):
|
||||
yield b"data: " + orjson.dumps(
|
||||
out, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
except ValueError as e:
|
||||
out = {"error": {"message": str(e)}}
|
||||
yield b"data: " + orjson.dumps(
|
||||
out, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
yield b"data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
stream_results(),
|
||||
media_type="text/event-stream",
|
||||
background=_global_state.tokenizer_manager.create_abort_task(obj),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
ret = await _global_state.tokenizer_manager.generate_request(
|
||||
obj, request
|
||||
).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {e}")
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||
"""Handle an embedding request."""
|
||||
try:
|
||||
ret = await _global_state.tokenizer_manager.generate_request(
|
||||
obj, request
|
||||
).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/classify", methods=["POST", "PUT"])
|
||||
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
||||
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
||||
try:
|
||||
ret = await _global_state.tokenizer_manager.generate_request(
|
||||
obj, request
|
||||
).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.post("/flush_cache")
|
||||
async def flush_cache():
|
||||
"""Flush the radix cache."""
|
||||
_global_state.tokenizer_manager.flush_cache()
|
||||
return Response(
|
||||
content="Cache flushed.\nPlease check backend logs for more details. "
|
||||
"(When there are running or waiting requests, the operation will not be performed.)\n",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/start_profile", methods=["GET", "POST"])
|
||||
async def start_profile_async():
|
||||
"""Start profiling."""
|
||||
_global_state.tokenizer_manager.start_profile()
|
||||
return Response(
|
||||
content="Start profiling.\n",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/stop_profile", methods=["GET", "POST"])
|
||||
async def stop_profile_async():
|
||||
"""Stop profiling."""
|
||||
_global_state.tokenizer_manager.stop_profile()
|
||||
return Response(
|
||||
content="Stop profiling. This will take some time.\n",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/update_weights_from_disk")
|
||||
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
||||
"""Update the weights from disk in-place without re-launching the server."""
|
||||
success, message = await _global_state.tokenizer_manager.update_weights_from_disk(
|
||||
obj, request
|
||||
)
|
||||
content = {"success": success, "message": message}
|
||||
if success:
|
||||
return ORJSONResponse(
|
||||
content,
|
||||
status_code=HTTPStatus.OK,
|
||||
)
|
||||
else:
|
||||
return ORJSONResponse(
|
||||
content,
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/init_weights_update_group")
|
||||
async def init_weights_update_group(
|
||||
obj: InitWeightsUpdateGroupReqInput, request: Request
|
||||
):
|
||||
"""Initialize the parameter update group."""
|
||||
success, message = await _global_state.tokenizer_manager.init_weights_update_group(
|
||||
obj, request
|
||||
)
|
||||
content = {"success": success, "message": message}
|
||||
if success:
|
||||
return ORJSONResponse(content, status_code=200)
|
||||
else:
|
||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.post("/update_weights_from_distributed")
|
||||
async def update_weights_from_distributed(
|
||||
obj: UpdateWeightsFromDistributedReqInput, request: Request
|
||||
):
|
||||
"""Update model parameter from distributed online."""
|
||||
success, message = (
|
||||
await _global_state.tokenizer_manager.update_weights_from_distributed(
|
||||
obj, request
|
||||
)
|
||||
)
|
||||
content = {"success": success, "message": message}
|
||||
if success:
|
||||
return ORJSONResponse(content, status_code=200)
|
||||
else:
|
||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
|
||||
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
||||
"""Get model parameter by name."""
|
||||
try:
|
||||
ret = await _global_state.tokenizer_manager.get_weights_by_name(obj, request)
|
||||
if ret is None:
|
||||
return _create_error_response("Get parameter by name failed")
|
||||
else:
|
||||
return ORJSONResponse(ret, status_code=200)
|
||||
except Exception as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/release_memory_occupation", methods=["GET", "POST"])
|
||||
async def release_memory_occupation(
|
||||
obj: ReleaseMemoryOccupationReqInput, request: Request
|
||||
):
|
||||
"""Release GPU occupation temporarily"""
|
||||
try:
|
||||
await _global_state.tokenizer_manager.release_memory_occupation(obj, request)
|
||||
except Exception as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/resume_memory_occupation", methods=["GET", "POST"])
|
||||
async def resume_memory_occupation(
|
||||
obj: ResumeMemoryOccupationReqInput, request: Request
|
||||
):
|
||||
"""Resume GPU occupation"""
|
||||
try:
|
||||
await _global_state.tokenizer_manager.resume_memory_occupation(obj, request)
|
||||
except Exception as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/open_session", methods=["GET", "POST"])
|
||||
async def open_session(obj: OpenSessionReqInput, request: Request):
|
||||
"""Open a session, and return its unique session id."""
|
||||
try:
|
||||
session_id = await _global_state.tokenizer_manager.open_session(obj, request)
|
||||
if session_id is None:
|
||||
raise Exception(
|
||||
"Failed to open the session. Check if a session with the same id is still open."
|
||||
)
|
||||
return session_id
|
||||
except Exception as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/close_session", methods=["GET", "POST"])
|
||||
async def close_session(obj: CloseSessionReqInput, request: Request):
|
||||
"""Close the session"""
|
||||
try:
|
||||
await _global_state.tokenizer_manager.close_session(obj, request)
|
||||
return Response(status_code=200)
|
||||
except Exception as e:
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/configure_logging", methods=["GET", "POST"])
|
||||
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
||||
"""Close the session"""
|
||||
_global_state.tokenizer_manager.configure_logging(obj)
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
##### OpenAI-compatible API endpoints #####
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def openai_v1_completions(raw_request: Request):
|
||||
return await v1_completions(_global_state.tokenizer_manager, raw_request)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def openai_v1_chat_completions(raw_request: Request):
|
||||
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
|
||||
|
||||
|
||||
@app.post("/v1/embeddings", response_class=ORJSONResponse)
|
||||
async def openai_v1_embeddings(raw_request: Request):
|
||||
response = await v1_embeddings(_global_state.tokenizer_manager, raw_request)
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/v1/models", response_class=ORJSONResponse)
|
||||
def available_models():
|
||||
"""Show available models."""
|
||||
served_model_names = [_global_state.tokenizer_manager.served_model_name]
|
||||
model_cards = []
|
||||
for served_model_name in served_model_names:
|
||||
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
|
||||
@app.post("/v1/files")
|
||||
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
|
||||
return await v1_files_create(
|
||||
file, purpose, _global_state.tokenizer_manager.server_args.file_storage_pth
|
||||
)
|
||||
|
||||
|
||||
@app.delete("/v1/files/{file_id}")
|
||||
async def delete_file(file_id: str):
|
||||
# https://platform.openai.com/docs/api-reference/files/delete
|
||||
return await v1_delete_file(file_id)
|
||||
|
||||
|
||||
@app.post("/v1/batches")
|
||||
async def openai_v1_batches(raw_request: Request):
|
||||
return await v1_batches(_global_state.tokenizer_manager, raw_request)
|
||||
|
||||
|
||||
@app.post("/v1/batches/{batch_id}/cancel")
|
||||
async def cancel_batches(batch_id: str):
|
||||
# https://platform.openai.com/docs/api-reference/batch/cancel
|
||||
return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id)
|
||||
|
||||
|
||||
@app.get("/v1/batches/{batch_id}")
|
||||
async def retrieve_batch(batch_id: str):
|
||||
return await v1_retrieve_batch(batch_id)
|
||||
|
||||
|
||||
@app.get("/v1/files/{file_id}")
|
||||
async def retrieve_file(file_id: str):
|
||||
# https://platform.openai.com/docs/api-reference/files/retrieve
|
||||
return await v1_retrieve_file(file_id)
|
||||
|
||||
|
||||
@app.get("/v1/files/{file_id}/content")
|
||||
async def retrieve_file_content(file_id: str):
|
||||
# https://platform.openai.com/docs/api-reference/files/retrieve-contents
|
||||
return await v1_retrieve_file_content(file_id)
|
||||
|
||||
|
||||
def _create_error_response(e):
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
def launch_server(
|
||||
server_args: ServerArgs,
|
||||
pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None,
|
||||
):
|
||||
"""
|
||||
Launch SRT (SGLang Runtime) Server.
|
||||
|
||||
The SRT server consists of an HTTP server and an SRT engine.
|
||||
|
||||
- HTTP server: A FastAPI server that routes requests to the engine.
|
||||
- The engine consists of three components:
|
||||
1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
|
||||
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
|
||||
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
|
||||
|
||||
Note:
|
||||
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
|
||||
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
|
||||
"""
|
||||
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
|
||||
set_global_state(
|
||||
_GlobalState(
|
||||
tokenizer_manager=tokenizer_manager,
|
||||
scheduler_info=scheduler_info,
|
||||
)
|
||||
)
|
||||
|
||||
# Add api key authorization
|
||||
if server_args.api_key:
|
||||
add_api_key_middleware(app, server_args.api_key)
|
||||
|
||||
# Add prometheus middleware
|
||||
if server_args.enable_metrics:
|
||||
add_prometheus_middleware(app)
|
||||
enable_func_timer()
|
||||
|
||||
# Send a warmup request
|
||||
t = threading.Thread(
|
||||
target=_wait_and_warmup,
|
||||
args=(
|
||||
server_args,
|
||||
pipe_finish_writer,
|
||||
_global_state.tokenizer_manager.image_token_id,
|
||||
),
|
||||
)
|
||||
t.start()
|
||||
|
||||
try:
|
||||
# Update logging configs
|
||||
set_uvicorn_logging_configs()
|
||||
|
||||
# Listen for HTTP requests
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=server_args.host,
|
||||
port=server_args.port,
|
||||
log_level=server_args.log_level_http or server_args.log_level,
|
||||
timeout_keep_alive=5,
|
||||
loop="uvloop",
|
||||
)
|
||||
finally:
|
||||
t.join()
|
||||
|
||||
|
||||
def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
|
||||
headers = {}
|
||||
url = server_args.url()
|
||||
if server_args.api_key:
|
||||
headers["Authorization"] = f"Bearer {server_args.api_key}"
|
||||
|
||||
# Wait until the server is launched
|
||||
success = False
|
||||
for _ in range(120):
|
||||
time.sleep(1)
|
||||
try:
|
||||
res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||
assert res.status_code == 200, f"{res=}, {res.text=}"
|
||||
success = True
|
||||
break
|
||||
except (AssertionError, requests.exceptions.RequestException):
|
||||
last_traceback = get_exception_traceback()
|
||||
pass
|
||||
|
||||
if not success:
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(last_traceback)
|
||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||
kill_process_tree(os.getpid())
|
||||
return
|
||||
|
||||
model_info = res.json()
|
||||
|
||||
# Send a warmup request
|
||||
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
||||
max_new_tokens = 8 if model_info["is_generation"] else 1
|
||||
json_data = {
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
},
|
||||
}
|
||||
if server_args.skip_tokenizer_init:
|
||||
json_data["input_ids"] = [10, 11, 12]
|
||||
else:
|
||||
json_data["text"] = "The capital city of France is"
|
||||
|
||||
try:
|
||||
for _ in range(server_args.dp_size):
|
||||
res = requests.post(
|
||||
url + request_name,
|
||||
json=json_data,
|
||||
headers=headers,
|
||||
timeout=600,
|
||||
)
|
||||
assert res.status_code == 200, f"{res}"
|
||||
except Exception:
|
||||
last_traceback = get_exception_traceback()
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(last_traceback)
|
||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||
kill_process_tree(os.getpid())
|
||||
return
|
||||
|
||||
# Debug print
|
||||
# logger.info(f"{res.json()=}")
|
||||
|
||||
logger.info("The server is fired up and ready to roll!")
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send("ready")
|
||||
|
||||
if server_args.delete_ckpt_after_loading:
|
||||
delete_directory(server_args.model_path)
|
||||
Reference in New Issue
Block a user