adapt to sglang v0.5.2rc1 on dcu
This commit is contained in:
45
examples/runtime/README.md
Normal file
45
examples/runtime/README.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# Runtime examples
|
||||
|
||||
The below examples will mostly need you to start a server in a separate terminal before you can execute them. Please see in the code for detailed instruction.
|
||||
|
||||
## Native API
|
||||
|
||||
* `lora.py`: An example how to use LoRA adapters.
|
||||
* `multimodal_embedding.py`: An example how perform [multi modal embedding](Alibaba-NLP/gme-Qwen2-VL-2B-Instruct).
|
||||
* `openai_batch_chat.py`: An example how to process batch requests for chat completions.
|
||||
* `openai_batch_complete.py`: An example how to process batch requests for text completions.
|
||||
* **`openai_chat_with_response_prefill.py`**:
|
||||
An example that demonstrates how to [prefill a response](https://eugeneyan.com/writing/prompting/#prefill-claudes-responses) using the OpenAI API by enabling the `continue_final_message` parameter.
|
||||
When enabled, the final (partial) assistant message is removed and its content is used as a prefill so that the model continues that message rather than starting a new turn. See [Anthropic's prefill example](https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/prefill-claudes-response#example-structured-data-extraction-with-prefilling) for more context.
|
||||
* `reward_model.py`: An example how to extract scores from a reward model.
|
||||
* `vertex_predict.py`: An example how to deploy a model to [Vertex AI](https://cloud.google.com/vertex-ai?hl=en).
|
||||
|
||||
## Engine
|
||||
|
||||
The `engine` folder contains that examples that show how to use [Offline Engine API](https://docs.sglang.ai/backend/offline_engine_api.html#Offline-Engine-API) for common workflows.
|
||||
|
||||
* `custom_server.py`: An example how to deploy a custom server.
|
||||
* `embedding.py`: An example how to extract embeddings.
|
||||
* `launch_engine.py`: An example how to launch the Engine.
|
||||
* `offline_batch_inference_eagle.py`: An example how to perform speculative decoding using [EAGLE](https://docs.sglang.ai/backend/speculative_decoding.html).
|
||||
* `offline_batch_inference_torchrun.py`: An example how to perform inference using [torchrun](https://pytorch.org/docs/stable/elastic/run.html).
|
||||
* `offline_batch_inference_vlm.py`: An example how to use VLMs with the engine.
|
||||
* `offline_batch_inference.py`: An example how to use the engine to perform inference on a batch of examples.
|
||||
|
||||
## Hidden States
|
||||
|
||||
The `hidden_states` folder contains examples on how to extract hidden states using SGLang. Please note that this might degrade throughput due to cuda graph rebuilding.
|
||||
|
||||
* `hidden_states_engine.py`: An example how to extract hidden states using the Engine API.
|
||||
* `hidden_states_server.py`: An example how to extract hidden states using the Server API.
|
||||
|
||||
## Multimodal
|
||||
|
||||
SGLang supports multimodal inputs for various model architectures. The `multimodal` folder contains examples showing how to use urls, files or encoded data to make requests to multimodal models. Examples include querying the [Llava-OneVision](multimodal/llava_onevision_server.py) model (image, multi-image, video), Llava-backed [Qwen-Llava](multimodal/qwen_llava_server.py) and [Llama3-Llava](multimodal/llama3_llava_server.py) models (image, multi-image), and Mistral AI's [Pixtral](multimodal/pixtral_server.py) (image, multi-image).
|
||||
|
||||
|
||||
## Token In, Token Out
|
||||
|
||||
The folder `token_in_token_out` shows how to perform inference, where we provide tokens and get tokens as response.
|
||||
|
||||
* `token_in_token_out_{llm|vlm}_{engine|server}.py`: Shows how to perform token in, token out workflow for llm/vlm using either the engine or native API.
|
||||
53
examples/runtime/engine/custom_server.py
Normal file
53
examples/runtime/engine/custom_server.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from sanic import Sanic, text
|
||||
from sanic.response import json
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
engine = None
|
||||
|
||||
# Create an instance of the Sanic app
|
||||
app = Sanic("sanic-server")
|
||||
|
||||
|
||||
# Define an asynchronous route handler
|
||||
@app.route("/generate", methods=["POST"])
|
||||
async def generate(request):
|
||||
prompt = request.json.get("prompt")
|
||||
if not prompt:
|
||||
return json({"error": "Prompt is required"}, status=400)
|
||||
|
||||
# async_generate returns a dict
|
||||
result = await engine.async_generate(prompt)
|
||||
|
||||
return text(result["text"])
|
||||
|
||||
|
||||
@app.route("/generate_stream", methods=["POST"])
|
||||
async def generate_stream(request):
|
||||
prompt = request.json.get("prompt")
|
||||
|
||||
if not prompt:
|
||||
return json({"error": "Prompt is required"}, status=400)
|
||||
|
||||
# async_generate returns a dict
|
||||
result = await engine.async_generate(prompt, stream=True)
|
||||
|
||||
# https://sanic.dev/en/guide/advanced/streaming.md#streaming
|
||||
# init the response
|
||||
response = await request.respond()
|
||||
|
||||
# result is an async generator
|
||||
async for chunk in result:
|
||||
await response.send(chunk["text"])
|
||||
|
||||
await response.eof()
|
||||
|
||||
|
||||
def run_server():
|
||||
global engine
|
||||
engine = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||
app.run(host="0.0.0.0", port=8000, single_process=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_server()
|
||||
27
examples/runtime/engine/embedding.py
Normal file
27
examples/runtime/engine/embedding.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
def main():
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create an LLM.
|
||||
llm = sgl.Engine(
|
||||
model_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct", is_embedding=True
|
||||
)
|
||||
|
||||
outputs = llm.encode(prompts)
|
||||
# Print the outputs (embedding vectors)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
print("===============================")
|
||||
print(f"Prompt: {prompt}\nEmbedding vector: {output['embedding']}")
|
||||
|
||||
|
||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
189
examples/runtime/engine/fastapi_engine_inference.py
Normal file
189
examples/runtime/engine/fastapi_engine_inference.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
FastAPI server example for text generation using SGLang Engine and demonstrating client usage.
|
||||
|
||||
Starts the server, sends requests to it, and prints responses.
|
||||
|
||||
Usage:
|
||||
python fastapi_engine_inference.py --model-path Qwen/Qwen2.5-0.5B-Instruct --tp_size 1 --host 127.0.0.1 --port 8000
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import requests
|
||||
from fastapi import FastAPI, Request
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.utils import terminate_process
|
||||
|
||||
engine = None
|
||||
|
||||
|
||||
# Use FastAPI's lifespan manager to initialize/shutdown the engine
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manages SGLang engine initialization during server startup."""
|
||||
global engine
|
||||
# Initialize the SGLang engine when the server starts
|
||||
# Adjust model_path and other engine arguments as needed
|
||||
print("Loading SGLang engine...")
|
||||
engine = sgl.Engine(
|
||||
model_path=os.getenv("MODEL_PATH"), tp_size=int(os.getenv("TP_SIZE"))
|
||||
)
|
||||
print("SGLang engine loaded.")
|
||||
yield
|
||||
# Clean up engine resources when the server stops (optional, depends on engine needs)
|
||||
print("Shutting down SGLang engine...")
|
||||
# engine.shutdown() # Or other cleanup if available/necessary
|
||||
print("SGLang engine shutdown.")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate_text(request: Request):
|
||||
"""FastAPI endpoint to handle text generation requests."""
|
||||
global engine
|
||||
if not engine:
|
||||
return {"error": "Engine not initialized"}, 503
|
||||
|
||||
try:
|
||||
data = await request.json()
|
||||
prompt = data.get("prompt")
|
||||
max_new_tokens = data.get("max_new_tokens", 128)
|
||||
temperature = data.get("temperature", 0.7)
|
||||
|
||||
if not prompt:
|
||||
return {"error": "Prompt is required"}, 400
|
||||
|
||||
# Use async_generate for non-blocking generation
|
||||
state = await engine.async_generate(
|
||||
prompt,
|
||||
sampling_params={
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
# Add other parameters like stop, top_p etc. as needed
|
||||
)
|
||||
|
||||
return {"generated_text": state["text"]}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 500
|
||||
|
||||
|
||||
# Helper function to start the server
|
||||
def start_server(args, timeout=60):
|
||||
"""Starts the Uvicorn server as a subprocess and waits for it to be ready."""
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
command = [
|
||||
"python",
|
||||
"-m",
|
||||
"uvicorn",
|
||||
"fastapi_engine_inference:app",
|
||||
f"--host={args.host}",
|
||||
f"--port={args.port}",
|
||||
]
|
||||
|
||||
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
with requests.Session() as session:
|
||||
while time.perf_counter() - start_time < timeout:
|
||||
try:
|
||||
# Check the /docs endpoint which FastAPI provides by default
|
||||
response = session.get(
|
||||
f"{base_url}/docs", timeout=5
|
||||
) # Add a request timeout
|
||||
if response.status_code == 200:
|
||||
print(f"Server {base_url} is ready (responded on /docs)")
|
||||
return process
|
||||
except requests.ConnectionError:
|
||||
# Specific exception for connection refused/DNS error etc.
|
||||
pass
|
||||
except requests.Timeout:
|
||||
# Specific exception for request timeout
|
||||
print(f"Health check to {base_url}/docs timed out, retrying...")
|
||||
pass
|
||||
except requests.RequestException as e:
|
||||
# Catch other request exceptions
|
||||
print(f"Health check request error: {e}, retrying...")
|
||||
pass
|
||||
# Use a shorter sleep interval for faster startup detection
|
||||
time.sleep(1)
|
||||
|
||||
# If loop finishes, raise the timeout error
|
||||
# Attempt to terminate the failed process before raising
|
||||
if process:
|
||||
print(
|
||||
"Server failed to start within timeout, attempting to terminate process..."
|
||||
)
|
||||
terminate_process(process) # Use the imported terminate_process
|
||||
raise TimeoutError(
|
||||
f"Server failed to start at {base_url} within the timeout period."
|
||||
)
|
||||
|
||||
|
||||
def send_requests(server_url, prompts, max_new_tokens, temperature):
|
||||
"""Sends generation requests to the running server for a list of prompts."""
|
||||
# Iterate through prompts and send requests
|
||||
for i, prompt in enumerate(prompts):
|
||||
print(f"\n[{i+1}/{len(prompts)}] Sending prompt: '{prompt}'")
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(f"{server_url}/generate", json=payload, timeout=60)
|
||||
|
||||
result = response.json()
|
||||
|
||||
print(f"Prompt: {prompt}\nResponse: {result['generated_text']}")
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
print(f" Error: Request timed out for prompt '{prompt}'")
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f" Error sending request for prompt '{prompt}': {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Main entry point for the script."""
|
||||
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--model-path", type=str, default="Qwen/Qwen2.5-0.5B-Instruct")
|
||||
parser.add_argument("--tp_size", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Pass the model to the child uvicorn process via an env var
|
||||
os.environ["MODEL_PATH"] = args.model_path
|
||||
os.environ["TP_SIZE"] = str(args.tp_size)
|
||||
|
||||
# Start the server
|
||||
process = start_server(args)
|
||||
|
||||
# Define the prompts and sampling parameters
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
max_new_tokens = 64
|
||||
temperature = 0.1
|
||||
|
||||
# Define server url
|
||||
server_url = f"http://{args.host}:{args.port}"
|
||||
|
||||
# Send requests to the server
|
||||
send_requests(server_url, prompts, max_new_tokens, temperature)
|
||||
|
||||
# Terminate the server process
|
||||
terminate_process(process)
|
||||
17
examples/runtime/engine/launch_engine.py
Normal file
17
examples/runtime/engine/launch_engine.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
This example demonstrates how to launch the offline engine.
|
||||
"""
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
def main():
|
||||
llm = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||
llm.generate("What is the capital of France?")
|
||||
llm.shutdown()
|
||||
|
||||
|
||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
43
examples/runtime/engine/offline_batch_inference.py
Normal file
43
examples/runtime/engine/offline_batch_inference.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 offline_batch_inference.py --model meta-llama/Llama-3.1-8B-Instruct
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
def main(
|
||||
server_args: ServerArgs,
|
||||
):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = {"temperature": 0.8, "top_p": 0.95}
|
||||
|
||||
# Create an LLM.
|
||||
llm = sgl.Engine(**dataclasses.asdict(server_args))
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
print("===============================")
|
||||
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
|
||||
|
||||
|
||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
main(server_args)
|
||||
65
examples/runtime/engine/offline_batch_inference_async.py
Normal file
65
examples/runtime/engine/offline_batch_inference_async.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
Usage:
|
||||
python offline_batch_inference_async.py --model-path Qwen/Qwen2-VL-7B-Instruct
|
||||
|
||||
Note:
|
||||
This demo shows the usage of async generation,
|
||||
which is useful to implement an online-like generation with batched inference.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import time
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
class InferenceEngine:
|
||||
def __init__(self, **kwargs):
|
||||
self.engine = sgl.Engine(**kwargs)
|
||||
|
||||
async def generate(self, prompt, sampling_params):
|
||||
result = await self.engine.async_generate(prompt, sampling_params)
|
||||
return result
|
||||
|
||||
|
||||
async def run_server(server_args):
|
||||
inference = InferenceEngine(**dataclasses.asdict(server_args))
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
] * 100
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = {"temperature": 0.8, "top_p": 0.95}
|
||||
|
||||
# Run the generation tasks concurrently in async mode.
|
||||
tasks = []
|
||||
for prompt in prompts:
|
||||
task = asyncio.create_task(inference.generate(prompt, sampling_params))
|
||||
tasks.append(task)
|
||||
|
||||
# Get and print the result
|
||||
for task in tasks:
|
||||
await task
|
||||
while True:
|
||||
if not task.done():
|
||||
time.sleep(1)
|
||||
else:
|
||||
result = task.result()
|
||||
print(f"Generated text: {result['text']}")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
asyncio.run(run_server(server_args))
|
||||
38
examples/runtime/engine/offline_batch_inference_eagle.py
Normal file
38
examples/runtime/engine/offline_batch_inference_eagle.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
def main():
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = {"temperature": 0, "max_new_tokens": 30}
|
||||
|
||||
# Create an LLM.
|
||||
llm = sgl.Engine(
|
||||
model_path="meta-llama/Llama-2-7b-chat-hf",
|
||||
speculative_algorithm="EAGLE",
|
||||
speculative_draft_model_path="lmsys/sglang-EAGLE-llama2-chat-7B",
|
||||
speculative_num_steps=3,
|
||||
speculative_eagle_topk=4,
|
||||
speculative_num_draft_tokens=16,
|
||||
cuda_graph_max_bs=8,
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
print("===============================")
|
||||
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
|
||||
|
||||
|
||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
74
examples/runtime/engine/offline_batch_inference_qwen_1m.py
Normal file
74
examples/runtime/engine/offline_batch_inference_qwen_1m.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 offline_batch_inference.py
|
||||
"""
|
||||
|
||||
from urllib.request import urlopen
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
def load_prompt() -> str:
|
||||
# Test cases with various lengths can be found at:
|
||||
#
|
||||
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt
|
||||
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt
|
||||
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt
|
||||
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
|
||||
|
||||
with urlopen(
|
||||
"https://qianwen-res.oss-cn-beijing.aliyuncs.com"
|
||||
"/Qwen2.5-1M/test-data/64k.txt",
|
||||
timeout=5,
|
||||
) as response:
|
||||
prompt = response.read().decode("utf-8")
|
||||
return prompt
|
||||
|
||||
|
||||
# Processing the prompt.
|
||||
def process_requests(llm: sgl.Engine, prompts: list[str]) -> None:
|
||||
# Create a sampling params object.
|
||||
sampling_params = {
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.8,
|
||||
"top_k": 20,
|
||||
"repetition_penalty": 1.05,
|
||||
"max_new_tokens": 256,
|
||||
}
|
||||
# Generate texts from the prompts.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt_token_ids = output["meta_info"]["prompt_tokens"]
|
||||
generated_text = output["text"]
|
||||
print(
|
||||
f"Prompt length: {prompt_token_ids}, " f"Generated text: {generated_text!r}"
|
||||
)
|
||||
|
||||
|
||||
# Create an LLM.
|
||||
def initialize_engine() -> sgl.Engine:
|
||||
llm = sgl.Engine(
|
||||
model_path="Qwen/Qwen2.5-7B-Instruct-1M",
|
||||
context_length=1048576,
|
||||
page_size=256,
|
||||
attention_backend="dual_chunk_flash_attn",
|
||||
tp_size=4,
|
||||
disable_radix_cache=True,
|
||||
enable_mixed_chunk=False,
|
||||
enable_torch_compile=False,
|
||||
chunked_prefill_size=131072,
|
||||
mem_fraction_static=0.6,
|
||||
log_level="DEBUG",
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
def main():
|
||||
llm = initialize_engine()
|
||||
prompt = load_prompt()
|
||||
process_requests(llm, [prompt])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
52
examples/runtime/engine/offline_batch_inference_vlm.py
Normal file
52
examples/runtime/engine/offline_batch_inference_vlm.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Usage:
|
||||
python offline_batch_inference_vlm.py --model-path Qwen/Qwen2-VL-7B-Instruct
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.srt.parser.conversation import chat_templates
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
def main(
|
||||
server_args: ServerArgs,
|
||||
):
|
||||
vlm = sgl.Engine(**dataclasses.asdict(server_args))
|
||||
|
||||
conv = chat_templates[server_args.chat_template].copy()
|
||||
image_token = conv.image_token
|
||||
|
||||
image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||
|
||||
prompt = f"What's in this image?\n{image_token}"
|
||||
|
||||
sampling_params = {
|
||||
"temperature": 0.001,
|
||||
"max_new_tokens": 30,
|
||||
}
|
||||
|
||||
output = vlm.generate(
|
||||
prompt=prompt,
|
||||
image_data=image_url,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
print("===============================")
|
||||
print(f"Prompt: {prompt}")
|
||||
print(f"Generated text: {output['text']}")
|
||||
|
||||
vlm.shutdown()
|
||||
|
||||
|
||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
main(server_args)
|
||||
54
examples/runtime/engine/readme.md
Normal file
54
examples/runtime/engine/readme.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# SGLang Engine
|
||||
|
||||
SGLang provides a direct inference engine without the need for an HTTP server. There are generally these use cases:
|
||||
|
||||
- [Offline Batch Inference](#offline-batch-inference)
|
||||
- [Embedding Generation](#embedding-generation)
|
||||
- [Custom Server](#custom-server)
|
||||
- [Token-In-Token-Out for RLHF](#token-in-token-out-for-rlhf)
|
||||
- [Inference Using FastAPI](#inference-using-fastapi)
|
||||
|
||||
## Examples
|
||||
|
||||
### [Offline Batch Inference](./offline_batch_inference.py)
|
||||
|
||||
In this example, we launch an SGLang engine and feed a batch of inputs for inference. If you provide a very large batch, the engine will intelligently schedule the requests to process efficiently and prevent OOM (Out of Memory) errors.
|
||||
|
||||
### [Embedding Generation](./embedding.py)
|
||||
|
||||
In this example, we launch an SGLang engine and feed a batch of inputs for embedding generation.
|
||||
|
||||
### [Custom Server](./custom_server.py)
|
||||
|
||||
This example demonstrates how to create a custom server on top of the SGLang Engine. We use [Sanic](https://sanic.dev/en/) as an example. The server supports both non-streaming and streaming endpoints.
|
||||
|
||||
#### Steps
|
||||
|
||||
1. Install Sanic:
|
||||
|
||||
```bash
|
||||
pip install sanic
|
||||
```
|
||||
|
||||
2. Run the server:
|
||||
|
||||
```bash
|
||||
python custom_server
|
||||
```
|
||||
|
||||
3. Send requests:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/generate -H "Content-Type: application/json" -d '{"prompt": "The Transformer architecture is..."}'
|
||||
curl -X POST http://localhost:8000/generate_stream -H "Content-Type: application/json" -d '{"prompt": "The Transformer architecture is..."}' --no-buffer
|
||||
```
|
||||
|
||||
This will send both non-streaming and streaming requests to the server.
|
||||
|
||||
### [Token-In-Token-Out for RLHF](../token_in_token_out)
|
||||
|
||||
In this example, we launch an SGLang engine, feed tokens as input and generate tokens as output.
|
||||
|
||||
### [Inference Using FastAPI](fastapi_engine_inference.py)
|
||||
|
||||
This example demonstrates how to create a FastAPI server that uses the SGLang engine for text generation.
|
||||
59
examples/runtime/engine/save_remote_state.py
Normal file
59
examples/runtime/engine/save_remote_state.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Saves each worker's model state dict directly to a checkpoint, which enables a
|
||||
fast load path for large tensor-parallel models where each worker only needs to
|
||||
read its own shard rather than the entire checkpoint.
|
||||
|
||||
Example usage:
|
||||
|
||||
python save_remote_state.py \
|
||||
--model-path /path/to/load \
|
||||
--tensor-parallel-size 8 \
|
||||
--remote-model-save-url [protocol]://[host]:[port]/[model_name] \
|
||||
|
||||
Then, the model can be loaded with
|
||||
|
||||
llm = Engine(
|
||||
model_path="[protocol]://[host]:[port]/[model_name]",
|
||||
tensor_parallel_size=8,
|
||||
)
|
||||
"""
|
||||
import dataclasses
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
from sglang import Engine, ServerArgs
|
||||
|
||||
parser = ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--remote-model-save-url",
|
||||
required=True,
|
||||
type=str,
|
||||
help="remote address to store model weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote-draft-model-save-url",
|
||||
default=None,
|
||||
type=str,
|
||||
help="remote address to store draft model weights",
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
engine_args = ServerArgs.from_cli_args(args)
|
||||
model_path = engine_args.model_path
|
||||
if not Path(model_path).is_dir():
|
||||
raise ValueError("model path must be a local directory")
|
||||
# Create LLM instance from arguments
|
||||
llm = Engine(**dataclasses.asdict(engine_args))
|
||||
llm.save_remote_model(
|
||||
url=args.remote_model_save_url, draft_url=args.remote_draft_model_save_url
|
||||
)
|
||||
print("save remote (draft) model successfully")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
74
examples/runtime/engine/save_sharded_state.py
Normal file
74
examples/runtime/engine/save_sharded_state.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Saves each worker's model state dict directly to a checkpoint, which enables a
|
||||
fast load path for large tensor-parallel models where each worker only needs to
|
||||
read its own shard rather than the entire checkpoint.
|
||||
|
||||
Example usage:
|
||||
|
||||
python save_sharded_state.py \
|
||||
--model-path /path/to/load \
|
||||
--quantization deepspeedfp \
|
||||
--tensor-parallel-size 8 \
|
||||
--output /path/to/save
|
||||
|
||||
Then, the model can be loaded with
|
||||
|
||||
llm = Engine(
|
||||
model_path="/path/to/save",
|
||||
load_format="sharded_state",
|
||||
quantization="deepspeedfp",
|
||||
tensor_parallel_size=8,
|
||||
)
|
||||
"""
|
||||
import dataclasses
|
||||
import os
|
||||
import shutil
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
from sglang import Engine, ServerArgs
|
||||
|
||||
parser = ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--output", "-o", required=True, type=str, help="path to output checkpoint"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--file-pattern", type=str, help="string pattern of saved filenames"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-file-size",
|
||||
type=str,
|
||||
default=5 * 1024**3,
|
||||
help="max size (in bytes) of each safetensors file",
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
engine_args = ServerArgs.from_cli_args(args)
|
||||
model_path = engine_args.model_path
|
||||
if not Path(model_path).is_dir():
|
||||
raise ValueError("model path must be a local directory")
|
||||
# Create LLM instance from arguments
|
||||
llm = Engine(**dataclasses.asdict(engine_args))
|
||||
Path(args.output).mkdir(exist_ok=True)
|
||||
llm.save_sharded_model(
|
||||
path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
|
||||
)
|
||||
|
||||
# Copy metadata files to output directory
|
||||
for file in os.listdir(model_path):
|
||||
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
|
||||
if os.path.isdir(os.path.join(model_path, file)):
|
||||
shutil.copytree(
|
||||
os.path.join(model_path, file), os.path.join(args.output, file)
|
||||
)
|
||||
else:
|
||||
shutil.copy(os.path.join(model_path, file), args.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
66
examples/runtime/hidden_states/hidden_states_engine.py
Normal file
66
examples/runtime/hidden_states/hidden_states_engine.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Usage:
|
||||
python hidden_states.py
|
||||
|
||||
Note that each time you change the `return_hidden_states` parameter,
|
||||
the cuda graph will be recaptured, which might lead to a performance hit.
|
||||
So avoid getting hidden states and completions alternately.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
def main():
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create an LLM.
|
||||
llm = sgl.Engine(
|
||||
model_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
||||
enable_return_hidden_states=True,
|
||||
)
|
||||
|
||||
sampling_params = {
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.95,
|
||||
"max_new_tokens": 10,
|
||||
}
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts, sampling_params=sampling_params, return_hidden_states=True
|
||||
)
|
||||
|
||||
llm.shutdown()
|
||||
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
for i in range(len(output["meta_info"]["hidden_states"])):
|
||||
output["meta_info"]["hidden_states"][i] = torch.tensor(
|
||||
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
|
||||
)
|
||||
print("===============================")
|
||||
print(
|
||||
f"Prompt: {prompt}\n"
|
||||
f"Generated text: {output['text']}\n"
|
||||
f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t"
|
||||
f"Completion_tokens: {output['meta_info']['completion_tokens']}"
|
||||
)
|
||||
print("Hidden states: ")
|
||||
hidden_states = torch.cat(
|
||||
[
|
||||
i.unsqueeze(0) if len(i.shape) == 1 else i
|
||||
for i in output["meta_info"]["hidden_states"]
|
||||
]
|
||||
)
|
||||
print(hidden_states)
|
||||
print()
|
||||
|
||||
|
||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
81
examples/runtime/hidden_states/hidden_states_server.py
Normal file
81
examples/runtime/hidden_states/hidden_states_server.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Usage:
|
||||
|
||||
python hidden_states_server.py
|
||||
|
||||
Note that each time you change the `return_hidden_states` parameter,
|
||||
the cuda graph will be recaptured, which might lead to a performance hit.
|
||||
So avoid getting hidden states and completions alternately.
|
||||
"""
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from sglang.test.test_utils import is_in_ci
|
||||
from sglang.utils import terminate_process, wait_for_server
|
||||
|
||||
if is_in_ci():
|
||||
from docs.backend.patch import launch_server_cmd
|
||||
else:
|
||||
from sglang.utils import launch_server_cmd
|
||||
|
||||
|
||||
def main():
|
||||
# Launch the server
|
||||
server_process, port = launch_server_cmd(
|
||||
"python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-1.5B-instruct --enable-return-hidden-states --host 0.0.0.0"
|
||||
)
|
||||
wait_for_server(f"http://localhost:{port}")
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = {
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.95,
|
||||
"max_new_tokens": 10,
|
||||
}
|
||||
|
||||
json_data = {
|
||||
"text": prompts,
|
||||
"sampling_params": sampling_params,
|
||||
"return_hidden_states": True,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"http://localhost:{port}/generate",
|
||||
json=json_data,
|
||||
)
|
||||
|
||||
terminate_process(server_process)
|
||||
|
||||
outputs = response.json()
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
for i in range(len(output["meta_info"]["hidden_states"])):
|
||||
output["meta_info"]["hidden_states"][i] = torch.tensor(
|
||||
output["meta_info"]["hidden_states"][i], dtype=torch.bfloat16
|
||||
)
|
||||
print("===============================")
|
||||
print(
|
||||
f"Prompt: {prompt}\n"
|
||||
f"Generated text: {output['text']}\n"
|
||||
f"Prompt_Tokens: {output['meta_info']['prompt_tokens']}\t"
|
||||
f"Completion_tokens: {output['meta_info']['completion_tokens']}"
|
||||
)
|
||||
print("Hidden states: ")
|
||||
hidden_states = torch.cat(
|
||||
[
|
||||
i.unsqueeze(0) if len(i.shape) == 1 else i
|
||||
for i in output["meta_info"]["hidden_states"]
|
||||
]
|
||||
)
|
||||
print(hidden_states)
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
37
examples/runtime/lora.py
Normal file
37
examples/runtime/lora.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# launch server
|
||||
# python -m sglang.launch_server --model mistralai/Mistral-7B-Instruct-v0.3 --lora-paths /home/ying/test_lora lora1=/home/ying/test_lora_1 lora2=/home/ying/test_lora_2 --disable-radix --disable-cuda-graph --max-loras-per-batch 4
|
||||
|
||||
# send requests
|
||||
# lora_path[i] specifies the LoRA used for text[i], so make sure they have the same length
|
||||
# use None to specify base-only prompt, e.x. "lora_path": [None, "/home/ying/test_lora"]
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
url = "http://127.0.0.1:30000"
|
||||
json_data = {
|
||||
"text": [
|
||||
"prompt 1",
|
||||
"prompt 2",
|
||||
"prompt 3",
|
||||
"prompt 4",
|
||||
"prompt 5",
|
||||
"prompt 6",
|
||||
"prompt 7",
|
||||
],
|
||||
"sampling_params": {"max_new_tokens": 32},
|
||||
"lora_path": [
|
||||
"/home/ying/test_lora",
|
||||
"lora1",
|
||||
"lora2",
|
||||
"lora1",
|
||||
"lora2",
|
||||
None,
|
||||
None,
|
||||
],
|
||||
}
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json=json_data,
|
||||
)
|
||||
print(json.dumps(response.json()))
|
||||
111
examples/runtime/multimodal/llama3_llava_server.py
Normal file
111
examples/runtime/multimodal/llama3_llava_server.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
Usage:
|
||||
# Installing latest llava-next: pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
|
||||
# Installing latest sglang.
|
||||
|
||||
# Endpoint Service CLI:
|
||||
python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000
|
||||
|
||||
python3 llama3_llava_server.py
|
||||
|
||||
Output:
|
||||
"Friends posing for a fun photo with a life-sized teddy bear, creating a playful and memorable moment."
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from llava.conversation import conv_llava_llama_3
|
||||
|
||||
|
||||
async def send_request(url, data, delay=0):
|
||||
await asyncio.sleep(delay)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data) as resp:
|
||||
output = await resp.json()
|
||||
return output
|
||||
|
||||
|
||||
async def test_concurrent(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
prompt = "<image>\nPlease generate caption towards this image."
|
||||
conv_template = copy.deepcopy(conv_llava_llama_3)
|
||||
conv_template.append_message(role=conv_template.roles[0], message=prompt)
|
||||
conv_template.append_message(role=conv_template.roles[1], message=None)
|
||||
prompt_with_template = conv_template.get_prompt()
|
||||
response = []
|
||||
for i in range(1):
|
||||
response.append(
|
||||
send_request(
|
||||
url + "/generate",
|
||||
{
|
||||
"text": prompt_with_template,
|
||||
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": 0,
|
||||
"top_p": 1.0,
|
||||
"presence_penalty": 2,
|
||||
"frequency_penalty": 2,
|
||||
"stop": "<|eot_id|>",
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
rets = await asyncio.gather(*response)
|
||||
for ret in rets:
|
||||
print(ret["text"])
|
||||
|
||||
|
||||
def test_streaming(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
prompt = "<image>\nPlease generate caption towards this image."
|
||||
conv_template = copy.deepcopy(conv_llava_llama_3)
|
||||
conv_template.append_message(role=conv_template.roles[0], message=prompt)
|
||||
conv_template.append_message(role=conv_template.roles[1], message=None)
|
||||
prompt_with_template = conv_template.get_prompt()
|
||||
pload = {
|
||||
"text": prompt_with_template,
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": 0,
|
||||
"top_p": 1.0,
|
||||
"presence_penalty": 2,
|
||||
"frequency_penalty": 2,
|
||||
"stop": "<|eot_id|>",
|
||||
},
|
||||
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
|
||||
"stream": True,
|
||||
}
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json=pload,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
prev = 0
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
output = data["text"].strip()
|
||||
print(output[prev:], end="", flush=True)
|
||||
prev = len(output)
|
||||
print("")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
asyncio.run(test_concurrent(args))
|
||||
test_streaming(args)
|
||||
264
examples/runtime/multimodal/llava_onevision_server.py
Normal file
264
examples/runtime/multimodal/llava_onevision_server.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
Usage:
|
||||
|
||||
python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8
|
||||
|
||||
python3 llava_onevision_server.py
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import requests
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
|
||||
# pip install httpx==0.23.3
|
||||
# pip install decord
|
||||
# pip install protobuf==3.20.0
|
||||
|
||||
|
||||
def download_video(url, cache_dir):
|
||||
file_path = os.path.join(cache_dir, "jobs.mp4")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
print(f"File downloaded and saved to: {file_path}")
|
||||
return file_path
|
||||
|
||||
|
||||
def create_openai_client(base_url):
|
||||
return openai.Client(api_key="EMPTY", base_url=base_url)
|
||||
|
||||
|
||||
def image_stream_request_test(client):
|
||||
print("----------------------Image Stream Request Test----------------------")
|
||||
stream_request = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Please describe this image. Please list the benchmarks and the models.",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=1024,
|
||||
stream=True,
|
||||
)
|
||||
stream_response = ""
|
||||
|
||||
for chunk in stream_request:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
content = chunk.choices[0].delta.content
|
||||
stream_response += content
|
||||
sys.stdout.write(content)
|
||||
sys.stdout.flush()
|
||||
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
def multi_image_stream_request_test(client):
|
||||
print(
|
||||
"----------------------Multi-Images Stream Request Test----------------------"
|
||||
)
|
||||
stream_request = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
|
||||
},
|
||||
"modalities": "multi-images",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
|
||||
},
|
||||
"modalities": "multi-images",
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "I have shown you two images. Please describe the two images to me.",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=1024,
|
||||
stream=True,
|
||||
)
|
||||
stream_response = ""
|
||||
|
||||
for chunk in stream_request:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
content = chunk.choices[0].delta.content
|
||||
stream_response += content
|
||||
sys.stdout.write(content)
|
||||
sys.stdout.flush()
|
||||
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
def video_stream_request_test(client, video_path):
|
||||
print("------------------------Video Stream Request Test----------------------")
|
||||
messages = prepare_video_messages(video_path)
|
||||
|
||||
video_request = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_tokens=1024,
|
||||
stream=True,
|
||||
)
|
||||
print("-" * 30)
|
||||
video_response = ""
|
||||
|
||||
for chunk in video_request:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
content = chunk.choices[0].delta.content
|
||||
video_response += content
|
||||
sys.stdout.write(content)
|
||||
sys.stdout.flush()
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
def image_speed_test(client):
|
||||
print("----------------------Image Speed Test----------------------")
|
||||
start_time = time.perf_counter()
|
||||
request = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Please describe this image. Please list the benchmarks and the models.",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=1024,
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
response = request.choices[0].message.content
|
||||
print(response)
|
||||
print("-" * 30)
|
||||
print_speed_test_results(request, start_time, end_time)
|
||||
|
||||
|
||||
def video_speed_test(client, video_path):
|
||||
print("------------------------Video Speed Test------------------------")
|
||||
messages = prepare_video_messages(video_path)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
video_request = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_tokens=1024,
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
video_response = video_request.choices[0].message.content
|
||||
print(video_response)
|
||||
print("-" * 30)
|
||||
print_speed_test_results(video_request, start_time, end_time)
|
||||
|
||||
|
||||
def prepare_video_messages(video_path):
|
||||
max_frames_num = 32
|
||||
vr = VideoReader(video_path, ctx=cpu(0))
|
||||
total_frame_num = len(vr)
|
||||
uniform_sampled_frames = np.linspace(
|
||||
0, total_frame_num - 1, max_frames_num, dtype=int
|
||||
)
|
||||
frame_idx = uniform_sampled_frames.tolist()
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
|
||||
base64_frames = []
|
||||
for frame in frames:
|
||||
pil_img = Image.fromarray(frame)
|
||||
buff = io.BytesIO()
|
||||
pil_img.save(buff, format="JPEG")
|
||||
base64_str = base64.b64encode(buff.getvalue()).decode("utf-8")
|
||||
base64_frames.append(base64_str)
|
||||
|
||||
messages = [{"role": "user", "content": []}]
|
||||
|
||||
for base64_frame in base64_frames:
|
||||
frame_format = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{base64_frame}"},
|
||||
"modalities": "video",
|
||||
}
|
||||
messages[0]["content"].append(frame_format)
|
||||
|
||||
prompt = {"type": "text", "text": "Please describe the video in detail."}
|
||||
messages[0]["content"].append(prompt)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def print_speed_test_results(request, start_time, end_time):
|
||||
total_tokens = request.usage.total_tokens
|
||||
completion_tokens = request.usage.completion_tokens
|
||||
prompt_tokens = request.usage.prompt_tokens
|
||||
|
||||
print(f"Total tokens: {total_tokens}")
|
||||
print(f"Completion tokens: {completion_tokens}")
|
||||
print(f"Prompt tokens: {prompt_tokens}")
|
||||
print(f"Time taken: {end_time - start_time} seconds")
|
||||
print(f"Token per second: {total_tokens / (end_time - start_time)}")
|
||||
print(f"Completion token per second: {completion_tokens / (end_time - start_time)}")
|
||||
print(f"Prompt token per second: {prompt_tokens / (end_time - start_time)}")
|
||||
|
||||
|
||||
def main():
|
||||
url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
|
||||
cache_dir = os.path.expanduser("~/.cache")
|
||||
video_path = download_video(url, cache_dir)
|
||||
|
||||
client = create_openai_client("http://127.0.0.1:30000/v1")
|
||||
|
||||
image_stream_request_test(client)
|
||||
multi_image_stream_request_test(client)
|
||||
video_stream_request_test(client, video_path)
|
||||
image_speed_test(client)
|
||||
video_speed_test(client, video_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
127
examples/runtime/multimodal/pixtral_server.py
Normal file
127
examples/runtime/multimodal/pixtral_server.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Usage:
|
||||
# Run a Pixtral model with SGLang:
|
||||
# HuggingFace:
|
||||
python -m sglang.launch_server --model-path mistral-community/pixtral-12b --port=30000
|
||||
# ModelScope:
|
||||
python -m sglang.launch_server --model-path AI-ModelScope/pixtral-12b --port=30000
|
||||
|
||||
# Then test it with:
|
||||
python pixtral_server.py
|
||||
|
||||
This script tests Pixtral model with both single and multiple images.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
IMAGE_TOKEN_SEP = "\n[IMG]"
|
||||
ROUTE = "/generate"
|
||||
|
||||
|
||||
async def send_request(url, data, delay=0):
|
||||
await asyncio.sleep(delay)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data) as resp:
|
||||
output = await resp.json()
|
||||
return output
|
||||
|
||||
|
||||
async def test_concurrent(args):
|
||||
url = f"{args.host}:{args.port}{ROUTE}"
|
||||
|
||||
# Single image test
|
||||
if args.single_image:
|
||||
prompt = f"<s>[INST]Describe this image in detail.{IMAGE_TOKEN_SEP}[/INST]"
|
||||
image_url = "https://picsum.photos/id/237/400/300"
|
||||
modality = ["image"]
|
||||
# Multiple images test
|
||||
else:
|
||||
image_urls = [
|
||||
"https://picsum.photos/id/237/400/300",
|
||||
"https://picsum.photos/id/27/500/500",
|
||||
]
|
||||
prompt = f"<s>[INST]How many photos are there? Describe each in a very short sentence.{IMAGE_TOKEN_SEP * len(image_urls)}[/INST]"
|
||||
image_url = image_urls
|
||||
modality = ["multi-images"]
|
||||
|
||||
response = await send_request(
|
||||
url,
|
||||
{
|
||||
"text": prompt,
|
||||
"image_data": image_url,
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 100,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
"modalities": modality,
|
||||
},
|
||||
)
|
||||
|
||||
print(f"Response: {response}")
|
||||
if "text" in response:
|
||||
print("\nOutput text:", response["text"])
|
||||
|
||||
|
||||
def test_streaming(args):
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
|
||||
# Single image test
|
||||
if args.single_image:
|
||||
prompt = f"<s>[INST]Describe this image in detail.{IMAGE_TOKEN_SEP}[/INST]"
|
||||
image_data = "https://picsum.photos/id/237/400/300"
|
||||
modality = ["image"]
|
||||
# Multiple images test
|
||||
else:
|
||||
image_urls = [
|
||||
"https://picsum.photos/id/237/400/300",
|
||||
"https://picsum.photos/id/27/500/500",
|
||||
]
|
||||
prompt = f"<s>[INST]How many photos are there? Describe each in a very short sentence.{IMAGE_TOKEN_SEP * len(image_urls)}[/INST]"
|
||||
image_data = image_urls
|
||||
modality = ["multi-images"]
|
||||
|
||||
pload = {
|
||||
"text": prompt,
|
||||
"image_data": image_data,
|
||||
"sampling_params": {"max_new_tokens": 100, "temperature": 0.7, "top_p": 0.9},
|
||||
"modalities": modality,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
response = requests.post(url, json=pload, stream=True)
|
||||
|
||||
print("Streaming response:")
|
||||
prev = 0
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
output = data["text"].strip()
|
||||
print(output[prev:], end="", flush=True)
|
||||
prev = len(output)
|
||||
print("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
parser.add_argument(
|
||||
"--single-image",
|
||||
action="store_true",
|
||||
help="Test with single image instead of multiple images",
|
||||
)
|
||||
parser.add_argument("--no-stream", action="store_true", help="Don't test streaming")
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(test_concurrent(args))
|
||||
if not args.no_stream:
|
||||
test_streaming(args)
|
||||
111
examples/runtime/multimodal/qwen_llava_server.py
Normal file
111
examples/runtime/multimodal/qwen_llava_server.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
Usage:
|
||||
# Installing latest llava-next: pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
|
||||
# Installing latest sglang.
|
||||
|
||||
# Endpoint Service CLI:
|
||||
python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8
|
||||
|
||||
python3 qwen_llava_server.py
|
||||
|
||||
Output:
|
||||
"Two children pose with a large teddy bear, one holding a smaller stuffed bear, in a room with an American flag and potted plants."
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from llava.conversation import conv_qwen
|
||||
|
||||
|
||||
async def send_request(url, data, delay=0):
|
||||
await asyncio.sleep(delay)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data) as resp:
|
||||
output = await resp.json()
|
||||
return output
|
||||
|
||||
|
||||
async def test_concurrent(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
prompt = "<image>\nPlease generate caption towards this image."
|
||||
conv_template = copy.deepcopy(conv_qwen)
|
||||
conv_template.append_message(role=conv_template.roles[0], message=prompt)
|
||||
conv_template.append_message(role=conv_template.roles[1], message=None)
|
||||
prompt_with_template = conv_template.get_prompt()
|
||||
response = []
|
||||
for i in range(1):
|
||||
response.append(
|
||||
send_request(
|
||||
url + "/generate",
|
||||
{
|
||||
"text": prompt_with_template,
|
||||
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": 0,
|
||||
"top_p": 1.0,
|
||||
"presence_penalty": 2,
|
||||
"frequency_penalty": 2,
|
||||
"stop": "<|im_end|>",
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
rets = await asyncio.gather(*response)
|
||||
for ret in rets:
|
||||
print(ret["text"])
|
||||
|
||||
|
||||
def test_streaming(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
prompt = "<image>\nPlease generate caption towards this image."
|
||||
conv_template = copy.deepcopy(conv_qwen)
|
||||
conv_template.append_message(role=conv_template.roles[0], message=prompt)
|
||||
conv_template.append_message(role=conv_template.roles[1], message=None)
|
||||
prompt_with_template = conv_template.get_prompt()
|
||||
pload = {
|
||||
"text": prompt_with_template,
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": 0,
|
||||
"top_p": 1.0,
|
||||
"presence_penalty": 2,
|
||||
"frequency_penalty": 2,
|
||||
"stop": "<|im_end|>",
|
||||
},
|
||||
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
|
||||
"stream": True,
|
||||
}
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json=pload,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
prev = 0
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
output = data["text"].strip()
|
||||
print(output[prev:], end="", flush=True)
|
||||
prev = len(output)
|
||||
print("")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
asyncio.run(test_concurrent(args))
|
||||
test_streaming(args)
|
||||
18
examples/runtime/multimodal_embedding.py
Normal file
18
examples/runtime/multimodal_embedding.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# launch server
|
||||
# python -m sglang.launch_server --model-path Alibaba-NLP/gme-Qwen2-VL-2B-Instruct --is-embedding
|
||||
|
||||
import requests
|
||||
|
||||
url = "http://127.0.0.1:30000"
|
||||
|
||||
text_input = "Represent this image in embedding space."
|
||||
image_path = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/023.jpg"
|
||||
|
||||
payload = {
|
||||
"model": "gme-qwen2-vl",
|
||||
"input": [{"text": text_input}, {"image": image_path}],
|
||||
}
|
||||
|
||||
response = requests.post(url + "/v1/embeddings", json=payload).json()
|
||||
|
||||
print("Embeddings:", [x.get("embedding") for x in response.get("data", [])])
|
||||
53
examples/runtime/openai_chat_with_response_prefill.py
Normal file
53
examples/runtime/openai_chat_with_response_prefill.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
Usage:
|
||||
1) Launch the server in one terminal:
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --port 30000
|
||||
|
||||
2) Run this script in another terminal:
|
||||
python openai_chat_with_response_prefill.py
|
||||
|
||||
This example demonstrates two chat completion calls:
|
||||
- One with continue_final_message enabled (the final assistant message is used as a prefill).
|
||||
- One without continue_final_message (the final assistant message remains, starting a new turn).
|
||||
"""
|
||||
|
||||
import openai
|
||||
|
||||
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful AI assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": """
|
||||
Extract the name, size, price, and color from this product description as a JSON object:
|
||||
|
||||
<description>
|
||||
The SmartHome Mini is a compact smart home assistant available in black or white for only $49.99.
|
||||
At just 5 inches wide, it lets you control lights, thermostats, and other connected devices via voice or app—
|
||||
no matter where you place it in your home.
|
||||
This affordable little hub brings convenient hands-free control to your smart devices.
|
||||
</description>
|
||||
""",
|
||||
},
|
||||
{"role": "assistant", "content": "{\n"},
|
||||
]
|
||||
|
||||
# Calling the API with continue_final_message enabled.
|
||||
print("=== Prefill with continue_final_messagem ===")
|
||||
response_with = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
extra_body={"continue_final_message": True},
|
||||
)
|
||||
print(response_with.choices[0].message.content)
|
||||
|
||||
# Calling the API without continue_final_message (using default behavior).
|
||||
print("\n=== Prefill without continue_final_message ===")
|
||||
response_without = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
)
|
||||
print(response_without.choices[0].message.content)
|
||||
32
examples/runtime/reward_model.py
Normal file
32
examples/runtime/reward_model.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# launch server
|
||||
# python -m sglang.launch_server --model LxzGordon/URM-LLaMa-3.1-8B --is-embedding
|
||||
|
||||
import requests
|
||||
|
||||
url = "http://127.0.0.1:30000"
|
||||
|
||||
PROMPT = (
|
||||
"What is the range of the numeric output of a sigmoid node in a neural network?"
|
||||
)
|
||||
RESPONSE1 = "The output of a sigmoid node is bounded between -1 and 1."
|
||||
RESPONSE2 = "The output of a sigmoid node is bounded between 0 and 1."
|
||||
|
||||
json_data = {
|
||||
"conv": [
|
||||
[
|
||||
{"role": "user", "content": PROMPT},
|
||||
{"role": "assistant", "content": RESPONSE1},
|
||||
],
|
||||
[
|
||||
{"role": "user", "content": PROMPT},
|
||||
{"role": "assistant", "content": RESPONSE2},
|
||||
],
|
||||
],
|
||||
}
|
||||
response = requests.post(
|
||||
url + "/classify",
|
||||
json=json_data,
|
||||
).json()
|
||||
|
||||
print(response)
|
||||
print("scores:", [x["embedding"] for x in response])
|
||||
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
This example demonstrates how to provide tokenized ids to LLM as input instead of text prompt, i.e. a token-in-token-out workflow.
|
||||
"""
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
|
||||
MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
|
||||
def main():
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = {"temperature": 0.8, "top_p": 0.95}
|
||||
|
||||
# Tokenize inputs
|
||||
tokenizer = get_tokenizer(MODEL_PATH)
|
||||
token_ids_list = [tokenizer.encode(prompt) for prompt in prompts]
|
||||
|
||||
# Create an LLM.
|
||||
llm = sgl.Engine(model_path=MODEL_PATH, skip_tokenizer_init=True)
|
||||
|
||||
outputs = llm.generate(input_ids=token_ids_list, sampling_params=sampling_params)
|
||||
# Print the outputs.
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
decode_output = tokenizer.decode(output["output_ids"])
|
||||
print("===============================")
|
||||
print(
|
||||
f"Prompt: {prompt}\nGenerated token ids: {output['output_ids']}\nGenerated text: {decode_output}"
|
||||
)
|
||||
print()
|
||||
|
||||
|
||||
# The __main__ condition is necessary here because we use "spawn" to create subprocesses
|
||||
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Usage:
|
||||
|
||||
python token_in_token_out_llm_server.py
|
||||
|
||||
"""
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.test.test_utils import is_in_ci
|
||||
from sglang.utils import terminate_process, wait_for_server
|
||||
|
||||
if is_in_ci():
|
||||
from docs.backend.patch import launch_server_cmd
|
||||
else:
|
||||
from sglang.utils import launch_server_cmd
|
||||
|
||||
|
||||
MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
|
||||
def main():
|
||||
# Launch the server
|
||||
server_process, port = launch_server_cmd(
|
||||
f"python -m sglang.launch_server --model-path {MODEL_PATH} --skip-tokenizer-init --host 0.0.0.0"
|
||||
)
|
||||
wait_for_server(f"http://localhost:{port}")
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = {"temperature": 0.8, "top_p": 0.95}
|
||||
|
||||
# Tokenize inputs
|
||||
tokenizer = get_tokenizer(MODEL_PATH)
|
||||
token_ids_list = [tokenizer.encode(prompt) for prompt in prompts]
|
||||
|
||||
json_data = {
|
||||
"input_ids": token_ids_list,
|
||||
"sampling_params": sampling_params,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"http://localhost:{port}/generate",
|
||||
json=json_data,
|
||||
)
|
||||
|
||||
outputs = response.json()
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
print("===============================")
|
||||
decode_output = tokenizer.decode(output["output_ids"])
|
||||
print(
|
||||
f"Prompt: {prompt}\nGenerated token ids: {output['output_ids']}\nGenerated text: {decode_output}"
|
||||
)
|
||||
print()
|
||||
|
||||
terminate_process(server_process)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,74 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
from typing import Tuple
|
||||
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from sglang import Engine
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.test.test_utils import DEFAULT_IMAGE_URL
|
||||
|
||||
|
||||
def get_input_ids(
|
||||
server_args: ServerArgs, model_config: ModelConfig
|
||||
) -> Tuple[list[int], list]:
|
||||
chat_template = get_chat_template_by_model_path(model_config.model_path)
|
||||
text = f"{chat_template.image_token}What is in this picture?"
|
||||
image_data = [DEFAULT_IMAGE_URL]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_config.model_path, trust_remote_code=server_args.trust_remote_code
|
||||
)
|
||||
|
||||
input_ids = (
|
||||
processor.tokenizer(
|
||||
text=[text],
|
||||
return_tensors="pt",
|
||||
)
|
||||
.input_ids[0]
|
||||
.tolist()
|
||||
)
|
||||
|
||||
return input_ids, image_data
|
||||
|
||||
|
||||
def token_in_out_example(
|
||||
server_args: ServerArgs,
|
||||
):
|
||||
input_ids, image_data = get_input_ids(
|
||||
server_args,
|
||||
ModelConfig(
|
||||
server_args.model_path,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
model_override_args=server_args.json_model_override_args,
|
||||
),
|
||||
)
|
||||
backend = Engine(**dataclasses.asdict(server_args))
|
||||
|
||||
output = backend.generate(
|
||||
input_ids=input_ids,
|
||||
image_data=image_data,
|
||||
sampling_params={
|
||||
"temperature": 0.8,
|
||||
"max_new_tokens": 32,
|
||||
},
|
||||
)
|
||||
|
||||
print("===============================")
|
||||
print(f"Output token ids: ", output["output_ids"])
|
||||
|
||||
backend.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = [
|
||||
"--model-path=Qwen/Qwen2-VL-2B",
|
||||
]
|
||||
args = parser.parse_args(args=args)
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
server_args.skip_tokenizer_init = True
|
||||
token_in_out_example(server_args)
|
||||
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Usage:
|
||||
|
||||
python token_in_token_out_vlm_server.py
|
||||
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import requests
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.test.test_utils import DEFAULT_IMAGE_URL, is_in_ci
|
||||
from sglang.utils import terminate_process, wait_for_server
|
||||
|
||||
if is_in_ci():
|
||||
from docs.backend.patch import launch_server_cmd
|
||||
else:
|
||||
from sglang.utils import launch_server_cmd
|
||||
|
||||
|
||||
MODEL_PATH = "Qwen/Qwen2-VL-2B"
|
||||
|
||||
|
||||
def get_input_ids() -> Tuple[list[int], list]:
|
||||
chat_template = get_chat_template_by_model_path(MODEL_PATH)
|
||||
text = f"{chat_template.image_token}What is in this picture?"
|
||||
image_data = [DEFAULT_IMAGE_URL]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(MODEL_PATH)
|
||||
|
||||
input_ids = (
|
||||
processor.tokenizer(
|
||||
text=[text],
|
||||
return_tensors="pt",
|
||||
)
|
||||
.input_ids[0]
|
||||
.tolist()
|
||||
)
|
||||
|
||||
return input_ids, image_data
|
||||
|
||||
|
||||
def main():
|
||||
# Launch the server
|
||||
server_process, port = launch_server_cmd(
|
||||
f"python -m sglang.launch_server --model-path {MODEL_PATH} --skip-tokenizer-init --host 0.0.0.0"
|
||||
)
|
||||
wait_for_server(f"http://localhost:{port}")
|
||||
|
||||
input_ids, image_data = get_input_ids()
|
||||
|
||||
sampling_params = {
|
||||
"temperature": 0.8,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
json_data = {
|
||||
"input_ids": input_ids,
|
||||
"image_data": image_data,
|
||||
"sampling_params": sampling_params,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"http://localhost:{port}/generate",
|
||||
json=json_data,
|
||||
)
|
||||
|
||||
output = response.json()
|
||||
print("===============================")
|
||||
print(f"Output token ids: ", output["output_ids"])
|
||||
|
||||
terminate_process(server_process)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
66
examples/runtime/vertex_predict.py
Normal file
66
examples/runtime/vertex_predict.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Usage:
|
||||
python -m sglang.launch_server --model meta-llama/Llama-2-7b-hf --port 30000
|
||||
python vertex_predict.py
|
||||
|
||||
This example shows the request and response formats of the prediction route for
|
||||
Google Cloud Vertex AI Online Predictions.
|
||||
|
||||
Vertex AI SDK for Python is recommended for deploying models to Vertex AI
|
||||
instead of a local server. After deploying the model to a Vertex AI Online
|
||||
Prediction Endpoint, send requests via the Python SDK:
|
||||
|
||||
response = endpoint.predict(
|
||||
instances=[
|
||||
{"text": "The capital of France is"},
|
||||
{"text": "What is a car?"},
|
||||
],
|
||||
parameters={"sampling_params": {"max_new_tokens": 16}},
|
||||
)
|
||||
print(response.predictions)
|
||||
|
||||
More details about get online predictions from Vertex AI can be found at
|
||||
https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
@dataclass
|
||||
class VertexPrediction:
|
||||
predictions: List
|
||||
|
||||
|
||||
class LocalVertexEndpoint:
|
||||
def __init__(self) -> None:
|
||||
self.base_url = "http://127.0.0.1:30000"
|
||||
|
||||
def predict(self, instances: List[dict], parameters: Optional[dict] = None):
|
||||
response = requests.post(
|
||||
self.base_url + "/vertex_generate",
|
||||
json={
|
||||
"instances": instances,
|
||||
"parameters": parameters,
|
||||
},
|
||||
)
|
||||
return VertexPrediction(predictions=response.json()["predictions"])
|
||||
|
||||
|
||||
endpoint = LocalVertexEndpoint()
|
||||
|
||||
# Predict with a single prompt.
|
||||
response = endpoint.predict(instances=[{"text": "The capital of France is"}])
|
||||
print(response.predictions)
|
||||
|
||||
# Predict with multiple prompts and parameters.
|
||||
response = endpoint.predict(
|
||||
instances=[
|
||||
{"text": "The capital of France is"},
|
||||
{"text": "What is a car?"},
|
||||
],
|
||||
parameters={"sampling_params": {"max_new_tokens": 16}},
|
||||
)
|
||||
print(response.predictions)
|
||||
Reference in New Issue
Block a user