ci: simplify multi-modality tests by using mixins (#9006)

This commit is contained in:
Mick
2025-08-17 13:25:02 +08:00
committed by GitHub
parent 66d6be0874
commit 1df84ff414
6 changed files with 264 additions and 400 deletions

View File

@@ -1,8 +1,6 @@
import base64
import io
import json
import os
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import openai
@@ -10,12 +8,7 @@ import requests
from PIL import Image
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
from sglang.test.test_utils import DEFAULT_URL_FOR_TEST, CustomTestCase
# image
IMAGE_MAN_IRONING_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/man_ironing_on_back_of_suv.png"
@@ -29,33 +22,123 @@ AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test
AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
class TestOpenAIVisionServer(CustomTestCase):
class TestOpenAIOmniServerBase(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmms-lab/llava-onevision-qwen2-0.5b-ov"
cls.model = ""
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
)
cls.process = None
cls.base_url += "/v1"
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def get_audio_request_kwargs(self):
return self.get_request_kwargs()
def get_vision_request_kwargs(self):
return self.get_request_kwargs()
def get_request_kwargs(self):
return {}
def get_or_download_file(self, url: str) -> str:
cache_dir = os.path.expanduser("~/.cache")
if url is None:
raise ValueError()
file_name = url.split("/")[-1]
file_path = os.path.join(cache_dir, file_name)
os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(file_path):
response = requests.get(url)
response.raise_for_status()
with open(file_path, "wb") as f:
f.write(response.content)
return file_path
class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
def prepare_audio_messages(self, prompt, audio_file_name):
messages = [
{
"role": "user",
"content": [
{
"type": "audio_url",
"audio_url": {"url": f"{audio_file_name}"},
},
{
"type": "text",
"text": prompt,
},
],
}
]
return messages
def get_audio_request_kwargs(self):
return self.get_request_kwargs()
def get_audio_response(self, url: str, prompt, category):
audio_file_path = self.get_or_download_file(url)
client = openai.Client(api_key="sk-123456", base_url=self.base_url)
messages = self.prepare_audio_messages(prompt, audio_file_path)
response = client.chat.completions.create(
model="default",
messages=messages,
temperature=0,
max_tokens=128,
stream=False,
**(self.get_audio_request_kwargs()),
)
audio_response = response.choices[0].message.content
print("-" * 30)
print(f"audio {category} response:\n{audio_response}")
print("-" * 30)
audio_response = audio_response.lower()
self.assertIsNotNone(audio_response)
self.assertGreater(len(audio_response), 0)
return audio_response.lower()
def test_audio_speech_completion(self):
# a fragment of Trump's speech
audio_response = self.get_audio_response(
AUDIO_TRUMP_SPEECH_URL,
"Listen to this audio and write down the audio transcription in English.",
category="speech",
)
check_list = [
"thank you",
"it's a privilege to be here",
"leader",
"science",
"art",
]
for check_word in check_list:
assert (
check_word in audio_response
), f"audio_response: {audio_response} should contain {check_word}"
def test_audio_ambient_completion(self):
# bird song
audio_response = self.get_audio_response(
AUDIO_BIRD_SONG_URL,
"Please listen to the audio snippet carefully and transcribe the content in English.",
"ambient",
)
assert "bird" in audio_response
class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
def test_single_image_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
@@ -316,38 +399,6 @@ class TestOpenAIVisionServer(CustomTestCase):
return messages
def prepare_video_messages(self, video_path):
messages = [
{
"role": "user",
"content": [
{
"type": "video_url",
"video_url": {"url": f"{video_path}"},
},
{"type": "text", "text": "Please describe the video in detail."},
],
},
]
return messages
def get_or_download_file(self, url: str) -> str:
cache_dir = os.path.expanduser("~/.cache")
if url is None:
raise ValueError()
file_name = url.split("/")[-1]
file_path = os.path.join(cache_dir, file_name)
os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(file_path):
response = requests.get(url)
response.raise_for_status()
with open(file_path, "wb") as f:
f.write(response.content)
return file_path
# this test samples frames of video as input, but not video directly
def test_video_images_chat_completion(self):
url = VIDEO_JOBS_URL
file_path = self.get_or_download_file(url)
@@ -409,7 +460,24 @@ class TestOpenAIVisionServer(CustomTestCase):
self.assertIsNotNone(video_response)
self.assertGreater(len(video_response), 0)
def _test_video_chat_completion(self):
class VideoOpenAITestMixin(TestOpenAIOmniServerBase):
def prepare_video_messages(self, video_path):
messages = [
{
"role": "user",
"content": [
{
"type": "video_url",
"video_url": {"url": f"{video_path}"},
},
{"type": "text", "text": "Please describe the video in detail."},
],
},
]
return messages
def test_video_chat_completion(self):
url = VIDEO_JOBS_URL
file_path = self.get_or_download_file(url)
@@ -457,170 +525,3 @@ class TestOpenAIVisionServer(CustomTestCase):
), f"video_response: {video_response}, should contain 'black' or 'dark'"
self.assertIsNotNone(video_response)
self.assertGreater(len(video_response), 0)
def test_regex(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
regex = (
r"""\{"""
+ r""""color":"[\w]+","""
+ r""""number_of_cars":[\d]+"""
+ r"""\}"""
)
extra_kwargs = self.get_vision_request_kwargs()
extra_kwargs.setdefault("extra_body", {})["regex"] = regex
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "text",
"text": "Describe this image in the JSON format.",
},
],
},
],
temperature=0,
**extra_kwargs,
)
text = response.choices[0].message.content
try:
js_obj = json.loads(text)
except (TypeError, json.decoder.JSONDecodeError):
print("JSONDecodeError", text)
raise
assert isinstance(js_obj["color"], str)
assert isinstance(js_obj["number_of_cars"], int)
def run_decode_with_image(self, image_id):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
content = []
if image_id == 0:
content.append(
{
"type": "image_url",
"image_url": {"url": IMAGE_MAN_IRONING_URL},
}
)
elif image_id == 1:
content.append(
{
"type": "image_url",
"image_url": {"url": IMAGE_SGL_LOGO_URL},
}
)
else:
pass
content.append(
{
"type": "text",
"text": "Describe this image in a sentence.",
}
)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "user", "content": content},
],
temperature=0,
**(self.get_vision_request_kwargs()),
)
assert response.choices[0].message.role == "assistant"
text = response.choices[0].message.content
assert isinstance(text, str)
def test_mixed_batch(self):
image_ids = [0, 1, 2] * 4
with ThreadPoolExecutor(4) as executor:
list(executor.map(self.run_decode_with_image, image_ids))
def prepare_audio_messages(self, prompt, audio_file_name):
messages = [
{
"role": "user",
"content": [
{
"type": "audio_url",
"audio_url": {"url": f"{audio_file_name}"},
},
{
"type": "text",
"text": prompt,
},
],
}
]
return messages
def get_audio_response(self, url: str, prompt, category):
audio_file_path = self.get_or_download_file(url)
client = openai.Client(api_key="sk-123456", base_url=self.base_url)
messages = self.prepare_audio_messages(prompt, audio_file_path)
response = client.chat.completions.create(
model="default",
messages=messages,
temperature=0,
max_tokens=128,
stream=False,
**(self.get_audio_request_kwargs()),
)
audio_response = response.choices[0].message.content
print("-" * 30)
print(f"audio {category} response:\n{audio_response}")
print("-" * 30)
audio_response = audio_response.lower()
self.assertIsNotNone(audio_response)
self.assertGreater(len(audio_response), 0)
return audio_response.lower()
def _test_audio_speech_completion(self):
# a fragment of Trump's speech
audio_response = self.get_audio_response(
AUDIO_TRUMP_SPEECH_URL,
"Listen to this audio and write down the audio transcription in English.",
category="speech",
)
check_list = [
"thank you",
"it's a privilege to be here",
"leader",
"science",
"art",
]
for check_word in check_list:
assert (
check_word in audio_response
), f"audio_response: {audio_response} should contain {check_word}"
def _test_audio_ambient_completion(self):
# bird song
audio_response = self.get_audio_response(
AUDIO_BIRD_SONG_URL,
"Please listen to the audio snippet carefully and transcribe the content in English.",
"ambient",
)
assert "bird" in audio_response
def test_audio_chat_completion(self):
pass