Add example to use sgl engine with fastapi (#5648)
Co-authored-by: Ravi Theja Desetty <ravitheja@Ravis-MacBook-Pro.local>
This commit is contained in:
189
examples/runtime/engine/fastapi_engine_inference.py
Normal file
189
examples/runtime/engine/fastapi_engine_inference.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""
|
||||||
|
FastAPI server example for text generation using SGLang Engine and demonstrating client usage.
|
||||||
|
|
||||||
|
Starts the server, sends requests to it, and prints responses.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python fastapi_engine_inference.py --model-path Qwen/Qwen2.5-0.5B-Instruct --tp_size 1 --host 127.0.0.1 --port 8000
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.utils import terminate_process
|
||||||
|
|
||||||
|
engine = None
|
||||||
|
|
||||||
|
|
||||||
|
# Use FastAPI's lifespan manager to initialize/shutdown the engine
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Manages SGLang engine initialization during server startup."""
|
||||||
|
global engine
|
||||||
|
# Initialize the SGLang engine when the server starts
|
||||||
|
# Adjust model_path and other engine arguments as needed
|
||||||
|
print("Loading SGLang engine...")
|
||||||
|
engine = sgl.Engine(
|
||||||
|
model_path=os.getenv("MODEL_PATH"), tp_size=int(os.getenv("TP_SIZE"))
|
||||||
|
)
|
||||||
|
print("SGLang engine loaded.")
|
||||||
|
yield
|
||||||
|
# Clean up engine resources when the server stops (optional, depends on engine needs)
|
||||||
|
print("Shutting down SGLang engine...")
|
||||||
|
# engine.shutdown() # Or other cleanup if available/necessary
|
||||||
|
print("SGLang engine shutdown.")
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/generate")
|
||||||
|
async def generate_text(request: Request):
|
||||||
|
"""FastAPI endpoint to handle text generation requests."""
|
||||||
|
global engine
|
||||||
|
if not engine:
|
||||||
|
return {"error": "Engine not initialized"}, 503
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
prompt = data.get("prompt")
|
||||||
|
max_new_tokens = data.get("max_new_tokens", 128)
|
||||||
|
temperature = data.get("temperature", 0.7)
|
||||||
|
|
||||||
|
if not prompt:
|
||||||
|
return {"error": "Prompt is required"}, 400
|
||||||
|
|
||||||
|
# Use async_generate for non-blocking generation
|
||||||
|
state = await engine.async_generate(
|
||||||
|
prompt,
|
||||||
|
sampling_params={
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"temperature": temperature,
|
||||||
|
},
|
||||||
|
# Add other parameters like stop, top_p etc. as needed
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"generated_text": state["text"]}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}, 500
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function to start the server
|
||||||
|
def start_server(args, timeout=60):
|
||||||
|
"""Starts the Uvicorn server as a subprocess and waits for it to be ready."""
|
||||||
|
base_url = f"http://{args.host}:{args.port}"
|
||||||
|
command = [
|
||||||
|
"python",
|
||||||
|
"-m",
|
||||||
|
"uvicorn",
|
||||||
|
"fastapi_engine_inference:app",
|
||||||
|
f"--host={args.host}",
|
||||||
|
f"--port={args.port}",
|
||||||
|
]
|
||||||
|
|
||||||
|
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
with requests.Session() as session:
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
try:
|
||||||
|
# Check the /docs endpoint which FastAPI provides by default
|
||||||
|
response = session.get(
|
||||||
|
f"{base_url}/docs", timeout=5
|
||||||
|
) # Add a request timeout
|
||||||
|
if response.status_code == 200:
|
||||||
|
print(f"Server {base_url} is ready (responded on /docs)")
|
||||||
|
return process
|
||||||
|
except requests.ConnectionError:
|
||||||
|
# Specific exception for connection refused/DNS error etc.
|
||||||
|
pass
|
||||||
|
except requests.Timeout:
|
||||||
|
# Specific exception for request timeout
|
||||||
|
print(f"Health check to {base_url}/docs timed out, retrying...")
|
||||||
|
pass
|
||||||
|
except requests.RequestException as e:
|
||||||
|
# Catch other request exceptions
|
||||||
|
print(f"Health check request error: {e}, retrying...")
|
||||||
|
pass
|
||||||
|
# Use a shorter sleep interval for faster startup detection
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# If loop finishes, raise the timeout error
|
||||||
|
# Attempt to terminate the failed process before raising
|
||||||
|
if process:
|
||||||
|
print(
|
||||||
|
"Server failed to start within timeout, attempting to terminate process..."
|
||||||
|
)
|
||||||
|
terminate_process(process) # Use the imported terminate_process
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Server failed to start at {base_url} within the timeout period."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def send_requests(server_url, prompts, max_new_tokens, temperature):
|
||||||
|
"""Sends generation requests to the running server for a list of prompts."""
|
||||||
|
# Iterate through prompts and send requests
|
||||||
|
for i, prompt in enumerate(prompts):
|
||||||
|
print(f"\n[{i+1}/{len(prompts)}] Sending prompt: '{prompt}'")
|
||||||
|
payload = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"temperature": temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(f"{server_url}/generate", json=payload, timeout=60)
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
print(f"Prompt: {prompt}\nResponse: {result['generated_text']}")
|
||||||
|
|
||||||
|
except requests.exceptions.Timeout:
|
||||||
|
print(f" Error: Request timed out for prompt '{prompt}'")
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f" Error sending request for prompt '{prompt}': {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""Main entry point for the script."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||||
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
|
parser.add_argument("--model-path", type=str, default="Qwen/Qwen2.5-0.5B-Instruct")
|
||||||
|
parser.add_argument("--tp_size", type=int, default=1)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Pass the model to the child uvicorn process via an env var
|
||||||
|
os.environ["MODEL_PATH"] = args.model_path
|
||||||
|
os.environ["TP_SIZE"] = str(args.tp_size)
|
||||||
|
|
||||||
|
# Start the server
|
||||||
|
process = start_server(args)
|
||||||
|
|
||||||
|
# Define the prompts and sampling parameters
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
max_new_tokens = 64
|
||||||
|
temperature = 0.1
|
||||||
|
|
||||||
|
# Define server url
|
||||||
|
server_url = f"http://{args.host}:{args.port}"
|
||||||
|
|
||||||
|
# Send requests to the server
|
||||||
|
send_requests(server_url, prompts, max_new_tokens, temperature)
|
||||||
|
|
||||||
|
# Terminate the server process
|
||||||
|
terminate_process(process)
|
||||||
@@ -6,6 +6,7 @@ SGLang provides a direct inference engine without the need for an HTTP server. T
|
|||||||
1. **Offline Batch Inference**
|
1. **Offline Batch Inference**
|
||||||
2. **Embedding Generation**
|
2. **Embedding Generation**
|
||||||
3. **Custom Server on Top of the Engine**
|
3. **Custom Server on Top of the Engine**
|
||||||
|
4. **Inference Using FastAPI**
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
@@ -47,3 +48,7 @@ This will send both non-streaming and streaming requests to the server.
|
|||||||
### [Token-In-Token-Out for RLHF](../token_in_token_out)
|
### [Token-In-Token-Out for RLHF](../token_in_token_out)
|
||||||
|
|
||||||
In this example, we launch an SGLang engine, feed tokens as input and generate tokens as output.
|
In this example, we launch an SGLang engine, feed tokens as input and generate tokens as output.
|
||||||
|
|
||||||
|
### [Inference Using FastAPI](fastapi_engine_inference.py)
|
||||||
|
|
||||||
|
This example demonstrates how to create a FastAPI server that uses the SGLang engine for text generation.
|
||||||
|
|||||||
Reference in New Issue
Block a user