From d2b8d0b8d8a8b891e427dbfd62d8f56163f7d782 Mon Sep 17 00:00:00 2001 From: Ravi Theja Date: Thu, 24 Apr 2025 21:27:05 +0530 Subject: [PATCH] Add example to use sgl engine with fastapi (#5648) Co-authored-by: Ravi Theja Desetty --- .../engine/fastapi_engine_inference.py | 189 ++++++++++++++++++ examples/runtime/engine/readme.md | 5 + 2 files changed, 194 insertions(+) create mode 100644 examples/runtime/engine/fastapi_engine_inference.py diff --git a/examples/runtime/engine/fastapi_engine_inference.py b/examples/runtime/engine/fastapi_engine_inference.py new file mode 100644 index 000000000..57b83bcba --- /dev/null +++ b/examples/runtime/engine/fastapi_engine_inference.py @@ -0,0 +1,189 @@ +""" +FastAPI server example for text generation using SGLang Engine and demonstrating client usage. + +Starts the server, sends requests to it, and prints responses. + +Usage: +python fastapi_engine_inference.py --model-path Qwen/Qwen2.5-0.5B-Instruct --tp_size 1 --host 127.0.0.1 --port 8000 +""" + +import os +import subprocess +import time +from contextlib import asynccontextmanager + +import requests +from fastapi import FastAPI, Request + +import sglang as sgl +from sglang.utils import terminate_process + +engine = None + + +# Use FastAPI's lifespan manager to initialize/shutdown the engine +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manages SGLang engine initialization during server startup.""" + global engine + # Initialize the SGLang engine when the server starts + # Adjust model_path and other engine arguments as needed + print("Loading SGLang engine...") + engine = sgl.Engine( + model_path=os.getenv("MODEL_PATH"), tp_size=int(os.getenv("TP_SIZE")) + ) + print("SGLang engine loaded.") + yield + # Clean up engine resources when the server stops (optional, depends on engine needs) + print("Shutting down SGLang engine...") + # engine.shutdown() # Or other cleanup if available/necessary + print("SGLang engine shutdown.") + + +app = FastAPI(lifespan=lifespan) + + +@app.post("/generate") +async def generate_text(request: Request): + """FastAPI endpoint to handle text generation requests.""" + global engine + if not engine: + return {"error": "Engine not initialized"}, 503 + + try: + data = await request.json() + prompt = data.get("prompt") + max_new_tokens = data.get("max_new_tokens", 128) + temperature = data.get("temperature", 0.7) + + if not prompt: + return {"error": "Prompt is required"}, 400 + + # Use async_generate for non-blocking generation + state = await engine.async_generate( + prompt, + sampling_params={ + "max_new_tokens": max_new_tokens, + "temperature": temperature, + }, + # Add other parameters like stop, top_p etc. as needed + ) + + return {"generated_text": state["text"]} + except Exception as e: + return {"error": str(e)}, 500 + + +# Helper function to start the server +def start_server(args, timeout=60): + """Starts the Uvicorn server as a subprocess and waits for it to be ready.""" + base_url = f"http://{args.host}:{args.port}" + command = [ + "python", + "-m", + "uvicorn", + "fastapi_engine_inference:app", + f"--host={args.host}", + f"--port={args.port}", + ] + + process = subprocess.Popen(command, stdout=None, stderr=None) + + start_time = time.time() + with requests.Session() as session: + while time.time() - start_time < timeout: + try: + # Check the /docs endpoint which FastAPI provides by default + response = session.get( + f"{base_url}/docs", timeout=5 + ) # Add a request timeout + if response.status_code == 200: + print(f"Server {base_url} is ready (responded on /docs)") + return process + except requests.ConnectionError: + # Specific exception for connection refused/DNS error etc. + pass + except requests.Timeout: + # Specific exception for request timeout + print(f"Health check to {base_url}/docs timed out, retrying...") + pass + except requests.RequestException as e: + # Catch other request exceptions + print(f"Health check request error: {e}, retrying...") + pass + # Use a shorter sleep interval for faster startup detection + time.sleep(1) + + # If loop finishes, raise the timeout error + # Attempt to terminate the failed process before raising + if process: + print( + "Server failed to start within timeout, attempting to terminate process..." + ) + terminate_process(process) # Use the imported terminate_process + raise TimeoutError( + f"Server failed to start at {base_url} within the timeout period." + ) + + +def send_requests(server_url, prompts, max_new_tokens, temperature): + """Sends generation requests to the running server for a list of prompts.""" + # Iterate through prompts and send requests + for i, prompt in enumerate(prompts): + print(f"\n[{i+1}/{len(prompts)}] Sending prompt: '{prompt}'") + payload = { + "prompt": prompt, + "max_new_tokens": max_new_tokens, + "temperature": temperature, + } + + try: + response = requests.post(f"{server_url}/generate", json=payload, timeout=60) + + result = response.json() + + print(f"Prompt: {prompt}\nResponse: {result['generated_text']}") + + except requests.exceptions.Timeout: + print(f" Error: Request timed out for prompt '{prompt}'") + except requests.exceptions.RequestException as e: + print(f" Error sending request for prompt '{prompt}': {e}") + + +if __name__ == "__main__": + """Main entry point for the script.""" + + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--model-path", type=str, default="Qwen/Qwen2.5-0.5B-Instruct") + parser.add_argument("--tp_size", type=int, default=1) + args = parser.parse_args() + + # Pass the model to the child uvicorn process via an env var + os.environ["MODEL_PATH"] = args.model_path + os.environ["TP_SIZE"] = str(args.tp_size) + + # Start the server + process = start_server(args) + + # Define the prompts and sampling parameters + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + max_new_tokens = 64 + temperature = 0.1 + + # Define server url + server_url = f"http://{args.host}:{args.port}" + + # Send requests to the server + send_requests(server_url, prompts, max_new_tokens, temperature) + + # Terminate the server process + terminate_process(process) diff --git a/examples/runtime/engine/readme.md b/examples/runtime/engine/readme.md index d5ac93671..e1161a9a3 100644 --- a/examples/runtime/engine/readme.md +++ b/examples/runtime/engine/readme.md @@ -6,6 +6,7 @@ SGLang provides a direct inference engine without the need for an HTTP server. T 1. **Offline Batch Inference** 2. **Embedding Generation** 3. **Custom Server on Top of the Engine** +4. **Inference Using FastAPI** ## Examples @@ -47,3 +48,7 @@ This will send both non-streaming and streaming requests to the server. ### [Token-In-Token-Out for RLHF](../token_in_token_out) In this example, we launch an SGLang engine, feed tokens as input and generate tokens as output. + +### [Inference Using FastAPI](fastapi_engine_inference.py) + +This example demonstrates how to create a FastAPI server that uses the SGLang engine for text generation.