From 37c8a5761f05d83b5ef3f946c8ebacbd51891651 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Wed, 27 Nov 2024 00:03:29 -0800 Subject: [PATCH] [feat] Support session control for vision language models (#2210) --- python/sglang/srt/managers/image_processor.py | 15 +- python/sglang/srt/managers/schedule_batch.py | 38 +++- python/sglang/srt/managers/scheduler.py | 5 +- .../sglang/srt/managers/session_controller.py | 19 +- python/sglang/srt/models/llava.py | 8 +- test/srt/run_suite.py | 1 + test/srt/test_session_control.py | 200 +++++++++++++++++- 7 files changed, 265 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py index 2af817319..4cfd210ef 100644 --- a/python/sglang/srt/managers/image_processor.py +++ b/python/sglang/srt/managers/image_processor.py @@ -131,6 +131,7 @@ class LlavaImageProcessor(BaseImageProcessor): if not image_data: return None + modalities = request_obj.modalities or ["image"] aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) grid_pinpoints = ( self.hf_config.image_grid_pinpoints @@ -139,9 +140,12 @@ class LlavaImageProcessor(BaseImageProcessor): else None ) + if isinstance(image_data, str): + image_data = [image_data] + if isinstance(image_data, list) and len(image_data) > 0: - # Multiple images - if len(image_data) > 1: + if "multi-images" in modalities or "video" in modalities: + # Multiple images aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres pixel_values, image_hashes, image_sizes = [], [], [] res = [] @@ -166,13 +170,6 @@ class LlavaImageProcessor(BaseImageProcessor): ) image_hashes = [image_hash] image_sizes = [image_size] - elif isinstance(image_data, str): - # A single image - pixel_values, image_hash, image_size = await self._process_single_image( - image_data, aspect_ratio, grid_pinpoints - ) - image_hashes = [image_hash] - image_sizes = [image_size] else: raise ValueError(f"Invalid image data: {image_data}") diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 4d1bbece2..971809124 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -31,6 +31,7 @@ import dataclasses import logging from typing import List, Optional, Tuple, Union +import numpy as np import torch import triton import triton.language as tl @@ -167,6 +168,30 @@ class ImageInputs: return ret + def merge(self, other, vocab_size): + assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:] + self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values]) + self.image_hashes += other.image_hashes + + self.pad_values = [ + (self.image_hashes) % vocab_size, + (self.image_hashes >> 16) % vocab_size, + (self.image_hashes >> 32) % vocab_size, + (self.image_hashes >> 64) % vocab_size, + ] + + optional_args = [ + "image_sizes", + "image_offsets", + # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images + "aspect_ratio_ids", + "aspect_ratio_mask", + "image_grid_thws", + ] + for arg in optional_args: + if getattr(self, arg, None) is not None: + setattr(self, arg, getattr(self, arg) + getattr(other, arg)) + class Req: """The input and output status of a request.""" @@ -177,6 +202,7 @@ class Req: origin_input_text: str, origin_input_ids: Tuple[int], sampling_params: SamplingParams, + origin_input_ids_unpadded: Optional[Tuple[int]] = None, lora_path: Optional[str] = None, input_embeds: Optional[List[List[float]]] = None, session_id: Optional[str] = None, @@ -184,7 +210,11 @@ class Req: # Input and output info self.rid = rid self.origin_input_text = origin_input_text - self.origin_input_ids_unpadded = origin_input_ids # Before image padding + self.origin_input_ids_unpadded = ( + origin_input_ids_unpadded + if origin_input_ids_unpadded + else origin_input_ids # Before image padding + ) self.origin_input_ids = origin_input_ids self.output_ids = [] # Each decode stage's output ids self.fill_ids = None # fill_ids = origin_input_ids + output_ids @@ -260,6 +290,12 @@ class Req: # The number of cached tokens, that were already cached in the KV cache self.cached_tokens = 0 + def extend_image_inputs(self, image_inputs, vocab_size): + if self.image_inputs is None: + self.image_inputs = image_inputs + else: + self.image_inputs.merge(image_inputs, vocab_size) + # whether request reached finished condition def finished(self) -> bool: return self.finished_reason is not None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 35a064e5e..5e8197de8 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -559,12 +559,13 @@ class Scheduler: # Image inputs if recv_req.image_inputs is not None: - req.image_inputs = ImageInputs.from_dict( + image_inputs = ImageInputs.from_dict( recv_req.image_inputs, self.model_config.vocab_size ) req.origin_input_ids = self.pad_input_ids_func( - req.origin_input_ids_unpadded, req.image_inputs + req.origin_input_ids, image_inputs ) + req.extend_image_inputs(image_inputs, self.model_config.vocab_size) if len(req.origin_input_ids) > self.max_req_input_len: req.finished_reason = FINISH_ABORT( diff --git a/python/sglang/srt/managers/session_controller.py b/python/sglang/srt/managers/session_controller.py index 6a5b5b5b5..f267a0dc2 100644 --- a/python/sglang/srt/managers/session_controller.py +++ b/python/sglang/srt/managers/session_controller.py @@ -41,16 +41,27 @@ class Session: ] + req.input_ids ) + input_ids_unpadded = ( + self.reqs[-1].origin_input_ids_unpadded + + self.reqs[-1].output_ids[ + : self.reqs[-1].sampling_params.max_new_tokens + ] + + req.input_ids + ) else: input_ids = req.input_ids + input_ids_unpadded = req.input_ids new_req = Req( - req.rid, - None, - input_ids, - req.sampling_params, + rid=req.rid, + origin_input_text=None, + origin_input_ids=input_ids, + origin_input_ids_unpadded=input_ids_unpadded, + sampling_params=req.sampling_params, lora_path=req.lora_path, session_id=self.session_id, ) + if len(self.reqs) > 0: + new_req.image_inputs = self.reqs[-1].image_inputs new_req.tokenizer = tokenizer if req.session_rid is not None and len(self.reqs) == 0: new_req.finished_reason = FINISH_ABORT( diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 780bf36b5..b07474ad9 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -49,7 +49,13 @@ class LlavaBaseForCausalLM(nn.Module): image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values # hardcode for spatial_unpad + anyres - image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad" + if image_inputs.modalities is not None and ( + "multi-images" in image_inputs.modalities + or "video" in image_inputs.modalities + ): + image_aspect_ratio = "pad" + else: + image_aspect_ratio = "anyres" offset_list = [] for image_s in image_sizes: if len(image_sizes) > 16: diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 6611330ae..3f55eb25f 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -36,6 +36,7 @@ suites = { "test_triton_attention_backend.py", "test_update_weights.py", "test_vision_openai_server.py", + "test_session_control.py", ], "sampling/penaltylib": glob.glob( "sampling/penaltylib/**/test_*.py", recursive=True diff --git a/test/srt/test_session_control.py b/test/srt/test_session_control.py index e5b5e7b6c..7396779f6 100644 --- a/test/srt/test_session_control.py +++ b/test/srt/test_session_control.py @@ -1,7 +1,7 @@ """ Usage: python3 -m unittest test_session_control.TestSessionControl.test_session_control -python3 -m unittest test_session_control.TestSessionControl.test_session_control_vlm +python3 -m unittest test_session_control.TestSessionControlVision.test_session_control """ import unittest @@ -61,6 +61,8 @@ class TestSessionControl(unittest.TestCase): "max_new_tokens": ( 16 if i > 0 else 0 ), # prefill only for the first chunk + "no_stop_trim": True, + "skip_special_tokens": False, }, }, ).json() @@ -79,6 +81,8 @@ class TestSessionControl(unittest.TestCase): "sampling_params": { "temperature": 0, "max_new_tokens": 16, + "no_stop_trim": True, + "skip_special_tokens": False, }, }, ).json() @@ -93,6 +97,8 @@ class TestSessionControl(unittest.TestCase): "sampling_params": { "temperature": 0, "max_new_tokens": 16, + "no_stop_trim": True, + "skip_special_tokens": False, }, }, ).json() @@ -113,6 +119,8 @@ class TestSessionControl(unittest.TestCase): "sampling_params": { "temperature": 0, "max_new_tokens": 16, + "no_stop_trim": True, + "skip_special_tokens": False, }, }, ).json() @@ -133,13 +141,16 @@ class TestSessionControl(unittest.TestCase): "max_new_tokens": ( 16 if i > 0 else 0 ), # prefill only for the first chunk + "no_stop_trim": True, + "skip_special_tokens": False, }, }, ).json() if i > 0: - input_ids += tokenizer.encode(response["text"])[ - 1: - ] # drop the bos token + output_ids = tokenizer.encode(response["text"]) + if output_ids[0] == tokenizer.bos_token_id: + output_ids = output_ids[1:] + input_ids += output_ids outputs_normal.append(response["text"]) if i == 0: input_ids_first_req = input_ids.copy() @@ -152,6 +163,187 @@ class TestSessionControl(unittest.TestCase): "sampling_params": { "temperature": 0, "max_new_tokens": 16, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + }, + ).json() + outputs_normal.append(response["text"]) + + print("outputs from chunked queries with session control:") + print(outputs_from_session) + print("outputs from normal queries:") + print(outputs_normal) + assert outputs_from_session == outputs_normal + + +class TestSessionControlVision(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "lmms-lab/llava-onevision-qwen2-7b-ov" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + # other_args={"--disable-radix"}, + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid, include_self=True) + + def test_session_control(self): + text_chunks = [ + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n", + "<|im_start|>user\n\nDescribe this image in a very short sentence.<|im_end|>\n<|im_start|>assistant\n", + "<|im_start|>user\n\nIs this image same with the previous image? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n", + "<|im_start|>user\n\nIs this image same with the previous image? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n", + ] + image_chunks = [ + "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png", + "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png", + "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png", + ] + assert len(text_chunks) == len(image_chunks) + 1 + tokenizer = get_tokenizer(self.model) + text_input_ids = [tokenizer.encode(x) for x in text_chunks] + + # 1. using session control + session_id = requests.post( + self.base_url + "/open_session", + json={"capacity_of_str_len": 1000}, + ).json() + rid = None + + first_rid = None + outputs_from_session = [] + for i in range(len(text_input_ids)): + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": text_input_ids[i], + "image_data": image_chunks[i - 1] if i > 0 else None, + "modalities": ["multi-images"], + "session": [session_id, rid], + "sampling_params": { + "temperature": 0, + "max_new_tokens": ( + 16 if i > 0 else 0 + ), # prefill only for the first chunk + "no_stop_trim": True, + "skip_special_tokens": False, + }, + }, + ).json() + rid = response["meta_info"]["id"] + if i == 0: + first_rid = rid + if i > 0: + outputs_from_session.append(response["text"]) + + # backtrack to the first request and regenerate + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": text_input_ids[-1], + "image_data": image_chunks[-1:], + "modalities": ["multi-images"], + "session": [session_id, first_rid], + "sampling_params": { + "temperature": 0, + "max_new_tokens": 16, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + }, + ).json() + outputs_from_session.append(response["text"]) + + # query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": text_input_ids[-1], + "image_data": image_chunks[-1:], + "modalities": ["multi-images"], + "session": [session_id, rid], + "sampling_params": { + "temperature": 0, + "max_new_tokens": 16, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + }, + ).json() + assert response["meta_info"]["finish_reason"]["type"] == "abort" + + ret = requests.post( + self.base_url + "/close_session", + json={"session_id": session_id}, + ) + assert ret.status_code == 200 + + # send a request to a closed session, should see abort + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": text_input_ids[-1], + "session": [session_id, first_rid], + "sampling_params": { + "temperature": 0, + "max_new_tokens": 16, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + }, + ).json() + assert response["meta_info"]["finish_reason"]["type"] == "abort" + + # 2. not use session control + input_ids_first_req = None + input_ids = [] + outputs_normal = [] + for i in range(len(text_input_ids)): + input_ids += text_input_ids[i] + image_data = image_chunks[:i] if i > 0 else None + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": input_ids, + "image_data": image_data, + "modalities": ["multi-images"], + "sampling_params": { + "temperature": 0, + "max_new_tokens": ( + 16 if i > 0 else 0 + ), # prefill only for the first chunk + "no_stop_trim": True, + "skip_special_tokens": False, + }, + }, + ).json() + if i > 0: + output_ids = tokenizer.encode(response["text"]) + if output_ids[0] == tokenizer.bos_token_id: + output_ids = output_ids[1:] + input_ids += output_ids + outputs_normal.append(response["text"]) + if i == 0: + input_ids_first_req = input_ids.copy() + + input_ids_first_req += text_input_ids[-1] + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": input_ids_first_req, + "image_data": image_chunks[-1:], + "modalities": ["multi-images"], + "sampling_params": { + "temperature": 0, + "max_new_tokens": 16, + "no_stop_trim": True, + "skip_special_tokens": False, }, }, ).json()