diff --git a/examples/runtime/llava_onevision/http_llava_onevision_test.py b/examples/runtime/llava_onevision/http_llava_onevision_test.py index 0c93d2ce2..2c7c2bd38 100644 --- a/examples/runtime/llava_onevision/http_llava_onevision_test.py +++ b/examples/runtime/llava_onevision/http_llava_onevision_test.py @@ -93,12 +93,14 @@ def multi_image_stream_request_test(client): "image_url": { "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" }, + "modalities": "multi-images", }, { "type": "image_url", "image_url": { "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png" }, + "modalities": "multi-images", }, { "type": "text", @@ -218,6 +220,7 @@ def prepare_video_messages(video_path): frame_format = { "type": "image_url", "image_url": {"url": "data:image/jpeg;base64,{}"}, + "modalities": "video", } for base64_frame in base64_frames: diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index dbc376d95..9a1227218 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -71,6 +71,7 @@ class Conversation: # Stop criteria (the default one is EOS token) stop_str: Union[str, List[str]] = None image_data: Optional[List[str]] = None + modalities: Optional[List[str]] = None def get_prompt(self) -> str: """Get the prompt for generation.""" @@ -379,6 +380,7 @@ def generate_chat_conv( sep2=conv.sep2, stop_str=conv.stop_str, image_data=[], + modalities=[], ) if isinstance(request.messages, str): @@ -408,6 +410,7 @@ def generate_chat_conv( for content in message.content: if content.type == "image_url": num_image_url += 1 + conv.modalities.append(content.modalities) if num_image_url > 1: image_token = "" else: diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 5b91ff62e..8e53df335 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -50,6 +50,8 @@ class GenerateReqInput: return_text_in_logprobs: bool = False # Whether to stream output. stream: bool = False + # The modalities of the image data [image, multi-images, video] + modalities: Optional[List[str]] = None def post_init(self): if (self.text is None and self.input_ids is None) or ( @@ -177,6 +179,8 @@ class TokenizedGenerateReqInput: top_logprobs_num: int # Whether to stream output stream: bool + # Modalities of the input images + modalites: Optional[List[str]] = None @dataclass diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c80cf2e27..f126cc9f3 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -130,6 +130,7 @@ class Req: self.image_sizes = None self.image_offsets = None self.pad_value = None + self.modalities = None # Prefix info self.extend_input_len = 0 diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 6af820641..d0cfed08c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -188,6 +188,7 @@ class TokenizerManager: pixel_values, image_hashes, image_sizes = await self._get_pixel_values( obj.image_data if not_use_index else obj.image_data[index] ) + modalities = obj.modalities return_logprob = ( obj.return_logprob if not_use_index else obj.return_logprob[index] ) @@ -243,6 +244,7 @@ class TokenizerManager: pixel_values, image_hashes, image_sizes = await self._get_pixel_values( obj.image_data[0] ) + modalities = obj.modalities return_logprob = obj.return_logprob[0] logprob_start_len = obj.logprob_start_len[0] top_logprobs_num = obj.top_logprobs_num[0] @@ -263,6 +265,7 @@ class TokenizerManager: logprob_start_len, top_logprobs_num, obj.stream, + modalities, ) else: # is embedding tokenized_obj = TokenizedEmbeddingReqInput( @@ -346,6 +349,7 @@ class TokenizerManager: pixel_values, image_hashes, image_sizes = ( await self._get_pixel_values(obj.image_data[index]) ) + modalities = obj.modalities tokenized_obj = TokenizedGenerateReqInput( rid, @@ -359,6 +363,7 @@ class TokenizerManager: obj.logprob_start_len[index], obj.top_logprobs_num[index], obj.stream, + modalities, ) else: tokenized_obj = TokenizedEmbeddingReqInput( diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index c2c0e6c2d..7bb9c4335 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -358,6 +358,8 @@ class ModelTpServer: req.pixel_values, req.image_sizes, ) + # Only when pixel values is not None we have modalities + req.modalities = recv_req.modalites req.return_logprob = recv_req.return_logprob req.logprob_start_len = recv_req.logprob_start_len req.top_logprobs_num = recv_req.top_logprobs_num diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index a443b113d..75f9136d3 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -78,6 +78,7 @@ class InputMetadata: pixel_values: List[torch.Tensor] = None image_sizes: List[List[List[int]]] = None image_offsets: List[List[int]] = None + modalities: List[List[str]] = None # Trition attention backend triton_max_seq_len: int = 0 @@ -96,6 +97,7 @@ class InputMetadata: self.pixel_values = [r.pixel_values for r in reqs] self.image_sizes = [r.image_sizes for r in reqs] self.image_offsets = [r.image_offsets for r in reqs] + self.modalities = [r.modalities for r in reqs] def compute_positions(self, batch: ScheduleBatch): position_ids_offsets = batch.position_ids_offsets diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 2e3c9ceba..62041a895 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -138,6 +138,12 @@ class LlavaBaseForCausalLM(nn.Module): ) -> torch.Tensor: if input_metadata.forward_mode == ForwardMode.EXTEND: bs = input_metadata.batch_size + # Got List[List[str]] extend it to List[str] + # The length of the List should be equal to batch size + modalities_list = [] + for modalities in input_metadata.modalities: + if modalities is not None: + modalities_list.extend(modalities) # Embed text inputs input_embeds = self.language_model.model.embed_tokens(input_ids) @@ -179,7 +185,7 @@ class LlavaBaseForCausalLM(nn.Module): new_image_features = [] height = width = self.num_patches_per_side for image_idx, image_feature in enumerate(image_features): - if len(image_sizes[image_idx]) == 1: + if modalities_list[image_idx] == 1: image_aspect_ratio = ( self.config.image_aspect_ratio ) # single image @@ -191,6 +197,7 @@ class LlavaBaseForCausalLM(nn.Module): if ( image_feature.shape[0] > 1 and "anyres" in image_aspect_ratio + and modalities_list[image_idx] == "image" ): base_image_feature = image_feature[0] image_feature = image_feature[1:] @@ -290,7 +297,7 @@ class LlavaBaseForCausalLM(nn.Module): ) image_feature = image_feature.unsqueeze(0) else: - if image_feature.shape[0] > 16: # video + if modalities_list[image_idx] == "video": # video # 2x2 pooling num_of_frames = image_feature.shape[0] image_feature = image_feature.view( diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index cd7526b0d..f1195aff7 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -832,6 +832,7 @@ def v1_chat_generate_request( return_logprobs = [] logprob_start_lens = [] top_logprobs_nums = [] + modalities_list = [] # NOTE: with openai API, the prompt's logprobs are always not computed @@ -864,10 +865,12 @@ def v1_chat_generate_request( ) stop = request.stop image_data = None + modalities = [] else: conv = generate_chat_conv(request, chat_template_name) prompt = conv.get_prompt() image_data = conv.image_data + modalities = conv.modalities stop = conv.stop_str or [] if request.stop: if isinstance(request.stop, str): @@ -880,6 +883,7 @@ def v1_chat_generate_request( prompt_ids = request.messages stop = request.stop image_data = None + modalities = [] input_ids.append(prompt_ids) return_logprobs.append(request.logprobs) logprob_start_lens.append(-1) @@ -901,6 +905,7 @@ def v1_chat_generate_request( } ) image_data_list.append(image_data) + modalities_list.extend(modalities) if len(all_requests) == 1: input_ids = input_ids[0] if isinstance(input_ids, str): @@ -912,6 +917,7 @@ def v1_chat_generate_request( return_logprobs = return_logprobs[0] logprob_start_lens = logprob_start_lens[0] top_logprobs_nums = top_logprobs_nums[0] + modalities_list = modalities_list[:1] else: if isinstance(input_ids[0], str): prompt_kwargs = {"text": input_ids} @@ -928,6 +934,7 @@ def v1_chat_generate_request( stream=all_requests[0].stream, return_text_in_logprobs=True, rid=request_ids, + modalities=modalities_list, ) if len(all_requests) == 1: return adapted_request, all_requests[0] diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 8073df795..5525cd882 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -213,6 +213,7 @@ class ChatCompletionMessageContentImageURL(BaseModel): class ChatCompletionMessageContentImagePart(BaseModel): type: Literal["image_url"] image_url: ChatCompletionMessageContentImageURL + modalities: Optional[Literal["image", "multi-images", "video"]] = "image" ChatCompletionMessageContentPart = Union[ diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 4f764c09c..727f5774c 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -140,12 +140,14 @@ class TestOpenAIVisionServer(unittest.TestCase): "image_url": { "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png" }, + "modalities": "multi-images", }, { "type": "image_url", "image_url": { "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" }, + "modalities": "multi-images", }, { "type": "text", @@ -192,6 +194,7 @@ class TestOpenAIVisionServer(unittest.TestCase): frame_format = { "type": "image_url", "image_url": {"url": "data:image/jpeg;base64,{}"}, + "modalities": "video", } for base64_frame in base64_frames: