From 8048c28c11b7b377d769bfc38fd8b8c87fb187de Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Thu, 21 Nov 2024 19:05:41 -0800 Subject: [PATCH] Fix #2037 - Context length check does not take into out pad tokens for visual models (#2106) --- python/sglang/srt/managers/scheduler.py | 9 ++++ test/srt/test_vision_openai_server.py | 58 +++++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4e96fbc26..9af4caaf2 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -557,6 +557,15 @@ class Scheduler: req.origin_input_ids_unpadded, req.image_inputs ) + if len(req.origin_input_ids) > self.max_req_input_len: + req.finished_reason = FINISH_ABORT( + "Image request length is longer than the KV cache pool size or " + "the max context length aborting because you cannot truncate the image embeds" + ) + req.sampling_params.max_new_tokens = 0 + self.waiting_queue.append(req) + return + req.return_logprob = recv_req.return_logprob req.top_logprobs_num = recv_req.top_logprobs_num req.stream = recv_req.stream diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index d70dac66f..95a1624cf 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -364,6 +364,64 @@ class TestQWen2VLServer(TestOpenAIVisionServer): cls.base_url += "/v1" +class TestQWen2VLServerContextLengthIssue(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "Qwen/Qwen2-VL-7B-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=[ + "--chat-template", + "qwen2-vl", + "--context-length", + "300", + "--mem-fraction-static=0.80", + ], + ) + cls.base_url += "/v1" + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid, include_self=True) + + def test_chat_completion(self): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + response = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" + }, + }, + { + "type": "text", + "text": "Give a lengthy description of this picture", + }, + ], + }, + ], + temperature=0, + ) + + assert response.choices[0].finish_reason == "abort" + assert response.id + assert response.created + assert response.usage.prompt_tokens > 0 + assert response.usage.completion_tokens > 0 + assert response.usage.total_tokens > 0 + + class TestMllamaServer(TestOpenAIVisionServer): @classmethod def setUpClass(cls):