ci: simplify multi-modality tests by using mixins (#9006)
This commit is contained in:
@@ -217,9 +217,9 @@ class BaseMultimodalProcessor(ABC):
|
||||
if videos:
|
||||
kwargs["videos"] = videos
|
||||
if audios:
|
||||
if self.arch in {
|
||||
"Gemma3nForConditionalGeneration",
|
||||
"Qwen2AudioForConditionalGeneration",
|
||||
if self._processor.__class__.__name__ in {
|
||||
"Gemma3nProcessor",
|
||||
"Qwen2AudioProcessor",
|
||||
}:
|
||||
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
|
||||
kwargs["audio"] = audios
|
||||
|
||||
@@ -18,7 +18,7 @@ from sglang.srt.models.llavavid import LlavaVidForCausalLM
|
||||
from sglang.srt.models.mistral import Mistral3ForConditionalGeneration
|
||||
from sglang.srt.multimodal.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
|
||||
from sglang.srt.utils import load_image, logger
|
||||
from sglang.srt.utils import ImageData, load_image, logger
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
|
||||
@staticmethod
|
||||
def _process_single_image_task(
|
||||
image_data: Union[str, bytes],
|
||||
image_data: Union[str, bytes, ImageData],
|
||||
image_aspect_ratio: Optional[str] = None,
|
||||
image_grid_pinpoints: Optional[str] = None,
|
||||
processor=None,
|
||||
@@ -44,10 +44,11 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
image_processor = processor.image_processor
|
||||
|
||||
try:
|
||||
image, image_size = load_image(image_data)
|
||||
url = image_data.url if isinstance(image_data, ImageData) else image_data
|
||||
image, image_size = load_image(url)
|
||||
if image_size is not None:
|
||||
# It is a video with multiple images
|
||||
image_hash = hash(image_data)
|
||||
image_hash = hash(url)
|
||||
pixel_values = image_processor(image)["pixel_values"]
|
||||
for _ in range(len(pixel_values)):
|
||||
pixel_values[_] = pixel_values[_].astype(np.float16)
|
||||
@@ -55,7 +56,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
return pixel_values, image_hash, image_size
|
||||
else:
|
||||
# It is an image
|
||||
image_hash = hash(image_data)
|
||||
image_hash = hash(url)
|
||||
if image_aspect_ratio == "pad":
|
||||
image = expand2square(
|
||||
image,
|
||||
@@ -82,7 +83,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
async def _process_single_image(
|
||||
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
||||
self,
|
||||
image_data: Union[bytes, str, ImageData],
|
||||
aspect_ratio: str,
|
||||
grid_pinpoints: str,
|
||||
):
|
||||
if self.cpu_executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -104,7 +108,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
image_data: List[Union[str, bytes, ImageData]],
|
||||
input_text,
|
||||
request_obj,
|
||||
*args,
|
||||
|
||||
@@ -110,8 +110,8 @@ suites = {
|
||||
TestFile("test_utils_update_weights.py", 48),
|
||||
TestFile("test_vision_chunked_prefill.py", 175),
|
||||
TestFile("test_vlm_input_format.py", 300),
|
||||
TestFile("test_vision_openai_server_a.py", 989),
|
||||
TestFile("test_vision_openai_server_b.py", 620),
|
||||
TestFile("test_vision_openai_server_a.py", 403),
|
||||
TestFile("test_vision_openai_server_b.py", 446),
|
||||
],
|
||||
"per-commit-2-gpu": [
|
||||
TestFile("lora/test_lora_tp.py", 116),
|
||||
|
||||
@@ -8,16 +8,28 @@ import unittest
|
||||
|
||||
from test_vision_openai_server_common import *
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestQwen2VLServer(TestOpenAIVisionServer):
|
||||
class TestLlava(ImageOpenAITestMixin):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "lmms-lab/llava-onevision-qwen2-0.5b-ov"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
|
||||
class TestQwen2VLServer(ImageOpenAITestMixin, VideoOpenAITestMixin):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "Qwen/Qwen2-VL-7B-Instruct"
|
||||
@@ -37,11 +49,8 @@ class TestQwen2VLServer(TestOpenAIVisionServer):
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
def test_video_chat_completion(self):
|
||||
self._test_video_chat_completion()
|
||||
|
||||
|
||||
class TestQwen2_5_VLServer(TestOpenAIVisionServer):
|
||||
class TestQwen2_5_VLServer(ImageOpenAITestMixin, VideoOpenAITestMixin):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||
@@ -61,9 +70,6 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer):
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
def test_video_chat_completion(self):
|
||||
self._test_video_chat_completion()
|
||||
|
||||
|
||||
class TestVLMContextLengthIssue(CustomTestCase):
|
||||
@classmethod
|
||||
@@ -137,11 +143,8 @@ class TestVLMContextLengthIssue(CustomTestCase):
|
||||
# )
|
||||
# cls.base_url += "/v1"
|
||||
|
||||
# def test_video_chat_completion(self):
|
||||
# pass
|
||||
|
||||
|
||||
class TestMinicpmvServer(TestOpenAIVisionServer):
|
||||
class TestMinicpmvServer(ImageOpenAITestMixin):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "openbmb/MiniCPM-V-2_6"
|
||||
@@ -162,7 +165,7 @@ class TestMinicpmvServer(TestOpenAIVisionServer):
|
||||
cls.base_url += "/v1"
|
||||
|
||||
|
||||
class TestInternVL2_5Server(TestOpenAIVisionServer):
|
||||
class TestInternVL2_5Server(ImageOpenAITestMixin):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "OpenGVLab/InternVL2_5-2B"
|
||||
@@ -181,7 +184,7 @@ class TestInternVL2_5Server(TestOpenAIVisionServer):
|
||||
cls.base_url += "/v1"
|
||||
|
||||
|
||||
class TestMinicpmoServer(TestOpenAIVisionServer):
|
||||
class TestMinicpmoServer(ImageOpenAITestMixin, AudioOpenAITestMixin):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "openbmb/MiniCPM-o-2_6"
|
||||
@@ -201,12 +204,8 @@ class TestMinicpmoServer(TestOpenAIVisionServer):
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
def test_audio_chat_completion(self):
|
||||
self._test_audio_speech_completion()
|
||||
self._test_audio_ambient_completion()
|
||||
|
||||
|
||||
class TestMimoVLServer(TestOpenAIVisionServer):
|
||||
class TestMimoVLServer(ImageOpenAITestMixin):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "XiaomiMiMo/MiMo-VL-7B-RL"
|
||||
@@ -228,6 +227,95 @@ class TestMimoVLServer(TestOpenAIVisionServer):
|
||||
cls.base_url += "/v1"
|
||||
|
||||
|
||||
class TestVILAServer(ImageOpenAITestMixin):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "Efficient-Large-Model/NVILA-Lite-2B-hf-0626"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.revision = "6bde1de5964b40e61c802b375fff419edc867506"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--context-length=65536",
|
||||
f"--revision={cls.revision}",
|
||||
"--cuda-graph-max-bs",
|
||||
"4",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
|
||||
class TestPhi4MMServer(ImageOpenAITestMixin, AudioOpenAITestMixin):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Manually download LoRA adapter_config.json as it's not downloaded by the model loader by default.
|
||||
from huggingface_hub import constants, snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
"microsoft/Phi-4-multimodal-instruct",
|
||||
allow_patterns=["**/adapter_config.json"],
|
||||
)
|
||||
|
||||
cls.model = "microsoft/Phi-4-multimodal-instruct"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
revision = "33e62acdd07cd7d6635badd529aa0a3467bb9c6a"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--mem-fraction-static",
|
||||
"0.70",
|
||||
"--disable-radix-cache",
|
||||
"--max-loras-per-batch",
|
||||
"2",
|
||||
"--revision",
|
||||
revision,
|
||||
"--lora-paths",
|
||||
f"vision={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/vision-lora",
|
||||
f"speech={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/speech-lora",
|
||||
"--cuda-graph-max-bs",
|
||||
"4",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
def get_vision_request_kwargs(self):
|
||||
return {
|
||||
"extra_body": {
|
||||
"lora_path": "vision",
|
||||
"top_k": 1,
|
||||
"top_p": 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
def get_audio_request_kwargs(self):
|
||||
return {
|
||||
"extra_body": {
|
||||
"lora_path": "speech",
|
||||
"top_k": 1,
|
||||
"top_p": 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
# This _test_audio_ambient_completion test is way too complicated to pass for a small LLM
|
||||
def test_audio_ambient_completion(self):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
del TestOpenAIVisionServer
|
||||
del (
|
||||
TestOpenAIOmniServerBase,
|
||||
ImageOpenAITestMixin,
|
||||
VideoOpenAITestMixin,
|
||||
AudioOpenAITestMixin,
|
||||
)
|
||||
unittest.main()
|
||||
|
||||
@@ -4,12 +4,11 @@ from test_vision_openai_server_common import *
|
||||
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestPixtralServer(TestOpenAIVisionServer):
|
||||
class TestPixtralServer(ImageOpenAITestMixin):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "mistral-community/pixtral-12b"
|
||||
@@ -29,11 +28,8 @@ class TestPixtralServer(TestOpenAIVisionServer):
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
def test_video_chat_completion(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestMistral3_1Server(TestOpenAIVisionServer):
|
||||
class TestMistral3_1Server(ImageOpenAITestMixin):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "unsloth/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
@@ -53,11 +49,8 @@ class TestMistral3_1Server(TestOpenAIVisionServer):
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
def test_video_chat_completion(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestDeepseekVL2Server(TestOpenAIVisionServer):
|
||||
class TestDeepseekVL2Server(ImageOpenAITestMixin):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "deepseek-ai/deepseek-vl2-small"
|
||||
@@ -77,11 +70,8 @@ class TestDeepseekVL2Server(TestOpenAIVisionServer):
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
def test_video_chat_completion(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestJanusProServer(TestOpenAIVisionServer):
|
||||
class TestJanusProServer(ImageOpenAITestMixin):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "deepseek-ai/Janus-Pro-7B"
|
||||
@@ -104,10 +94,6 @@ class TestJanusProServer(TestOpenAIVisionServer):
|
||||
def test_video_images_chat_completion(self):
|
||||
pass
|
||||
|
||||
def test_single_image_chat_completion(self):
|
||||
# Skip this test because it is flaky
|
||||
pass
|
||||
|
||||
|
||||
## Skip for ci test
|
||||
# class TestLlama4Server(TestOpenAIVisionServer):
|
||||
@@ -135,11 +121,8 @@ class TestJanusProServer(TestOpenAIVisionServer):
|
||||
# )
|
||||
# cls.base_url += "/v1"
|
||||
|
||||
# def test_video_chat_completion(self):
|
||||
# pass
|
||||
|
||||
|
||||
class TestGemma3itServer(TestOpenAIVisionServer):
|
||||
class TestGemma3itServer(ImageOpenAITestMixin):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "google/gemma-3-4b-it"
|
||||
@@ -160,11 +143,8 @@ class TestGemma3itServer(TestOpenAIVisionServer):
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
def test_video_chat_completion(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestGemma3nServer(TestOpenAIVisionServer):
|
||||
class TestGemma3nServer(ImageOpenAITestMixin, AudioOpenAITestMixin):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "google/gemma-3n-E4B-it"
|
||||
@@ -184,16 +164,15 @@ class TestGemma3nServer(TestOpenAIVisionServer):
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
def test_audio_chat_completion(self):
|
||||
self._test_audio_speech_completion()
|
||||
# This _test_audio_ambient_completion test is way too complicated to pass for a small LLM
|
||||
# self._test_audio_ambient_completion()
|
||||
# This _test_audio_ambient_completion test is way too complicated to pass for a small LLM
|
||||
def test_audio_ambient_completion(self):
|
||||
pass
|
||||
|
||||
def _test_mixed_image_audio_chat_completion(self):
|
||||
self._test_mixed_image_audio_chat_completion()
|
||||
|
||||
|
||||
class TestQwen2AudioServer(TestOpenAIVisionServer):
|
||||
class TestQwen2AudioServer(AudioOpenAITestMixin):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "Qwen/Qwen2-Audio-7B-Instruct"
|
||||
@@ -211,36 +190,8 @@ class TestQwen2AudioServer(TestOpenAIVisionServer):
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
def test_audio_chat_completion(self):
|
||||
self._test_audio_speech_completion()
|
||||
self._test_audio_ambient_completion()
|
||||
|
||||
# Qwen2Audio does not support image
|
||||
def test_single_image_chat_completion(self):
|
||||
pass
|
||||
|
||||
# Qwen2Audio does not support image
|
||||
def test_multi_turn_chat_completion(self):
|
||||
pass
|
||||
|
||||
# Qwen2Audio does not support image
|
||||
def test_multi_images_chat_completion(self):
|
||||
pass
|
||||
|
||||
# Qwen2Audio does not support image
|
||||
def test_video_images_chat_completion(self):
|
||||
pass
|
||||
|
||||
# Qwen2Audio does not support image
|
||||
def test_regex(self):
|
||||
pass
|
||||
|
||||
# Qwen2Audio does not support image
|
||||
def test_mixed_batch(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestKimiVLServer(TestOpenAIVisionServer):
|
||||
class TestKimiVLServer(ImageOpenAITestMixin):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "moonshotai/Kimi-VL-A3B-Instruct"
|
||||
@@ -266,91 +217,6 @@ class TestKimiVLServer(TestOpenAIVisionServer):
|
||||
pass
|
||||
|
||||
|
||||
class TestPhi4MMServer(TestOpenAIVisionServer):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Manually download LoRA adapter_config.json as it's not downloaded by the model loader by default.
|
||||
from huggingface_hub import constants, snapshot_download
|
||||
|
||||
snapshot_download(
|
||||
"microsoft/Phi-4-multimodal-instruct",
|
||||
allow_patterns=["**/adapter_config.json"],
|
||||
)
|
||||
|
||||
cls.model = "microsoft/Phi-4-multimodal-instruct"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
revision = "33e62acdd07cd7d6635badd529aa0a3467bb9c6a"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--mem-fraction-static",
|
||||
"0.70",
|
||||
"--disable-radix-cache",
|
||||
"--max-loras-per-batch",
|
||||
"2",
|
||||
"--revision",
|
||||
revision,
|
||||
"--lora-paths",
|
||||
f"vision={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/vision-lora",
|
||||
f"speech={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/speech-lora",
|
||||
"--cuda-graph-max-bs",
|
||||
"4",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
def get_vision_request_kwargs(self):
|
||||
return {
|
||||
"extra_body": {
|
||||
"lora_path": "vision",
|
||||
"top_k": 1,
|
||||
"top_p": 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
def get_audio_request_kwargs(self):
|
||||
return {
|
||||
"extra_body": {
|
||||
"lora_path": "speech",
|
||||
"top_k": 1,
|
||||
"top_p": 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
def test_audio_chat_completion(self):
|
||||
self._test_audio_speech_completion()
|
||||
# This _test_audio_ambient_completion test is way too complicated to pass for a small LLM
|
||||
# self._test_audio_ambient_completion()
|
||||
|
||||
|
||||
class TestVILAServer(TestOpenAIVisionServer):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "Efficient-Large-Model/NVILA-Lite-2B-hf-0626"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.revision = "6bde1de5964b40e61c802b375fff419edc867506"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--context-length=65536",
|
||||
f"--revision={cls.revision}",
|
||||
"--cuda-graph-max-bs",
|
||||
"4",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
|
||||
# Skip for ci test
|
||||
# class TestGLM41VServer(TestOpenAIVisionServer):
|
||||
# @classmethod
|
||||
@@ -379,5 +245,10 @@ class TestVILAServer(TestOpenAIVisionServer):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
del TestOpenAIVisionServer
|
||||
del (
|
||||
TestOpenAIOmniServerBase,
|
||||
ImageOpenAITestMixin,
|
||||
VideoOpenAITestMixin,
|
||||
AudioOpenAITestMixin,
|
||||
)
|
||||
unittest.main()
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
@@ -10,12 +8,7 @@ import requests
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
from sglang.test.test_utils import DEFAULT_URL_FOR_TEST, CustomTestCase
|
||||
|
||||
# image
|
||||
IMAGE_MAN_IRONING_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/man_ironing_on_back_of_suv.png"
|
||||
@@ -29,33 +22,123 @@ AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test
|
||||
AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
|
||||
|
||||
|
||||
class TestOpenAIVisionServer(CustomTestCase):
|
||||
class TestOpenAIOmniServerBase(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "lmms-lab/llava-onevision-qwen2-0.5b-ov"
|
||||
cls.model = ""
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
)
|
||||
cls.process = None
|
||||
cls.base_url += "/v1"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def get_audio_request_kwargs(self):
|
||||
return self.get_request_kwargs()
|
||||
|
||||
def get_vision_request_kwargs(self):
|
||||
return self.get_request_kwargs()
|
||||
|
||||
def get_request_kwargs(self):
|
||||
return {}
|
||||
|
||||
def get_or_download_file(self, url: str) -> str:
|
||||
cache_dir = os.path.expanduser("~/.cache")
|
||||
if url is None:
|
||||
raise ValueError()
|
||||
file_name = url.split("/")[-1]
|
||||
file_path = os.path.join(cache_dir, file_name)
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
return file_path
|
||||
|
||||
|
||||
class AudioOpenAITestMixin(TestOpenAIOmniServerBase):
|
||||
def prepare_audio_messages(self, prompt, audio_file_name):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {"url": f"{audio_file_name}"},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
def get_audio_request_kwargs(self):
|
||||
return self.get_request_kwargs()
|
||||
|
||||
def get_audio_response(self, url: str, prompt, category):
|
||||
audio_file_path = self.get_or_download_file(url)
|
||||
client = openai.Client(api_key="sk-123456", base_url=self.base_url)
|
||||
|
||||
messages = self.prepare_audio_messages(prompt, audio_file_path)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
stream=False,
|
||||
**(self.get_audio_request_kwargs()),
|
||||
)
|
||||
|
||||
audio_response = response.choices[0].message.content
|
||||
|
||||
print("-" * 30)
|
||||
print(f"audio {category} response:\n{audio_response}")
|
||||
print("-" * 30)
|
||||
|
||||
audio_response = audio_response.lower()
|
||||
|
||||
self.assertIsNotNone(audio_response)
|
||||
self.assertGreater(len(audio_response), 0)
|
||||
|
||||
return audio_response.lower()
|
||||
|
||||
def test_audio_speech_completion(self):
|
||||
# a fragment of Trump's speech
|
||||
audio_response = self.get_audio_response(
|
||||
AUDIO_TRUMP_SPEECH_URL,
|
||||
"Listen to this audio and write down the audio transcription in English.",
|
||||
category="speech",
|
||||
)
|
||||
check_list = [
|
||||
"thank you",
|
||||
"it's a privilege to be here",
|
||||
"leader",
|
||||
"science",
|
||||
"art",
|
||||
]
|
||||
for check_word in check_list:
|
||||
assert (
|
||||
check_word in audio_response
|
||||
), f"audio_response: |{audio_response}| should contain |{check_word}|"
|
||||
|
||||
def test_audio_ambient_completion(self):
|
||||
# bird song
|
||||
audio_response = self.get_audio_response(
|
||||
AUDIO_BIRD_SONG_URL,
|
||||
"Please listen to the audio snippet carefully and transcribe the content in English.",
|
||||
"ambient",
|
||||
)
|
||||
assert "bird" in audio_response
|
||||
|
||||
|
||||
class ImageOpenAITestMixin(TestOpenAIOmniServerBase):
|
||||
def test_single_image_chat_completion(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
@@ -316,38 +399,6 @@ class TestOpenAIVisionServer(CustomTestCase):
|
||||
|
||||
return messages
|
||||
|
||||
def prepare_video_messages(self, video_path):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {"url": f"{video_path}"},
|
||||
},
|
||||
{"type": "text", "text": "Please describe the video in detail."},
|
||||
],
|
||||
},
|
||||
]
|
||||
return messages
|
||||
|
||||
def get_or_download_file(self, url: str) -> str:
|
||||
cache_dir = os.path.expanduser("~/.cache")
|
||||
if url is None:
|
||||
raise ValueError()
|
||||
file_name = url.split("/")[-1]
|
||||
file_path = os.path.join(cache_dir, file_name)
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
return file_path
|
||||
|
||||
# this test samples frames of video as input, but not video directly
|
||||
def test_video_images_chat_completion(self):
|
||||
url = VIDEO_JOBS_URL
|
||||
file_path = self.get_or_download_file(url)
|
||||
@@ -409,7 +460,24 @@ class TestOpenAIVisionServer(CustomTestCase):
|
||||
self.assertIsNotNone(video_response)
|
||||
self.assertGreater(len(video_response), 0)
|
||||
|
||||
def _test_video_chat_completion(self):
|
||||
|
||||
class VideoOpenAITestMixin(TestOpenAIOmniServerBase):
|
||||
def prepare_video_messages(self, video_path):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video_url",
|
||||
"video_url": {"url": f"{video_path}"},
|
||||
},
|
||||
{"type": "text", "text": "Please describe the video in detail."},
|
||||
],
|
||||
},
|
||||
]
|
||||
return messages
|
||||
|
||||
def test_video_chat_completion(self):
|
||||
url = VIDEO_JOBS_URL
|
||||
file_path = self.get_or_download_file(url)
|
||||
|
||||
@@ -457,170 +525,3 @@ class TestOpenAIVisionServer(CustomTestCase):
|
||||
), f"video_response: {video_response}, should contain 'black' or 'dark'"
|
||||
self.assertIsNotNone(video_response)
|
||||
self.assertGreater(len(video_response), 0)
|
||||
|
||||
def test_regex(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
regex = (
|
||||
r"""\{"""
|
||||
+ r""""color":"[\w]+","""
|
||||
+ r""""number_of_cars":[\d]+"""
|
||||
+ r"""\}"""
|
||||
)
|
||||
|
||||
extra_kwargs = self.get_vision_request_kwargs()
|
||||
extra_kwargs.setdefault("extra_body", {})["regex"] = regex
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": IMAGE_MAN_IRONING_URL},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Describe this image in the JSON format.",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
**extra_kwargs,
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
|
||||
try:
|
||||
js_obj = json.loads(text)
|
||||
except (TypeError, json.decoder.JSONDecodeError):
|
||||
print("JSONDecodeError", text)
|
||||
raise
|
||||
assert isinstance(js_obj["color"], str)
|
||||
assert isinstance(js_obj["number_of_cars"], int)
|
||||
|
||||
def run_decode_with_image(self, image_id):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
content = []
|
||||
if image_id == 0:
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": IMAGE_MAN_IRONING_URL},
|
||||
}
|
||||
)
|
||||
elif image_id == 1:
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": IMAGE_SGL_LOGO_URL},
|
||||
}
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Describe this image in a sentence.",
|
||||
}
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "user", "content": content},
|
||||
],
|
||||
temperature=0,
|
||||
**(self.get_vision_request_kwargs()),
|
||||
)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
text = response.choices[0].message.content
|
||||
assert isinstance(text, str)
|
||||
|
||||
def test_mixed_batch(self):
|
||||
image_ids = [0, 1, 2] * 4
|
||||
with ThreadPoolExecutor(4) as executor:
|
||||
list(executor.map(self.run_decode_with_image, image_ids))
|
||||
|
||||
def prepare_audio_messages(self, prompt, audio_file_name):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {"url": f"{audio_file_name}"},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
def get_audio_response(self, url: str, prompt, category):
|
||||
audio_file_path = self.get_or_download_file(url)
|
||||
client = openai.Client(api_key="sk-123456", base_url=self.base_url)
|
||||
|
||||
messages = self.prepare_audio_messages(prompt, audio_file_path)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
stream=False,
|
||||
**(self.get_audio_request_kwargs()),
|
||||
)
|
||||
|
||||
audio_response = response.choices[0].message.content
|
||||
|
||||
print("-" * 30)
|
||||
print(f"audio {category} response:\n{audio_response}")
|
||||
print("-" * 30)
|
||||
|
||||
audio_response = audio_response.lower()
|
||||
|
||||
self.assertIsNotNone(audio_response)
|
||||
self.assertGreater(len(audio_response), 0)
|
||||
|
||||
return audio_response.lower()
|
||||
|
||||
def _test_audio_speech_completion(self):
|
||||
# a fragment of Trump's speech
|
||||
audio_response = self.get_audio_response(
|
||||
AUDIO_TRUMP_SPEECH_URL,
|
||||
"Listen to this audio and write down the audio transcription in English.",
|
||||
category="speech",
|
||||
)
|
||||
check_list = [
|
||||
"thank you",
|
||||
"it's a privilege to be here",
|
||||
"leader",
|
||||
"science",
|
||||
"art",
|
||||
]
|
||||
for check_word in check_list:
|
||||
assert (
|
||||
check_word in audio_response
|
||||
), f"audio_response: |{audio_response}| should contain |{check_word}|"
|
||||
|
||||
def _test_audio_ambient_completion(self):
|
||||
# bird song
|
||||
audio_response = self.get_audio_response(
|
||||
AUDIO_BIRD_SONG_URL,
|
||||
"Please listen to the audio snippet carefully and transcribe the content in English.",
|
||||
"ambient",
|
||||
)
|
||||
assert "bird" in audio_response
|
||||
|
||||
def test_audio_chat_completion(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user