diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 933341ee9..d650535cb 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -217,9 +217,9 @@ class BaseMultimodalProcessor(ABC): if videos: kwargs["videos"] = videos if audios: - if self.arch in { - "Gemma3nForConditionalGeneration", - "Qwen2AudioForConditionalGeneration", + if self._processor.__class__.__name__ in { + "Gemma3nProcessor", + "Qwen2AudioProcessor", }: # Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107 kwargs["audio"] = audios diff --git a/python/sglang/srt/multimodal/processors/llava.py b/python/sglang/srt/multimodal/processors/llava.py index 5031dccbd..1647ea1e5 100644 --- a/python/sglang/srt/multimodal/processors/llava.py +++ b/python/sglang/srt/multimodal/processors/llava.py @@ -18,7 +18,7 @@ from sglang.srt.models.llavavid import LlavaVidForCausalLM from sglang.srt.models.mistral import Mistral3ForConditionalGeneration from sglang.srt.multimodal.mm_utils import expand2square, process_anyres_image from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor -from sglang.srt.utils import load_image, logger +from sglang.srt.utils import ImageData, load_image, logger from sglang.utils import get_exception_traceback @@ -35,7 +35,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor): @staticmethod def _process_single_image_task( - image_data: Union[str, bytes], + image_data: Union[str, bytes, ImageData], image_aspect_ratio: Optional[str] = None, image_grid_pinpoints: Optional[str] = None, processor=None, @@ -44,10 +44,11 @@ class LlavaImageProcessor(BaseMultimodalProcessor): image_processor = processor.image_processor try: - image, image_size = load_image(image_data) + url = image_data.url if isinstance(image_data, ImageData) else image_data + image, image_size = load_image(url) if image_size is not None: # It is a video with multiple images - image_hash = hash(image_data) + image_hash = hash(url) pixel_values = image_processor(image)["pixel_values"] for _ in range(len(pixel_values)): pixel_values[_] = pixel_values[_].astype(np.float16) @@ -55,7 +56,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor): return pixel_values, image_hash, image_size else: # It is an image - image_hash = hash(image_data) + image_hash = hash(url) if image_aspect_ratio == "pad": image = expand2square( image, @@ -82,7 +83,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor): logger.error("Exception in TokenizerManager:\n" + get_exception_traceback()) async def _process_single_image( - self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str + self, + image_data: Union[bytes, str, ImageData], + aspect_ratio: str, + grid_pinpoints: str, ): if self.cpu_executor is not None: loop = asyncio.get_event_loop() @@ -104,7 +108,7 @@ class LlavaImageProcessor(BaseMultimodalProcessor): async def process_mm_data_async( self, - image_data: List[Union[str, bytes]], + image_data: List[Union[str, bytes, ImageData]], input_text, request_obj, *args, diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 69fdb949c..26e99ae10 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -110,8 +110,8 @@ suites = { TestFile("test_utils_update_weights.py", 48), TestFile("test_vision_chunked_prefill.py", 175), TestFile("test_vlm_input_format.py", 300), - TestFile("test_vision_openai_server_a.py", 989), - TestFile("test_vision_openai_server_b.py", 620), + TestFile("test_vision_openai_server_a.py", 403), + TestFile("test_vision_openai_server_b.py", 446), ], "per-commit-2-gpu": [ TestFile("lora/test_lora_tp.py", 116), diff --git a/test/srt/test_vision_openai_server_a.py b/test/srt/test_vision_openai_server_a.py index 9d69b918c..9e311d5b1 100644 --- a/test/srt/test_vision_openai_server_a.py +++ b/test/srt/test_vision_openai_server_a.py @@ -8,16 +8,28 @@ import unittest from test_vision_openai_server_common import * -from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, popen_launch_server, ) -class TestQwen2VLServer(TestOpenAIVisionServer): +class TestLlava(ImageOpenAITestMixin): + @classmethod + def setUpClass(cls): + cls.model = "lmms-lab/llava-onevision-qwen2-0.5b-ov" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + ) + cls.base_url += "/v1" + + +class TestQwen2VLServer(ImageOpenAITestMixin, VideoOpenAITestMixin): @classmethod def setUpClass(cls): cls.model = "Qwen/Qwen2-VL-7B-Instruct" @@ -37,11 +49,8 @@ class TestQwen2VLServer(TestOpenAIVisionServer): ) cls.base_url += "/v1" - def test_video_chat_completion(self): - self._test_video_chat_completion() - -class TestQwen2_5_VLServer(TestOpenAIVisionServer): +class TestQwen2_5_VLServer(ImageOpenAITestMixin, VideoOpenAITestMixin): @classmethod def setUpClass(cls): cls.model = "Qwen/Qwen2.5-VL-7B-Instruct" @@ -61,9 +70,6 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer): ) cls.base_url += "/v1" - def test_video_chat_completion(self): - self._test_video_chat_completion() - class TestVLMContextLengthIssue(CustomTestCase): @classmethod @@ -137,11 +143,8 @@ class TestVLMContextLengthIssue(CustomTestCase): # ) # cls.base_url += "/v1" -# def test_video_chat_completion(self): -# pass - -class TestMinicpmvServer(TestOpenAIVisionServer): +class TestMinicpmvServer(ImageOpenAITestMixin): @classmethod def setUpClass(cls): cls.model = "openbmb/MiniCPM-V-2_6" @@ -162,7 +165,7 @@ class TestMinicpmvServer(TestOpenAIVisionServer): cls.base_url += "/v1" -class TestInternVL2_5Server(TestOpenAIVisionServer): +class TestInternVL2_5Server(ImageOpenAITestMixin): @classmethod def setUpClass(cls): cls.model = "OpenGVLab/InternVL2_5-2B" @@ -181,7 +184,7 @@ class TestInternVL2_5Server(TestOpenAIVisionServer): cls.base_url += "/v1" -class TestMinicpmoServer(TestOpenAIVisionServer): +class TestMinicpmoServer(ImageOpenAITestMixin, AudioOpenAITestMixin): @classmethod def setUpClass(cls): cls.model = "openbmb/MiniCPM-o-2_6" @@ -201,12 +204,8 @@ class TestMinicpmoServer(TestOpenAIVisionServer): ) cls.base_url += "/v1" - def test_audio_chat_completion(self): - self._test_audio_speech_completion() - self._test_audio_ambient_completion() - -class TestMimoVLServer(TestOpenAIVisionServer): +class TestMimoVLServer(ImageOpenAITestMixin): @classmethod def setUpClass(cls): cls.model = "XiaomiMiMo/MiMo-VL-7B-RL" @@ -228,6 +227,95 @@ class TestMimoVLServer(TestOpenAIVisionServer): cls.base_url += "/v1" +class TestVILAServer(ImageOpenAITestMixin): + @classmethod + def setUpClass(cls): + cls.model = "Efficient-Large-Model/NVILA-Lite-2B-hf-0626" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.revision = "6bde1de5964b40e61c802b375fff419edc867506" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=[ + "--trust-remote-code", + "--context-length=65536", + f"--revision={cls.revision}", + "--cuda-graph-max-bs", + "4", + ], + ) + cls.base_url += "/v1" + + +class TestPhi4MMServer(ImageOpenAITestMixin, AudioOpenAITestMixin): + @classmethod + def setUpClass(cls): + # Manually download LoRA adapter_config.json as it's not downloaded by the model loader by default. + from huggingface_hub import constants, snapshot_download + + snapshot_download( + "microsoft/Phi-4-multimodal-instruct", + allow_patterns=["**/adapter_config.json"], + ) + + cls.model = "microsoft/Phi-4-multimodal-instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + + revision = "33e62acdd07cd7d6635badd529aa0a3467bb9c6a" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--mem-fraction-static", + "0.70", + "--disable-radix-cache", + "--max-loras-per-batch", + "2", + "--revision", + revision, + "--lora-paths", + f"vision={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/vision-lora", + f"speech={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/speech-lora", + "--cuda-graph-max-bs", + "4", + ], + ) + cls.base_url += "/v1" + + def get_vision_request_kwargs(self): + return { + "extra_body": { + "lora_path": "vision", + "top_k": 1, + "top_p": 1.0, + } + } + + def get_audio_request_kwargs(self): + return { + "extra_body": { + "lora_path": "speech", + "top_k": 1, + "top_p": 1.0, + } + } + + # This _test_audio_ambient_completion test is way too complicated to pass for a small LLM + def test_audio_ambient_completion(self): + pass + + if __name__ == "__main__": - del TestOpenAIVisionServer + del ( + TestOpenAIOmniServerBase, + ImageOpenAITestMixin, + VideoOpenAITestMixin, + AudioOpenAITestMixin, + ) unittest.main() diff --git a/test/srt/test_vision_openai_server_b.py b/test/srt/test_vision_openai_server_b.py index c420c0ad8..fd952f82f 100644 --- a/test/srt/test_vision_openai_server_b.py +++ b/test/srt/test_vision_openai_server_b.py @@ -4,12 +4,11 @@ from test_vision_openai_server_common import * from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, popen_launch_server, ) -class TestPixtralServer(TestOpenAIVisionServer): +class TestPixtralServer(ImageOpenAITestMixin): @classmethod def setUpClass(cls): cls.model = "mistral-community/pixtral-12b" @@ -29,11 +28,8 @@ class TestPixtralServer(TestOpenAIVisionServer): ) cls.base_url += "/v1" - def test_video_chat_completion(self): - pass - -class TestMistral3_1Server(TestOpenAIVisionServer): +class TestMistral3_1Server(ImageOpenAITestMixin): @classmethod def setUpClass(cls): cls.model = "unsloth/Mistral-Small-3.1-24B-Instruct-2503" @@ -53,11 +49,8 @@ class TestMistral3_1Server(TestOpenAIVisionServer): ) cls.base_url += "/v1" - def test_video_chat_completion(self): - pass - -class TestDeepseekVL2Server(TestOpenAIVisionServer): +class TestDeepseekVL2Server(ImageOpenAITestMixin): @classmethod def setUpClass(cls): cls.model = "deepseek-ai/deepseek-vl2-small" @@ -77,11 +70,8 @@ class TestDeepseekVL2Server(TestOpenAIVisionServer): ) cls.base_url += "/v1" - def test_video_chat_completion(self): - pass - -class TestJanusProServer(TestOpenAIVisionServer): +class TestJanusProServer(ImageOpenAITestMixin): @classmethod def setUpClass(cls): cls.model = "deepseek-ai/Janus-Pro-7B" @@ -104,10 +94,6 @@ class TestJanusProServer(TestOpenAIVisionServer): def test_video_images_chat_completion(self): pass - def test_single_image_chat_completion(self): - # Skip this test because it is flaky - pass - ## Skip for ci test # class TestLlama4Server(TestOpenAIVisionServer): @@ -135,11 +121,8 @@ class TestJanusProServer(TestOpenAIVisionServer): # ) # cls.base_url += "/v1" -# def test_video_chat_completion(self): -# pass - -class TestGemma3itServer(TestOpenAIVisionServer): +class TestGemma3itServer(ImageOpenAITestMixin): @classmethod def setUpClass(cls): cls.model = "google/gemma-3-4b-it" @@ -160,11 +143,8 @@ class TestGemma3itServer(TestOpenAIVisionServer): ) cls.base_url += "/v1" - def test_video_chat_completion(self): - pass - -class TestGemma3nServer(TestOpenAIVisionServer): +class TestGemma3nServer(ImageOpenAITestMixin, AudioOpenAITestMixin): @classmethod def setUpClass(cls): cls.model = "google/gemma-3n-E4B-it" @@ -184,16 +164,15 @@ class TestGemma3nServer(TestOpenAIVisionServer): ) cls.base_url += "/v1" - def test_audio_chat_completion(self): - self._test_audio_speech_completion() - # This _test_audio_ambient_completion test is way too complicated to pass for a small LLM - # self._test_audio_ambient_completion() + # This _test_audio_ambient_completion test is way too complicated to pass for a small LLM + def test_audio_ambient_completion(self): + pass def _test_mixed_image_audio_chat_completion(self): self._test_mixed_image_audio_chat_completion() -class TestQwen2AudioServer(TestOpenAIVisionServer): +class TestQwen2AudioServer(AudioOpenAITestMixin): @classmethod def setUpClass(cls): cls.model = "Qwen/Qwen2-Audio-7B-Instruct" @@ -211,36 +190,8 @@ class TestQwen2AudioServer(TestOpenAIVisionServer): ) cls.base_url += "/v1" - def test_audio_chat_completion(self): - self._test_audio_speech_completion() - self._test_audio_ambient_completion() - # Qwen2Audio does not support image - def test_single_image_chat_completion(self): - pass - - # Qwen2Audio does not support image - def test_multi_turn_chat_completion(self): - pass - - # Qwen2Audio does not support image - def test_multi_images_chat_completion(self): - pass - - # Qwen2Audio does not support image - def test_video_images_chat_completion(self): - pass - - # Qwen2Audio does not support image - def test_regex(self): - pass - - # Qwen2Audio does not support image - def test_mixed_batch(self): - pass - - -class TestKimiVLServer(TestOpenAIVisionServer): +class TestKimiVLServer(ImageOpenAITestMixin): @classmethod def setUpClass(cls): cls.model = "moonshotai/Kimi-VL-A3B-Instruct" @@ -266,91 +217,6 @@ class TestKimiVLServer(TestOpenAIVisionServer): pass -class TestPhi4MMServer(TestOpenAIVisionServer): - @classmethod - def setUpClass(cls): - # Manually download LoRA adapter_config.json as it's not downloaded by the model loader by default. - from huggingface_hub import constants, snapshot_download - - snapshot_download( - "microsoft/Phi-4-multimodal-instruct", - allow_patterns=["**/adapter_config.json"], - ) - - cls.model = "microsoft/Phi-4-multimodal-instruct" - cls.base_url = DEFAULT_URL_FOR_TEST - cls.api_key = "sk-123456" - - revision = "33e62acdd07cd7d6635badd529aa0a3467bb9c6a" - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--mem-fraction-static", - "0.70", - "--disable-radix-cache", - "--max-loras-per-batch", - "2", - "--revision", - revision, - "--lora-paths", - f"vision={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/vision-lora", - f"speech={constants.HF_HUB_CACHE}/models--microsoft--Phi-4-multimodal-instruct/snapshots/{revision}/speech-lora", - "--cuda-graph-max-bs", - "4", - ], - ) - cls.base_url += "/v1" - - def get_vision_request_kwargs(self): - return { - "extra_body": { - "lora_path": "vision", - "top_k": 1, - "top_p": 1.0, - } - } - - def get_audio_request_kwargs(self): - return { - "extra_body": { - "lora_path": "speech", - "top_k": 1, - "top_p": 1.0, - } - } - - def test_audio_chat_completion(self): - self._test_audio_speech_completion() - # This _test_audio_ambient_completion test is way too complicated to pass for a small LLM - # self._test_audio_ambient_completion() - - -class TestVILAServer(TestOpenAIVisionServer): - @classmethod - def setUpClass(cls): - cls.model = "Efficient-Large-Model/NVILA-Lite-2B-hf-0626" - cls.base_url = DEFAULT_URL_FOR_TEST - cls.api_key = "sk-123456" - cls.revision = "6bde1de5964b40e61c802b375fff419edc867506" - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - api_key=cls.api_key, - other_args=[ - "--trust-remote-code", - "--context-length=65536", - f"--revision={cls.revision}", - "--cuda-graph-max-bs", - "4", - ], - ) - cls.base_url += "/v1" - - # Skip for ci test # class TestGLM41VServer(TestOpenAIVisionServer): # @classmethod @@ -379,5 +245,10 @@ class TestVILAServer(TestOpenAIVisionServer): if __name__ == "__main__": - del TestOpenAIVisionServer + del ( + TestOpenAIOmniServerBase, + ImageOpenAITestMixin, + VideoOpenAITestMixin, + AudioOpenAITestMixin, + ) unittest.main() diff --git a/test/srt/test_vision_openai_server_common.py b/test/srt/test_vision_openai_server_common.py index 6a46a0610..792636060 100644 --- a/test/srt/test_vision_openai_server_common.py +++ b/test/srt/test_vision_openai_server_common.py @@ -1,8 +1,6 @@ import base64 import io -import json import os -from concurrent.futures import ThreadPoolExecutor import numpy as np import openai @@ -10,12 +8,7 @@ import requests from PIL import Image from sglang.srt.utils import kill_process_tree -from sglang.test.test_utils import ( - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - popen_launch_server, -) +from sglang.test.test_utils import DEFAULT_URL_FOR_TEST, CustomTestCase # image IMAGE_MAN_IRONING_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/man_ironing_on_back_of_suv.png" @@ -29,33 +22,123 @@ AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3" -class TestOpenAIVisionServer(CustomTestCase): +class TestOpenAIOmniServerBase(CustomTestCase): @classmethod def setUpClass(cls): - cls.model = "lmms-lab/llava-onevision-qwen2-0.5b-ov" + cls.model = "" cls.base_url = DEFAULT_URL_FOR_TEST cls.api_key = "sk-123456" - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - api_key=cls.api_key, - ) + cls.process = None cls.base_url += "/v1" @classmethod def tearDownClass(cls): kill_process_tree(cls.process.pid) - def get_audio_request_kwargs(self): - return self.get_request_kwargs() - def get_vision_request_kwargs(self): return self.get_request_kwargs() def get_request_kwargs(self): return {} + def get_or_download_file(self, url: str) -> str: + cache_dir = os.path.expanduser("~/.cache") + if url is None: + raise ValueError() + file_name = url.split("/")[-1] + file_path = os.path.join(cache_dir, file_name) + os.makedirs(cache_dir, exist_ok=True) + + if not os.path.exists(file_path): + response = requests.get(url) + response.raise_for_status() + + with open(file_path, "wb") as f: + f.write(response.content) + return file_path + + +class AudioOpenAITestMixin(TestOpenAIOmniServerBase): + def prepare_audio_messages(self, prompt, audio_file_name): + messages = [ + { + "role": "user", + "content": [ + { + "type": "audio_url", + "audio_url": {"url": f"{audio_file_name}"}, + }, + { + "type": "text", + "text": prompt, + }, + ], + } + ] + + return messages + + def get_audio_request_kwargs(self): + return self.get_request_kwargs() + + def get_audio_response(self, url: str, prompt, category): + audio_file_path = self.get_or_download_file(url) + client = openai.Client(api_key="sk-123456", base_url=self.base_url) + + messages = self.prepare_audio_messages(prompt, audio_file_path) + + response = client.chat.completions.create( + model="default", + messages=messages, + temperature=0, + max_tokens=128, + stream=False, + **(self.get_audio_request_kwargs()), + ) + + audio_response = response.choices[0].message.content + + print("-" * 30) + print(f"audio {category} response:\n{audio_response}") + print("-" * 30) + + audio_response = audio_response.lower() + + self.assertIsNotNone(audio_response) + self.assertGreater(len(audio_response), 0) + + return audio_response.lower() + + def test_audio_speech_completion(self): + # a fragment of Trump's speech + audio_response = self.get_audio_response( + AUDIO_TRUMP_SPEECH_URL, + "Listen to this audio and write down the audio transcription in English.", + category="speech", + ) + check_list = [ + "thank you", + "it's a privilege to be here", + "leader", + "science", + "art", + ] + for check_word in check_list: + assert ( + check_word in audio_response + ), f"audio_response: |{audio_response}| should contain |{check_word}|" + + def test_audio_ambient_completion(self): + # bird song + audio_response = self.get_audio_response( + AUDIO_BIRD_SONG_URL, + "Please listen to the audio snippet carefully and transcribe the content in English.", + "ambient", + ) + assert "bird" in audio_response + + +class ImageOpenAITestMixin(TestOpenAIOmniServerBase): def test_single_image_chat_completion(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url) @@ -316,38 +399,6 @@ class TestOpenAIVisionServer(CustomTestCase): return messages - def prepare_video_messages(self, video_path): - messages = [ - { - "role": "user", - "content": [ - { - "type": "video_url", - "video_url": {"url": f"{video_path}"}, - }, - {"type": "text", "text": "Please describe the video in detail."}, - ], - }, - ] - return messages - - def get_or_download_file(self, url: str) -> str: - cache_dir = os.path.expanduser("~/.cache") - if url is None: - raise ValueError() - file_name = url.split("/")[-1] - file_path = os.path.join(cache_dir, file_name) - os.makedirs(cache_dir, exist_ok=True) - - if not os.path.exists(file_path): - response = requests.get(url) - response.raise_for_status() - - with open(file_path, "wb") as f: - f.write(response.content) - return file_path - - # this test samples frames of video as input, but not video directly def test_video_images_chat_completion(self): url = VIDEO_JOBS_URL file_path = self.get_or_download_file(url) @@ -409,7 +460,24 @@ class TestOpenAIVisionServer(CustomTestCase): self.assertIsNotNone(video_response) self.assertGreater(len(video_response), 0) - def _test_video_chat_completion(self): + +class VideoOpenAITestMixin(TestOpenAIOmniServerBase): + def prepare_video_messages(self, video_path): + messages = [ + { + "role": "user", + "content": [ + { + "type": "video_url", + "video_url": {"url": f"{video_path}"}, + }, + {"type": "text", "text": "Please describe the video in detail."}, + ], + }, + ] + return messages + + def test_video_chat_completion(self): url = VIDEO_JOBS_URL file_path = self.get_or_download_file(url) @@ -457,170 +525,3 @@ class TestOpenAIVisionServer(CustomTestCase): ), f"video_response: {video_response}, should contain 'black' or 'dark'" self.assertIsNotNone(video_response) self.assertGreater(len(video_response), 0) - - def test_regex(self): - client = openai.Client(api_key=self.api_key, base_url=self.base_url) - - regex = ( - r"""\{""" - + r""""color":"[\w]+",""" - + r""""number_of_cars":[\d]+""" - + r"""\}""" - ) - - extra_kwargs = self.get_vision_request_kwargs() - extra_kwargs.setdefault("extra_body", {})["regex"] = regex - - response = client.chat.completions.create( - model="default", - messages=[ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": {"url": IMAGE_MAN_IRONING_URL}, - }, - { - "type": "text", - "text": "Describe this image in the JSON format.", - }, - ], - }, - ], - temperature=0, - **extra_kwargs, - ) - text = response.choices[0].message.content - - try: - js_obj = json.loads(text) - except (TypeError, json.decoder.JSONDecodeError): - print("JSONDecodeError", text) - raise - assert isinstance(js_obj["color"], str) - assert isinstance(js_obj["number_of_cars"], int) - - def run_decode_with_image(self, image_id): - client = openai.Client(api_key=self.api_key, base_url=self.base_url) - - content = [] - if image_id == 0: - content.append( - { - "type": "image_url", - "image_url": {"url": IMAGE_MAN_IRONING_URL}, - } - ) - elif image_id == 1: - content.append( - { - "type": "image_url", - "image_url": {"url": IMAGE_SGL_LOGO_URL}, - } - ) - else: - pass - - content.append( - { - "type": "text", - "text": "Describe this image in a sentence.", - } - ) - - response = client.chat.completions.create( - model="default", - messages=[ - {"role": "user", "content": content}, - ], - temperature=0, - **(self.get_vision_request_kwargs()), - ) - - assert response.choices[0].message.role == "assistant" - text = response.choices[0].message.content - assert isinstance(text, str) - - def test_mixed_batch(self): - image_ids = [0, 1, 2] * 4 - with ThreadPoolExecutor(4) as executor: - list(executor.map(self.run_decode_with_image, image_ids)) - - def prepare_audio_messages(self, prompt, audio_file_name): - messages = [ - { - "role": "user", - "content": [ - { - "type": "audio_url", - "audio_url": {"url": f"{audio_file_name}"}, - }, - { - "type": "text", - "text": prompt, - }, - ], - } - ] - - return messages - - def get_audio_response(self, url: str, prompt, category): - audio_file_path = self.get_or_download_file(url) - client = openai.Client(api_key="sk-123456", base_url=self.base_url) - - messages = self.prepare_audio_messages(prompt, audio_file_path) - - response = client.chat.completions.create( - model="default", - messages=messages, - temperature=0, - max_tokens=128, - stream=False, - **(self.get_audio_request_kwargs()), - ) - - audio_response = response.choices[0].message.content - - print("-" * 30) - print(f"audio {category} response:\n{audio_response}") - print("-" * 30) - - audio_response = audio_response.lower() - - self.assertIsNotNone(audio_response) - self.assertGreater(len(audio_response), 0) - - return audio_response.lower() - - def _test_audio_speech_completion(self): - # a fragment of Trump's speech - audio_response = self.get_audio_response( - AUDIO_TRUMP_SPEECH_URL, - "Listen to this audio and write down the audio transcription in English.", - category="speech", - ) - check_list = [ - "thank you", - "it's a privilege to be here", - "leader", - "science", - "art", - ] - for check_word in check_list: - assert ( - check_word in audio_response - ), f"audio_response: |{audio_response}| should contain |{check_word}|" - - def _test_audio_ambient_completion(self): - # bird song - audio_response = self.get_audio_response( - AUDIO_BIRD_SONG_URL, - "Please listen to the audio snippet carefully and transcribe the content in English.", - "ambient", - ) - assert "bird" in audio_response - - def test_audio_chat_completion(self): - pass