sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
53
examples/runtime/engine/custom_server.py
Normal file
53
examples/runtime/engine/custom_server.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from sanic import Sanic, text
|
||||
from sanic.response import json
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
engine = None
|
||||
|
||||
# Create an instance of the Sanic app
|
||||
app = Sanic("sanic-server")
|
||||
|
||||
|
||||
# Define an asynchronous route handler
|
||||
@app.route("/generate", methods=["POST"])
|
||||
async def generate(request):
|
||||
prompt = request.json.get("prompt")
|
||||
if not prompt:
|
||||
return json({"error": "Prompt is required"}, status=400)
|
||||
|
||||
# async_generate returns a dict
|
||||
result = await engine.async_generate(prompt)
|
||||
|
||||
return text(result["text"])
|
||||
|
||||
|
||||
@app.route("/generate_stream", methods=["POST"])
|
||||
async def generate_stream(request):
|
||||
prompt = request.json.get("prompt")
|
||||
|
||||
if not prompt:
|
||||
return json({"error": "Prompt is required"}, status=400)
|
||||
|
||||
# async_generate returns a dict
|
||||
result = await engine.async_generate(prompt, stream=True)
|
||||
|
||||
# https://sanic.dev/en/guide/advanced/streaming.md#streaming
|
||||
# init the response
|
||||
response = await request.respond()
|
||||
|
||||
# result is an async generator
|
||||
async for chunk in result:
|
||||
await response.send(chunk["text"])
|
||||
|
||||
await response.eof()
|
||||
|
||||
|
||||
def run_server():
|
||||
global engine
|
||||
engine = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||
app.run(host="0.0.0.0", port=8000, single_process=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_server()
|
||||
27
examples/runtime/engine/embedding.py
Normal file
27
examples/runtime/engine/embedding.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
def main():
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create an LLM.
|
||||
llm = sgl.Engine(
|
||||
model_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct", is_embedding=True
|
||||
)
|
||||
|
||||
outputs = llm.encode(prompts)
|
||||
# Print the outputs (embedding vectors)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
print("===============================")
|
||||
print(f"Prompt: {prompt}\nEmbedding vector: {output['embedding']}")
|
||||
|
||||
|
||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
189
examples/runtime/engine/fastapi_engine_inference.py
Normal file
189
examples/runtime/engine/fastapi_engine_inference.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
FastAPI server example for text generation using SGLang Engine and demonstrating client usage.
|
||||
|
||||
Starts the server, sends requests to it, and prints responses.
|
||||
|
||||
Usage:
|
||||
python fastapi_engine_inference.py --model-path Qwen/Qwen2.5-0.5B-Instruct --tp_size 1 --host 127.0.0.1 --port 8000
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import requests
|
||||
from fastapi import FastAPI, Request
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.utils import terminate_process
|
||||
|
||||
engine = None
|
||||
|
||||
|
||||
# Use FastAPI's lifespan manager to initialize/shutdown the engine
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manages SGLang engine initialization during server startup."""
|
||||
global engine
|
||||
# Initialize the SGLang engine when the server starts
|
||||
# Adjust model_path and other engine arguments as needed
|
||||
print("Loading SGLang engine...")
|
||||
engine = sgl.Engine(
|
||||
model_path=os.getenv("MODEL_PATH"), tp_size=int(os.getenv("TP_SIZE"))
|
||||
)
|
||||
print("SGLang engine loaded.")
|
||||
yield
|
||||
# Clean up engine resources when the server stops (optional, depends on engine needs)
|
||||
print("Shutting down SGLang engine...")
|
||||
# engine.shutdown() # Or other cleanup if available/necessary
|
||||
print("SGLang engine shutdown.")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate_text(request: Request):
|
||||
"""FastAPI endpoint to handle text generation requests."""
|
||||
global engine
|
||||
if not engine:
|
||||
return {"error": "Engine not initialized"}, 503
|
||||
|
||||
try:
|
||||
data = await request.json()
|
||||
prompt = data.get("prompt")
|
||||
max_new_tokens = data.get("max_new_tokens", 128)
|
||||
temperature = data.get("temperature", 0.7)
|
||||
|
||||
if not prompt:
|
||||
return {"error": "Prompt is required"}, 400
|
||||
|
||||
# Use async_generate for non-blocking generation
|
||||
state = await engine.async_generate(
|
||||
prompt,
|
||||
sampling_params={
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
# Add other parameters like stop, top_p etc. as needed
|
||||
)
|
||||
|
||||
return {"generated_text": state["text"]}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 500
|
||||
|
||||
|
||||
# Helper function to start the server
|
||||
def start_server(args, timeout=60):
|
||||
"""Starts the Uvicorn server as a subprocess and waits for it to be ready."""
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
command = [
|
||||
"python",
|
||||
"-m",
|
||||
"uvicorn",
|
||||
"fastapi_engine_inference:app",
|
||||
f"--host={args.host}",
|
||||
f"--port={args.port}",
|
||||
]
|
||||
|
||||
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
with requests.Session() as session:
|
||||
while time.perf_counter() - start_time < timeout:
|
||||
try:
|
||||
# Check the /docs endpoint which FastAPI provides by default
|
||||
response = session.get(
|
||||
f"{base_url}/docs", timeout=5
|
||||
) # Add a request timeout
|
||||
if response.status_code == 200:
|
||||
print(f"Server {base_url} is ready (responded on /docs)")
|
||||
return process
|
||||
except requests.ConnectionError:
|
||||
# Specific exception for connection refused/DNS error etc.
|
||||
pass
|
||||
except requests.Timeout:
|
||||
# Specific exception for request timeout
|
||||
print(f"Health check to {base_url}/docs timed out, retrying...")
|
||||
pass
|
||||
except requests.RequestException as e:
|
||||
# Catch other request exceptions
|
||||
print(f"Health check request error: {e}, retrying...")
|
||||
pass
|
||||
# Use a shorter sleep interval for faster startup detection
|
||||
time.sleep(1)
|
||||
|
||||
# If loop finishes, raise the timeout error
|
||||
# Attempt to terminate the failed process before raising
|
||||
if process:
|
||||
print(
|
||||
"Server failed to start within timeout, attempting to terminate process..."
|
||||
)
|
||||
terminate_process(process) # Use the imported terminate_process
|
||||
raise TimeoutError(
|
||||
f"Server failed to start at {base_url} within the timeout period."
|
||||
)
|
||||
|
||||
|
||||
def send_requests(server_url, prompts, max_new_tokens, temperature):
|
||||
"""Sends generation requests to the running server for a list of prompts."""
|
||||
# Iterate through prompts and send requests
|
||||
for i, prompt in enumerate(prompts):
|
||||
print(f"\n[{i+1}/{len(prompts)}] Sending prompt: '{prompt}'")
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(f"{server_url}/generate", json=payload, timeout=60)
|
||||
|
||||
result = response.json()
|
||||
|
||||
print(f"Prompt: {prompt}\nResponse: {result['generated_text']}")
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
print(f" Error: Request timed out for prompt '{prompt}'")
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f" Error sending request for prompt '{prompt}': {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Main entry point for the script."""
|
||||
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--model-path", type=str, default="Qwen/Qwen2.5-0.5B-Instruct")
|
||||
parser.add_argument("--tp_size", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Pass the model to the child uvicorn process via an env var
|
||||
os.environ["MODEL_PATH"] = args.model_path
|
||||
os.environ["TP_SIZE"] = str(args.tp_size)
|
||||
|
||||
# Start the server
|
||||
process = start_server(args)
|
||||
|
||||
# Define the prompts and sampling parameters
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
max_new_tokens = 64
|
||||
temperature = 0.1
|
||||
|
||||
# Define server url
|
||||
server_url = f"http://{args.host}:{args.port}"
|
||||
|
||||
# Send requests to the server
|
||||
send_requests(server_url, prompts, max_new_tokens, temperature)
|
||||
|
||||
# Terminate the server process
|
||||
terminate_process(process)
|
||||
17
examples/runtime/engine/launch_engine.py
Normal file
17
examples/runtime/engine/launch_engine.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
This example demonstrates how to launch the offline engine.
|
||||
"""
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
def main():
|
||||
llm = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||
llm.generate("What is the capital of France?")
|
||||
llm.shutdown()
|
||||
|
||||
|
||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
43
examples/runtime/engine/offline_batch_inference.py
Normal file
43
examples/runtime/engine/offline_batch_inference.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 offline_batch_inference.py --model meta-llama/Llama-3.1-8B-Instruct
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
def main(
|
||||
server_args: ServerArgs,
|
||||
):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = {"temperature": 0.8, "top_p": 0.95}
|
||||
|
||||
# Create an LLM.
|
||||
llm = sgl.Engine(**dataclasses.asdict(server_args))
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
print("===============================")
|
||||
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
|
||||
|
||||
|
||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
main(server_args)
|
||||
65
examples/runtime/engine/offline_batch_inference_async.py
Normal file
65
examples/runtime/engine/offline_batch_inference_async.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
Usage:
|
||||
python offline_batch_inference_async.py --model-path Qwen/Qwen2-VL-7B-Instruct
|
||||
|
||||
Note:
|
||||
This demo shows the usage of async generation,
|
||||
which is useful to implement an online-like generation with batched inference.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import time
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
class InferenceEngine:
|
||||
def __init__(self, **kwargs):
|
||||
self.engine = sgl.Engine(**kwargs)
|
||||
|
||||
async def generate(self, prompt, sampling_params):
|
||||
result = await self.engine.async_generate(prompt, sampling_params)
|
||||
return result
|
||||
|
||||
|
||||
async def run_server(server_args):
|
||||
inference = InferenceEngine(**dataclasses.asdict(server_args))
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
] * 100
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = {"temperature": 0.8, "top_p": 0.95}
|
||||
|
||||
# Run the generation tasks concurrently in async mode.
|
||||
tasks = []
|
||||
for prompt in prompts:
|
||||
task = asyncio.create_task(inference.generate(prompt, sampling_params))
|
||||
tasks.append(task)
|
||||
|
||||
# Get and print the result
|
||||
for task in tasks:
|
||||
await task
|
||||
while True:
|
||||
if not task.done():
|
||||
time.sleep(1)
|
||||
else:
|
||||
result = task.result()
|
||||
print(f"Generated text: {result['text']}")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
asyncio.run(run_server(server_args))
|
||||
38
examples/runtime/engine/offline_batch_inference_eagle.py
Normal file
38
examples/runtime/engine/offline_batch_inference_eagle.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
def main():
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = {"temperature": 0, "max_new_tokens": 30}
|
||||
|
||||
# Create an LLM.
|
||||
llm = sgl.Engine(
|
||||
model_path="meta-llama/Llama-2-7b-chat-hf",
|
||||
speculative_algorithm="EAGLE",
|
||||
speculative_draft_model_path="lmsys/sglang-EAGLE-llama2-chat-7B",
|
||||
speculative_num_steps=3,
|
||||
speculative_eagle_topk=4,
|
||||
speculative_num_draft_tokens=16,
|
||||
cuda_graph_max_bs=8,
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
print("===============================")
|
||||
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
|
||||
|
||||
|
||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
74
examples/runtime/engine/offline_batch_inference_qwen_1m.py
Normal file
74
examples/runtime/engine/offline_batch_inference_qwen_1m.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 offline_batch_inference.py
|
||||
"""
|
||||
|
||||
from urllib.request import urlopen
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
def load_prompt() -> str:
|
||||
# Test cases with various lengths can be found at:
|
||||
#
|
||||
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt
|
||||
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt
|
||||
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt
|
||||
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
|
||||
|
||||
with urlopen(
|
||||
"https://qianwen-res.oss-cn-beijing.aliyuncs.com"
|
||||
"/Qwen2.5-1M/test-data/64k.txt",
|
||||
timeout=5,
|
||||
) as response:
|
||||
prompt = response.read().decode("utf-8")
|
||||
return prompt
|
||||
|
||||
|
||||
# Processing the prompt.
|
||||
def process_requests(llm: sgl.Engine, prompts: list[str]) -> None:
|
||||
# Create a sampling params object.
|
||||
sampling_params = {
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.8,
|
||||
"top_k": 20,
|
||||
"repetition_penalty": 1.05,
|
||||
"max_new_tokens": 256,
|
||||
}
|
||||
# Generate texts from the prompts.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt_token_ids = output["meta_info"]["prompt_tokens"]
|
||||
generated_text = output["text"]
|
||||
print(
|
||||
f"Prompt length: {prompt_token_ids}, " f"Generated text: {generated_text!r}"
|
||||
)
|
||||
|
||||
|
||||
# Create an LLM.
|
||||
def initialize_engine() -> sgl.Engine:
|
||||
llm = sgl.Engine(
|
||||
model_path="Qwen/Qwen2.5-7B-Instruct-1M",
|
||||
context_length=1048576,
|
||||
page_size=256,
|
||||
attention_backend="dual_chunk_flash_attn",
|
||||
tp_size=4,
|
||||
disable_radix_cache=True,
|
||||
enable_mixed_chunk=False,
|
||||
enable_torch_compile=False,
|
||||
chunked_prefill_size=131072,
|
||||
mem_fraction_static=0.6,
|
||||
log_level="DEBUG",
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
def main():
|
||||
llm = initialize_engine()
|
||||
prompt = load_prompt()
|
||||
process_requests(llm, [prompt])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
52
examples/runtime/engine/offline_batch_inference_vlm.py
Normal file
52
examples/runtime/engine/offline_batch_inference_vlm.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Usage:
|
||||
python offline_batch_inference_vlm.py --model-path Qwen/Qwen2-VL-7B-Instruct
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.srt.parser.conversation import chat_templates
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
def main(
|
||||
server_args: ServerArgs,
|
||||
):
|
||||
vlm = sgl.Engine(**dataclasses.asdict(server_args))
|
||||
|
||||
conv = chat_templates[server_args.chat_template].copy()
|
||||
image_token = conv.image_token
|
||||
|
||||
image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||
|
||||
prompt = f"What's in this image?\n{image_token}"
|
||||
|
||||
sampling_params = {
|
||||
"temperature": 0.001,
|
||||
"max_new_tokens": 30,
|
||||
}
|
||||
|
||||
output = vlm.generate(
|
||||
prompt=prompt,
|
||||
image_data=image_url,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
print("===============================")
|
||||
print(f"Prompt: {prompt}")
|
||||
print(f"Generated text: {output['text']}")
|
||||
|
||||
vlm.shutdown()
|
||||
|
||||
|
||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
main(server_args)
|
||||
54
examples/runtime/engine/readme.md
Normal file
54
examples/runtime/engine/readme.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# SGLang Engine
|
||||
|
||||
SGLang provides a direct inference engine without the need for an HTTP server. There are generally these use cases:
|
||||
|
||||
- [Offline Batch Inference](#offline-batch-inference)
|
||||
- [Embedding Generation](#embedding-generation)
|
||||
- [Custom Server](#custom-server)
|
||||
- [Token-In-Token-Out for RLHF](#token-in-token-out-for-rlhf)
|
||||
- [Inference Using FastAPI](#inference-using-fastapi)
|
||||
|
||||
## Examples
|
||||
|
||||
### [Offline Batch Inference](./offline_batch_inference.py)
|
||||
|
||||
In this example, we launch an SGLang engine and feed a batch of inputs for inference. If you provide a very large batch, the engine will intelligently schedule the requests to process efficiently and prevent OOM (Out of Memory) errors.
|
||||
|
||||
### [Embedding Generation](./embedding.py)
|
||||
|
||||
In this example, we launch an SGLang engine and feed a batch of inputs for embedding generation.
|
||||
|
||||
### [Custom Server](./custom_server.py)
|
||||
|
||||
This example demonstrates how to create a custom server on top of the SGLang Engine. We use [Sanic](https://sanic.dev/en/) as an example. The server supports both non-streaming and streaming endpoints.
|
||||
|
||||
#### Steps
|
||||
|
||||
1. Install Sanic:
|
||||
|
||||
```bash
|
||||
pip install sanic
|
||||
```
|
||||
|
||||
2. Run the server:
|
||||
|
||||
```bash
|
||||
python custom_server
|
||||
```
|
||||
|
||||
3. Send requests:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/generate -H "Content-Type: application/json" -d '{"prompt": "The Transformer architecture is..."}'
|
||||
curl -X POST http://localhost:8000/generate_stream -H "Content-Type: application/json" -d '{"prompt": "The Transformer architecture is..."}' --no-buffer
|
||||
```
|
||||
|
||||
This will send both non-streaming and streaming requests to the server.
|
||||
|
||||
### [Token-In-Token-Out for RLHF](../token_in_token_out)
|
||||
|
||||
In this example, we launch an SGLang engine, feed tokens as input and generate tokens as output.
|
||||
|
||||
### [Inference Using FastAPI](fastapi_engine_inference.py)
|
||||
|
||||
This example demonstrates how to create a FastAPI server that uses the SGLang engine for text generation.
|
||||
59
examples/runtime/engine/save_remote_state.py
Normal file
59
examples/runtime/engine/save_remote_state.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Saves each worker's model state dict directly to a checkpoint, which enables a
|
||||
fast load path for large tensor-parallel models where each worker only needs to
|
||||
read its own shard rather than the entire checkpoint.
|
||||
|
||||
Example usage:
|
||||
|
||||
python save_remote_state.py \
|
||||
--model-path /path/to/load \
|
||||
--tensor-parallel-size 8 \
|
||||
--remote-model-save-url [protocol]://[host]:[port]/[model_name] \
|
||||
|
||||
Then, the model can be loaded with
|
||||
|
||||
llm = Engine(
|
||||
model_path="[protocol]://[host]:[port]/[model_name]",
|
||||
tensor_parallel_size=8,
|
||||
)
|
||||
"""
|
||||
import dataclasses
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
from sglang import Engine, ServerArgs
|
||||
|
||||
parser = ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--remote-model-save-url",
|
||||
required=True,
|
||||
type=str,
|
||||
help="remote address to store model weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote-draft-model-save-url",
|
||||
default=None,
|
||||
type=str,
|
||||
help="remote address to store draft model weights",
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
engine_args = ServerArgs.from_cli_args(args)
|
||||
model_path = engine_args.model_path
|
||||
if not Path(model_path).is_dir():
|
||||
raise ValueError("model path must be a local directory")
|
||||
# Create LLM instance from arguments
|
||||
llm = Engine(**dataclasses.asdict(engine_args))
|
||||
llm.save_remote_model(
|
||||
url=args.remote_model_save_url, draft_url=args.remote_draft_model_save_url
|
||||
)
|
||||
print("save remote (draft) model successfully")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
74
examples/runtime/engine/save_sharded_state.py
Normal file
74
examples/runtime/engine/save_sharded_state.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Saves each worker's model state dict directly to a checkpoint, which enables a
|
||||
fast load path for large tensor-parallel models where each worker only needs to
|
||||
read its own shard rather than the entire checkpoint.
|
||||
|
||||
Example usage:
|
||||
|
||||
python save_sharded_state.py \
|
||||
--model-path /path/to/load \
|
||||
--quantization deepspeedfp \
|
||||
--tensor-parallel-size 8 \
|
||||
--output /path/to/save
|
||||
|
||||
Then, the model can be loaded with
|
||||
|
||||
llm = Engine(
|
||||
model_path="/path/to/save",
|
||||
load_format="sharded_state",
|
||||
quantization="deepspeedfp",
|
||||
tensor_parallel_size=8,
|
||||
)
|
||||
"""
|
||||
import dataclasses
|
||||
import os
|
||||
import shutil
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
from sglang import Engine, ServerArgs
|
||||
|
||||
parser = ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--output", "-o", required=True, type=str, help="path to output checkpoint"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--file-pattern", type=str, help="string pattern of saved filenames"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-file-size",
|
||||
type=str,
|
||||
default=5 * 1024**3,
|
||||
help="max size (in bytes) of each safetensors file",
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
engine_args = ServerArgs.from_cli_args(args)
|
||||
model_path = engine_args.model_path
|
||||
if not Path(model_path).is_dir():
|
||||
raise ValueError("model path must be a local directory")
|
||||
# Create LLM instance from arguments
|
||||
llm = Engine(**dataclasses.asdict(engine_args))
|
||||
Path(args.output).mkdir(exist_ok=True)
|
||||
llm.save_sharded_model(
|
||||
path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
|
||||
)
|
||||
|
||||
# Copy metadata files to output directory
|
||||
for file in os.listdir(model_path):
|
||||
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
|
||||
if os.path.isdir(os.path.join(model_path, file)):
|
||||
shutil.copytree(
|
||||
os.path.join(model_path, file), os.path.join(args.output, file)
|
||||
)
|
||||
else:
|
||||
shutil.copy(os.path.join(model_path, file), args.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
Reference in New Issue
Block a user