From 66e7dcaf7008d2ffe892044a21513a6e06424d1a Mon Sep 17 00:00:00 2001 From: Kaichen Zhang - NTU Date: Mon, 26 Aug 2024 01:28:23 +0800 Subject: [PATCH] [Fix] Fixing the multi-images error for llava-onevision (#1205) --- .../http_llava_onevision_test.py | 46 +++++++++++++++++++ .../sglang/srt/managers/tokenizer_manager.py | 4 +- test/srt/test_vision_openai_server.py | 42 +++++++++++++++++ 3 files changed, 91 insertions(+), 1 deletion(-) diff --git a/examples/runtime/llava_onevision/http_llava_onevision_test.py b/examples/runtime/llava_onevision/http_llava_onevision_test.py index 40dc27ec2..41d60b12a 100644 --- a/examples/runtime/llava_onevision/http_llava_onevision_test.py +++ b/examples/runtime/llava_onevision/http_llava_onevision_test.py @@ -78,6 +78,51 @@ def image_stream_request_test(client): print("-" * 30) +def multi_image_stream_request_test(client): + print( + "----------------------Multi-Images Stream Request Test----------------------" + ) + stream_request = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" + }, + }, + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png" + }, + }, + { + "type": "text", + "text": "I have shown you two images. Please describe the two images to me.", + }, + ], + }, + ], + temperature=0.7, + max_tokens=1024, + stream=True, + ) + stream_response = "" + + for chunk in stream_request: + if chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + stream_response += content + sys.stdout.write(content) + sys.stdout.flush() + + print("-" * 30) + + def video_stream_request_test(client, video_path): print("------------------------Video Stream Request Test----------------------") messages = prepare_video_messages(video_path) @@ -209,6 +254,7 @@ def main(): client = create_openai_client("http://127.0.0.1:30000/v1") image_stream_request_test(client) + multi_image_stream_request_test(client) video_stream_request_test(client, video_path) image_speed_test(client) video_speed_test(client, video_path) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 8f6700575..5cc060be1 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -744,7 +744,9 @@ def get_pixel_values( image, tuple(int(x * 255) for x in processor.image_processor.image_mean), ) - pixel_values = processor.image_processor(image)["pixel_values"][0] + pixel_values = processor.image_processor(image.convert("RGB"))[ + "pixel_values" + ][0] elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: pixel_values = process_anyres_image( image, processor.image_processor, image_grid_pinpoints diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 0a477a92a..0f136fe6e 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -74,6 +74,48 @@ class TestOpenAIVisionServer(unittest.TestCase): assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 + def test_mult_images_chat_completion(self): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + response = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" + }, + }, + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png" + }, + }, + { + "type": "text", + "text": "I have shown you two images. Please describe the two images to me.", + }, + ], + }, + ], + temperature=0, + ) + + assert response.choices[0].message.role == "assistant" + text = response.choices[0].message.content + assert isinstance(text, str) + assert "man" in text or "cab" in text, text + assert "logo" in text, text + assert response.id + assert response.created + assert response.usage.prompt_tokens > 0 + assert response.usage.completion_tokens > 0 + assert response.usage.total_tokens > 0 + def prepare_video_messages(self, video_path): max_frames_num = 32 vr = VideoReader(video_path, ctx=cpu(0))