model(vlm): pixtral (#5084)

This commit is contained in:
Kiv Chen
2025-05-13 00:16:10 -07:00
committed by GitHub
parent b2e95f62b4
commit 5380cd7ea3
16 changed files with 1125 additions and 39 deletions

View File

@@ -0,0 +1,111 @@
"""
Usage:
# Installing latest llava-next: pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
# Installing latest sglang.
# Endpoint Service CLI:
python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000
python3 llama3_llava_server.py
Output:
"Friends posing for a fun photo with a life-sized teddy bear, creating a playful and memorable moment."
"""
import argparse
import asyncio
import copy
import json
import aiohttp
import requests
from llava.conversation import conv_llava_llama_3
async def send_request(url, data, delay=0):
await asyncio.sleep(delay)
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as resp:
output = await resp.json()
return output
async def test_concurrent(args):
url = f"{args.host}:{args.port}"
prompt = "<image>\nPlease generate caption towards this image."
conv_template = copy.deepcopy(conv_llava_llama_3)
conv_template.append_message(role=conv_template.roles[0], message=prompt)
conv_template.append_message(role=conv_template.roles[1], message=None)
prompt_with_template = conv_template.get_prompt()
response = []
for i in range(1):
response.append(
send_request(
url + "/generate",
{
"text": prompt_with_template,
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
"sampling_params": {
"max_new_tokens": 1024,
"temperature": 0,
"top_p": 1.0,
"presence_penalty": 2,
"frequency_penalty": 2,
"stop": "<|eot_id|>",
},
},
)
)
rets = await asyncio.gather(*response)
for ret in rets:
print(ret["text"])
def test_streaming(args):
url = f"{args.host}:{args.port}"
prompt = "<image>\nPlease generate caption towards this image."
conv_template = copy.deepcopy(conv_llava_llama_3)
conv_template.append_message(role=conv_template.roles[0], message=prompt)
conv_template.append_message(role=conv_template.roles[1], message=None)
prompt_with_template = conv_template.get_prompt()
pload = {
"text": prompt_with_template,
"sampling_params": {
"max_new_tokens": 1024,
"temperature": 0,
"top_p": 1.0,
"presence_penalty": 2,
"frequency_penalty": 2,
"stop": "<|eot_id|>",
},
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
"stream": True,
}
response = requests.post(
url + "/generate",
json=pload,
stream=True,
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
asyncio.run(test_concurrent(args))
test_streaming(args)

View File

@@ -0,0 +1,264 @@
"""
Usage:
python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8
python3 llava_onevision_server.py
"""
import base64
import io
import os
import sys
import time
import numpy as np
import openai
import requests
from decord import VideoReader, cpu
from PIL import Image
# pip install httpx==0.23.3
# pip install decord
# pip install protobuf==3.20.0
def download_video(url, cache_dir):
file_path = os.path.join(cache_dir, "jobs.mp4")
os.makedirs(cache_dir, exist_ok=True)
response = requests.get(url)
response.raise_for_status()
with open(file_path, "wb") as f:
f.write(response.content)
print(f"File downloaded and saved to: {file_path}")
return file_path
def create_openai_client(base_url):
return openai.Client(api_key="EMPTY", base_url=base_url)
def image_stream_request_test(client):
print("----------------------Image Stream Request Test----------------------")
stream_request = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
},
{
"type": "text",
"text": "Please describe this image. Please list the benchmarks and the models.",
},
],
},
],
temperature=0.7,
max_tokens=1024,
stream=True,
)
stream_response = ""
for chunk in stream_request:
if chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
stream_response += content
sys.stdout.write(content)
sys.stdout.flush()
print("-" * 30)
def multi_image_stream_request_test(client):
print(
"----------------------Multi-Images Stream Request Test----------------------"
)
stream_request = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
"modalities": "multi-images",
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
},
"modalities": "multi-images",
},
{
"type": "text",
"text": "I have shown you two images. Please describe the two images to me.",
},
],
},
],
temperature=0.7,
max_tokens=1024,
stream=True,
)
stream_response = ""
for chunk in stream_request:
if chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
stream_response += content
sys.stdout.write(content)
sys.stdout.flush()
print("-" * 30)
def video_stream_request_test(client, video_path):
print("------------------------Video Stream Request Test----------------------")
messages = prepare_video_messages(video_path)
video_request = client.chat.completions.create(
model="default",
messages=messages,
temperature=0,
max_tokens=1024,
stream=True,
)
print("-" * 30)
video_response = ""
for chunk in video_request:
if chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
video_response += content
sys.stdout.write(content)
sys.stdout.flush()
print("-" * 30)
def image_speed_test(client):
print("----------------------Image Speed Test----------------------")
start_time = time.time()
request = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
},
{
"type": "text",
"text": "Please describe this image. Please list the benchmarks and the models.",
},
],
},
],
temperature=0,
max_tokens=1024,
)
end_time = time.time()
response = request.choices[0].message.content
print(response)
print("-" * 30)
print_speed_test_results(request, start_time, end_time)
def video_speed_test(client, video_path):
print("------------------------Video Speed Test------------------------")
messages = prepare_video_messages(video_path)
start_time = time.time()
video_request = client.chat.completions.create(
model="default",
messages=messages,
temperature=0,
max_tokens=1024,
)
end_time = time.time()
video_response = video_request.choices[0].message.content
print(video_response)
print("-" * 30)
print_speed_test_results(video_request, start_time, end_time)
def prepare_video_messages(video_path):
max_frames_num = 32
vr = VideoReader(video_path, ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(
0, total_frame_num - 1, max_frames_num, dtype=int
)
frame_idx = uniform_sampled_frames.tolist()
frames = vr.get_batch(frame_idx).asnumpy()
base64_frames = []
for frame in frames:
pil_img = Image.fromarray(frame)
buff = io.BytesIO()
pil_img.save(buff, format="JPEG")
base64_str = base64.b64encode(buff.getvalue()).decode("utf-8")
base64_frames.append(base64_str)
messages = [{"role": "user", "content": []}]
for base64_frame in base64_frames:
frame_format = {
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_frame}"},
"modalities": "video",
}
messages[0]["content"].append(frame_format)
prompt = {"type": "text", "text": "Please describe the video in detail."}
messages[0]["content"].append(prompt)
return messages
def print_speed_test_results(request, start_time, end_time):
total_tokens = request.usage.total_tokens
completion_tokens = request.usage.completion_tokens
prompt_tokens = request.usage.prompt_tokens
print(f"Total tokens: {total_tokens}")
print(f"Completion tokens: {completion_tokens}")
print(f"Prompt tokens: {prompt_tokens}")
print(f"Time taken: {end_time - start_time} seconds")
print(f"Token per second: {total_tokens / (end_time - start_time)}")
print(f"Completion token per second: {completion_tokens / (end_time - start_time)}")
print(f"Prompt token per second: {prompt_tokens / (end_time - start_time)}")
def main():
url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
cache_dir = os.path.expanduser("~/.cache")
video_path = download_video(url, cache_dir)
client = create_openai_client("http://127.0.0.1:30000/v1")
image_stream_request_test(client)
multi_image_stream_request_test(client)
video_stream_request_test(client, video_path)
image_speed_test(client)
video_speed_test(client, video_path)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,127 @@
"""
Usage:
# Run a Pixtral model with SGLang:
# HuggingFace:
python -m sglang.launch_server --model-path mistral-community/pixtral-12b --port=30000
# ModelScope:
python -m sglang.launch_server --model-path AI-ModelScope/pixtral-12b --port=30000
# Then test it with:
python pixtral_server.py
This script tests Pixtral model with both single and multiple images.
"""
import argparse
import asyncio
import json
import aiohttp
import requests
IMAGE_TOKEN_SEP = "\n[IMG]"
ROUTE = "/generate"
async def send_request(url, data, delay=0):
await asyncio.sleep(delay)
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as resp:
output = await resp.json()
return output
async def test_concurrent(args):
url = f"{args.host}:{args.port}{ROUTE}"
# Single image test
if args.single_image:
prompt = f"<s>[INST]Describe this image in detail.{IMAGE_TOKEN_SEP}[/INST]"
image_url = "https://picsum.photos/id/237/400/300"
modality = ["image"]
# Multiple images test
else:
image_urls = [
"https://picsum.photos/id/237/400/300",
"https://picsum.photos/id/27/500/500",
]
prompt = f"<s>[INST]How many photos are there? Describe each in a very short sentence.{IMAGE_TOKEN_SEP * len(image_urls)}[/INST]"
image_url = image_urls
modality = ["multi-images"]
response = await send_request(
url,
{
"text": prompt,
"image_data": image_url,
"sampling_params": {
"max_new_tokens": 100,
"temperature": 0.7,
"top_p": 0.9,
},
"modalities": modality,
},
)
print(f"Response: {response}")
if "text" in response:
print("\nOutput text:", response["text"])
def test_streaming(args):
url = f"{args.host}:{args.port}/generate"
# Single image test
if args.single_image:
prompt = f"<s>[INST]Describe this image in detail.{IMAGE_TOKEN_SEP}[/INST]"
image_data = "https://picsum.photos/id/237/400/300"
modality = ["image"]
# Multiple images test
else:
image_urls = [
"https://picsum.photos/id/237/400/300",
"https://picsum.photos/id/27/500/500",
]
prompt = f"<s>[INST]How many photos are there? Describe each in a very short sentence.{IMAGE_TOKEN_SEP * len(image_urls)}[/INST]"
image_data = image_urls
modality = ["multi-images"]
pload = {
"text": prompt,
"image_data": image_data,
"sampling_params": {"max_new_tokens": 100, "temperature": 0.7, "top_p": 0.9},
"modalities": modality,
"stream": True,
}
response = requests.post(url, json=pload, stream=True)
print("Streaming response:")
prev = 0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
parser.add_argument(
"--single-image",
action="store_true",
help="Test with single image instead of multiple images",
)
parser.add_argument("--no-stream", action="store_true", help="Don't test streaming")
args = parser.parse_args()
asyncio.run(test_concurrent(args))
if not args.no_stream:
test_streaming(args)

View File

@@ -0,0 +1,111 @@
"""
Usage:
# Installing latest llava-next: pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
# Installing latest sglang.
# Endpoint Service CLI:
python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8
python3 qwen_llava_server.py
Output:
"Two children pose with a large teddy bear, one holding a smaller stuffed bear, in a room with an American flag and potted plants."
"""
import argparse
import asyncio
import copy
import json
import aiohttp
import requests
from llava.conversation import conv_qwen
async def send_request(url, data, delay=0):
await asyncio.sleep(delay)
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as resp:
output = await resp.json()
return output
async def test_concurrent(args):
url = f"{args.host}:{args.port}"
prompt = "<image>\nPlease generate caption towards this image."
conv_template = copy.deepcopy(conv_qwen)
conv_template.append_message(role=conv_template.roles[0], message=prompt)
conv_template.append_message(role=conv_template.roles[1], message=None)
prompt_with_template = conv_template.get_prompt()
response = []
for i in range(1):
response.append(
send_request(
url + "/generate",
{
"text": prompt_with_template,
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
"sampling_params": {
"max_new_tokens": 1024,
"temperature": 0,
"top_p": 1.0,
"presence_penalty": 2,
"frequency_penalty": 2,
"stop": "<|im_end|>",
},
},
)
)
rets = await asyncio.gather(*response)
for ret in rets:
print(ret["text"])
def test_streaming(args):
url = f"{args.host}:{args.port}"
prompt = "<image>\nPlease generate caption towards this image."
conv_template = copy.deepcopy(conv_qwen)
conv_template.append_message(role=conv_template.roles[0], message=prompt)
conv_template.append_message(role=conv_template.roles[1], message=None)
prompt_with_template = conv_template.get_prompt()
pload = {
"text": prompt_with_template,
"sampling_params": {
"max_new_tokens": 1024,
"temperature": 0,
"top_p": 1.0,
"presence_penalty": 2,
"frequency_penalty": 2,
"stop": "<|im_end|>",
},
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
"stream": True,
}
response = requests.post(
url + "/generate",
json=pload,
stream=True,
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
asyncio.run(test_concurrent(args))
test_streaming(args)