[Feat/WIP] add llava-onevision, with support for (1) siglip encoder, (2) qwen2 decoder (3) openai api compatible server. (#1123)

Co-authored-by: Bo Li <drluodian@gmail.com>
This commit is contained in:
Kaichen Zhang - NTU
2024-08-24 05:11:16 +08:00
committed by GitHub
parent 5fafcac008
commit a5b14ad043
13 changed files with 703 additions and 95 deletions

View File

@@ -0,0 +1,211 @@
import base64
import io
import os
import sys
import time
import numpy as np
import openai
import requests
from decord import VideoReader, cpu
from PIL import Image
# pip install httpx==0.23.3
# pip install decord
# pip install protobuf==3.20.0
def download_video(url, cache_dir):
file_path = os.path.join(cache_dir, "jobs.mp4")
os.makedirs(cache_dir, exist_ok=True)
response = requests.get(url)
response.raise_for_status()
with open(file_path, "wb") as f:
f.write(response.content)
print(f"File downloaded and saved to: {file_path}")
return file_path
def create_openai_client(base_url):
return openai.Client(api_key="EMPTY", base_url=base_url)
def image_stream_request_test(client):
print("----------------------Image Stream Request Test----------------------")
stream_request = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
},
{
"type": "text",
"text": "Please describe this image. Please list the benchmarks and the models.",
},
],
},
],
temperature=0.7,
max_tokens=1024,
stream=True,
)
stream_response = ""
for chunk in stream_request:
if chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
stream_response += content
sys.stdout.write(content)
sys.stdout.flush()
print("-" * 30)
def video_stream_request_test(client, video_path):
print("------------------------Video Stream Request Test----------------------")
messages = prepare_video_messages(video_path)
start_time = time.time()
video_request = client.chat.completions.create(
model="default",
messages=messages,
temperature=0,
max_tokens=1024,
stream=True,
)
print("-" * 30)
video_response = ""
for chunk in video_request:
if chunk.choices[0].delta.content is not None:
content = chunk.choices[0].delta.content
video_response += content
sys.stdout.write(content)
sys.stdout.flush()
print("-" * 30)
def image_speed_test(client):
print("----------------------Image Speed Test----------------------")
start_time = time.time()
request = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
},
{
"type": "text",
"text": "Please describe this image. Please list the benchmarks and the models.",
},
],
},
],
temperature=0,
max_tokens=1024,
)
end_time = time.time()
response = request.choices[0].message.content
print(response)
print("-" * 30)
print_speed_test_results(request, start_time, end_time)
def video_speed_test(client, video_path):
print("------------------------Video Speed Test------------------------")
messages = prepare_video_messages(video_path)
start_time = time.time()
video_request = client.chat.completions.create(
model="default",
messages=messages,
temperature=0,
max_tokens=1024,
)
end_time = time.time()
video_response = video_request.choices[0].message.content
print(video_response)
print("-" * 30)
print_speed_test_results(video_request, start_time, end_time)
def prepare_video_messages(video_path):
max_frames_num = 32
vr = VideoReader(video_path, ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(
0, total_frame_num - 1, max_frames_num, dtype=int
)
frame_idx = uniform_sampled_frames.tolist()
frames = vr.get_batch(frame_idx).asnumpy()
base64_frames = []
for frame in frames:
pil_img = Image.fromarray(frame)
buff = io.BytesIO()
pil_img.save(buff, format="JPEG")
base64_str = base64.b64encode(buff.getvalue()).decode("utf-8")
base64_frames.append(base64_str)
messages = [{"role": "user", "content": []}]
frame_format = {
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,{}"},
}
for base64_frame in base64_frames:
frame_format["image_url"]["url"] = "data:image/jpeg;base64,{}".format(
base64_frame
)
messages[0]["content"].append(frame_format.copy())
prompt = {"type": "text", "text": "Please describe the video in detail."}
messages[0]["content"].append(prompt)
return messages
def print_speed_test_results(request, start_time, end_time):
total_tokens = request.usage.total_tokens
completion_tokens = request.usage.completion_tokens
prompt_tokens = request.usage.prompt_tokens
print(f"Total tokens: {total_tokens}")
print(f"Completion tokens: {completion_tokens}")
print(f"Prompt tokens: {prompt_tokens}")
print(f"Time taken: {end_time - start_time} seconds")
print(f"Token per second: {total_tokens / (end_time - start_time)}")
print(f"Completion token per second: {completion_tokens / (end_time - start_time)}")
print(f"Prompt token per second: {prompt_tokens / (end_time - start_time)}")
def main():
url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
cache_dir = os.path.expanduser("~/.cache")
video_path = download_video(url, cache_dir)
client = create_openai_client("http://127.0.0.1:30000/v1")
image_stream_request_test(client)
video_stream_request_test(client, video_path)
image_speed_test(client)
video_speed_test(client, video_path)
if __name__ == "__main__":
main()

View File

@@ -121,6 +121,20 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=
if __name__ == "__main__":
url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
cache_dir = os.path.expanduser("~/.cache")
file_path = os.path.join(cache_dir, "jobs.mp4")
os.makedirs(cache_dir, exist_ok=True)
response = requests.get(url)
response.raise_for_status() # Raise an exception for bad responses
with open(file_path, "wb") as f:
f.write(response.content)
print(f"File downloaded and saved to: {file_path}")
# Create the parser
parser = argparse.ArgumentParser(
description="Run video processing with specified port."
@@ -148,7 +162,7 @@ if __name__ == "__main__":
parser.add_argument(
"--video-dir",
type=str,
default="./videos/Q98Z4OTh8RwmDonc.mp4",
default=os.path.expanduser("~/.cache/jobs.mp4"),
help="The directory or path for the processed video files.",
)
parser.add_argument(