From 361971b8593431aa572b3a5b2aa12b250ec09b03 Mon Sep 17 00:00:00 2001 From: Pan Lyu Date: Fri, 7 Mar 2025 08:46:20 +0800 Subject: [PATCH] Add Support for Qwen2-VL Multi-modal Embedding Models (#3694) --- docs/references/supported_models.md | 2 + python/sglang/srt/conversation.py | 64 ++++++++++ python/sglang/srt/entrypoints/engine.py | 4 +- python/sglang/srt/managers/io_struct.py | 47 +++++--- python/sglang/srt/managers/scheduler.py | 24 ++++ .../sglang/srt/managers/tokenizer_manager.py | 12 +- python/sglang/srt/openai_api/adapter.py | 28 +++++ python/sglang/srt/openai_api/protocol.py | 9 +- python/sglang/test/runners.py | 114 ++++++++++++++++-- test/srt/models/test_gme_qwen_models.py | 85 +++++++++++++ test/srt/run_suite.py | 1 + 11 files changed, 356 insertions(+), 34 deletions(-) create mode 100644 test/srt/models/test_gme_qwen_models.py diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index 69b11bfbd..396746a0d 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -38,6 +38,8 @@ - Mistral embedding models - Qwen embedding models - `python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct --is-embedding` +- Multi-modal embedding models + - `python -m sglang.launch_server --model-path Alibaba-NLP/gme-Qwen2-VL-2B-Instruct --is-embedding --chat-template gme-qwen2-vl` ## Reward Models diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 3a775aa1e..a19a9e735 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -44,6 +44,7 @@ class SeparatorStyle(IntEnum): CHATGLM3 = auto() DEEPSEEK_CHAT = auto() METAMATH = auto() + QWEN2_VL_EMBED = auto() @dataclasses.dataclass @@ -110,6 +111,15 @@ class Conversation: else: ret += role + "\n" return ret + elif self.sep_style == SeparatorStyle.QWEN2_VL_EMBED: + ret = "" if system_prompt == "" else system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + "\n" + message + self.sep + else: + ret += role + "\n" + ret += self.stop_str + return ret elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: ret = system_prompt for role, message in self.messages: @@ -366,6 +376,46 @@ def chat_template_exists(template_name: str) -> bool: return template_name in chat_templates +def generate_embedding_convs( + texts: List[str], images: List[str], template_name: str +) -> List[Conversation]: + conv_template = chat_templates[template_name].copy() + convs = [] + for text, image in zip(texts, images): + conv = Conversation( + name=conv_template.name, + system_template=conv_template.system_template, + system_message=conv_template.system_message, + roles=conv_template.roles, + messages=list(conv_template.messages), # prevent in-place modification + offset=conv_template.offset, + sep_style=SeparatorStyle(conv_template.sep_style), + sep=conv_template.sep, + sep2=conv_template.sep2, + stop_str=conv_template.stop_str, + image_data=[], + modalities=[], + image_token=conv_template.image_token, + ) + real_content = "" + + if image is not None: + image_token = ( + conv.image_token + "\n" + if conv.name != "gme-qwen2-vl" + else conv.image_token + ) + real_content += image_token + if text is not None: + real_content += text + conv.append_message(conv.roles[0], real_content) + # Add a blank message for the assistant. + conv.append_message(conv.roles[1], None) + convs.append(conv) + + return convs + + def generate_chat_conv( request: ChatCompletionRequest, template_name: str ) -> Conversation: @@ -555,6 +605,20 @@ register_conv_template( ) ) +# Reference: https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct#usage +register_conv_template( + Conversation( + name="gme-qwen2-vl", + system_message="You are a helpful assistant.", + system_template="<|im_start|>system\n{system_message}", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep="<|im_end|>\n", + sep_style=SeparatorStyle.QWEN2_VL_EMBED, + stop_str="<|endoftext|>", + image_token="<|vision_start|><|image_pad|><|vision_end|>", + ) +) + # Reference: https://huggingface.co/openbmb/MiniCPM-V-2_6#usage register_conv_template( Conversation( diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 074691a4f..7c0f287b7 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -214,13 +214,13 @@ class Engine: def encode( self, prompt: Union[str, List[str], List[Dict], List[List[Dict]]], + image_data: Optional[Union[List[str], str]] = None, ) -> Dict: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. Please refer to `EmbeddingReqInput` for the documentation. """ - - obj = EmbeddingReqInput(text=prompt) + obj = EmbeddingReqInput(text=prompt, image_data=image_data) loop = asyncio.get_event_loop() generator = self.tokenizer_manager.generate_request(obj, None) ret = loop.run_until_complete(generator.__anext__()) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index e7d548710..232fb3859 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -293,6 +293,8 @@ class TokenizedGenerateReqInput: class EmbeddingReqInput: # The input prompt. It can be a single prompt or a batch of prompts. text: Optional[Union[List[str], str]] = None + # The image input. It can be a file name, a url, or base64 encoded string. + image_data: Optional[Union[List[str], str]] = None # The token ids for text; one can either specify text or input_ids. input_ids: Optional[Union[List[List[int]], List[int]]] = None # The request id. @@ -303,28 +305,40 @@ class EmbeddingReqInput: input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None # Whether to log metrics for this request (e.g. health_generate calls do not log metrics) log_metrics: bool = True + # The modalities of the image data [image, multi-images, video] + modalities: Optional[List[str]] = None def normalize_batch_and_arguments(self): - if (self.text is None and self.input_ids is None) or ( - self.text is not None and self.input_ids is not None - ): - raise ValueError("Either text or input_ids should be provided.") + # at least one of text, input_ids, or image should be provided + if self.text is None and self.input_ids is None and self.image_data is None: + raise ValueError( + "At least one of text, input_ids, or image should be provided" + ) + + # text and input_ids cannot be provided at the same time + if self.text is not None and self.input_ids is not None: + raise ValueError("text and input_ids cannot be provided at the same time") # Derive the batch size + self.batch_size = 0 + self.is_single = True + + # check the batch size of text if self.text is not None: - if isinstance(self.text, str): - self.is_single = True - self.batch_size = 1 + if isinstance(self.text, list): + self.batch_size += len(self.text) else: - self.is_single = False - self.batch_size = len(self.text) - else: - if isinstance(self.input_ids[0], int): - self.is_single = True - self.batch_size = 1 + self.batch_size += 1 + + # check the batch size of input_ids + if self.input_ids is not None: + if isinstance(self.input_ids[0], list): + self.batch_size += len(self.input_ids) else: - self.is_single = False - self.batch_size = len(self.input_ids) + self.batch_size += 1 + + if self.batch_size > 1: + self.is_single = False # Fill in default arguments if self.is_single: @@ -352,6 +366,7 @@ class EmbeddingReqInput: return EmbeddingReqInput( text=self.text[i] if self.text is not None else None, input_ids=self.input_ids[i] if self.input_ids is not None else None, + image_data=self.image_data[i] if self.image_data is not None else None, sampling_params=self.sampling_params[i], rid=self.rid[i], ) @@ -365,6 +380,8 @@ class TokenizedEmbeddingReqInput: input_text: str # The input token ids input_ids: List[int] + # The image inputs + image_inputs: dict # Dummy sampling params for compatibility sampling_params: SamplingParams diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 10698b0bc..05bc8d730 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -767,6 +767,30 @@ class Scheduler: ) req.tokenizer = self.tokenizer + # Handle multimodal inputs + if recv_req.image_inputs is not None: + image_inputs = ImageInputs.from_dict(recv_req.image_inputs) + # Expand a single image token into multiple dummy tokens for receiving image embeddings + req.origin_input_ids = self.pad_input_ids_func( + req.origin_input_ids, image_inputs + ) + req.extend_image_inputs(image_inputs) + + if len(req.origin_input_ids) >= self.max_req_input_len: + error_msg = ( + "Multimodal prompt is too long after expanding multimodal tokens. " + f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}." + ) + logger.error(error_msg) + req.origin_input_ids = [0] + req.image_inputs = None + req.sampling_params.max_new_tokens = 0 + req.finished_reason = FINISH_ABORT( + error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError" + ) + self.waiting_queue.append(req) + return + # Validate prompts length error_msg = validate_input_length( req, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 743c0c430..3132060ed 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -372,13 +372,12 @@ class TokenizerManager: ) input_ids = self.tokenizer.encode(input_text) + image_inputs: Dict = await self.image_processor.process_images_async( + obj.image_data, input_text or input_ids, obj, self.max_req_input_len + ) + if image_inputs and "input_ids" in image_inputs: + input_ids = image_inputs["input_ids"] if self.is_generation: - # TODO: also support getting embeddings for multimodal models - image_inputs: Dict = await self.image_processor.process_images_async( - obj.image_data, input_text or input_ids, obj, self.max_req_input_len - ) - if image_inputs and "input_ids" in image_inputs: - input_ids = image_inputs["input_ids"] return_logprob = obj.return_logprob logprob_start_len = obj.logprob_start_len top_logprobs_num = obj.top_logprobs_num @@ -438,6 +437,7 @@ class TokenizerManager: obj.rid, input_text, input_ids, + image_inputs, sampling_params, ) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 50464ba4b..43c3625bb 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -38,6 +38,7 @@ from sglang.srt.conversation import ( SeparatorStyle, chat_template_exists, generate_chat_conv, + generate_embedding_convs, register_conv_template, ) from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser @@ -68,6 +69,7 @@ from sglang.srt.openai_api.protocol import ( FileResponse, FunctionResponse, LogProbs, + MultimodalEmbeddingInput, ToolCall, TopLogprob, UsageInfo, @@ -1556,11 +1558,37 @@ def v1_embedding_request(all_requests, tokenizer_manager): prompt = prompts[0] if isinstance(prompt, str) or isinstance(prompt[0], str): prompt_kwargs = {"text": prompt} + elif isinstance(prompt, list) and isinstance( + prompt[0], MultimodalEmbeddingInput + ): + assert ( + chat_template_name is not None + ), "chat_template_name is required for multimodal inputs" + texts = [] + images = [] + for item in prompt: + texts.append(item.text if item.text is not None else None) + images.append(item.image if item.image is not None else None) + convs = generate_embedding_convs(texts, images, chat_template_name) + generate_prompts = [] + for conv in convs: + generate_prompts.append(conv.get_prompt()) + if len(generate_prompts) == 1: + prompt_kwargs = {"text": generate_prompts[0], "image_data": images[0]} + else: + prompt_kwargs = {"text": generate_prompts, "image_data": images} else: prompt_kwargs = {"input_ids": prompt} else: if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): prompt_kwargs = {"text": prompts} + elif isinstance(prompts[0], list) and isinstance( + prompts[0][0], MultimodalEmbeddingInput + ): + # TODO: multiple requests + raise NotImplementedError( + "Multiple requests with multimodal inputs are not supported yet" + ) else: prompt_kwargs = {"input_ids": prompts} diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 0c0aa0961..0b6148321 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -403,10 +403,17 @@ class ChatCompletionStreamResponse(BaseModel): usage: Optional[UsageInfo] = None +class MultimodalEmbeddingInput(BaseModel): + text: Optional[str] = None + image: Optional[str] = None + + class EmbeddingRequest(BaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/embeddings/create - input: Union[List[int], List[List[int]], str, List[str]] + input: Union[ + List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput] + ] model: str encoding_format: str = "float" dimensions: int = None diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index faccb16e5..d92dca4f0 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -19,7 +19,7 @@ from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.server import Engine @@ -135,6 +135,76 @@ class HFRunner: return True return False + # copy from https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct/blob/main/gme_inference.py + + def _get_gme_qwen2_vl_embeddings( + self, prompts, image_data: Optional[List[str]] = None + ): + from sglang.srt.utils import load_image + + images = None + if image_data is not None: + images = [load_image(image)[0] for image in image_data] + + inputs = self.processor( + text=prompts, + images=images, + padding=True, + truncation=True, + max_length=1800, + return_tensors="pt", + ) + inputs = {k: v.to(self.model.device) for k, v in inputs.items()} + with torch.no_grad(): + embeddings = self._forward_gme_qwen2_vl(**inputs) + return embeddings.tolist() + + def _forward_gme_qwen2_vl( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + pooling_mask: Optional[torch.LongTensor] = None, + **kwargs, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.model.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.model.visual.get_dtype()) + image_embeds = self.model.visual( + pixel_values, grid_thw=image_grid_thw + ).to(inputs_embeds.device) + image_mask = input_ids == self.model.config.image_token_id + inputs_embeds[image_mask] = image_embeds + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + outputs = self.model.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + ) + + pooling_mask = attention_mask if pooling_mask is None else pooling_mask + left_padding = pooling_mask[:, -1].sum() == pooling_mask.shape[0] # TODO + if left_padding: + embeddings = outputs.last_hidden_state[:, -1] + else: + sequence_lengths = pooling_mask.sum(dim=1) - 1 + batch_size = outputs.last_hidden_state.shape[0] + embeddings = outputs.last_hidden_state[ + torch.arange(batch_size, device=outputs.last_hidden_state.device), + sequence_lengths, + ] + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + return embeddings.contiguous() + def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): # Apply model-specific patches monkey_patch_gemma2_sdpa() @@ -148,9 +218,18 @@ class HFRunner: low_cpu_mem_usage=True, ).cuda() elif self.model_type == "embedding": - self.model = _get_sentence_transformer_embedding_model( - model_path, torch_dtype - ) + if "gme-qwen2-vl" in model_path.lower(): + self.model = AutoModelForVision2Seq.from_pretrained( + model_path, + torch_dtype=torch_dtype, + trust_remote_code=False, + low_cpu_mem_usage=True, + ).cuda() + self.processor = AutoProcessor.from_pretrained(model_path) + else: + self.model = _get_sentence_transformer_embedding_model( + model_path, torch_dtype + ) elif self.model_type == "reward": from transformers import AutoModelForSequenceClassification @@ -169,7 +248,9 @@ class HFRunner: # Run forward while True: - prompts, max_new_tokens, lora_paths, token_ids_logprob = in_queue.get() + prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob = ( + in_queue.get() + ) if lora_paths is not None: assert len(prompts) == len(lora_paths) @@ -189,7 +270,10 @@ class HFRunner: ) elif self.model_type == "embedding": assert not self.output_str_only - logits = self.model.encode(prompts).tolist() + if "gme-qwen2-vl" in model_path.lower(): + logits = self._get_gme_qwen2_vl_embeddings(prompts, image_data) + else: + logits = self.model.encode(prompts).tolist() out_queue.put(ModelOutput(embed_logits=logits)) elif self.model_type == "reward": @@ -211,11 +295,14 @@ class HFRunner: def forward( self, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, + image_data: Optional[List[str]] = None, max_new_tokens: int = 8, lora_paths: Optional[List[str]] = None, token_ids_logprob: Optional[int] = None, ): - self.in_queue.put((prompts, max_new_tokens, lora_paths, token_ids_logprob)) + self.in_queue.put( + (prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob) + ) return self.out_queue.get() def terminate(self): @@ -396,6 +483,7 @@ class SRTRunner: def forward( self, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, + image_data: Optional[List[str]] = None, max_new_tokens: int = 8, lora_paths: Optional[List[str]] = None, logprob_start_len: int = 0, @@ -413,17 +501,23 @@ class SRTRunner: token_ids_logprob=token_ids_logprob, ) else: - response = self.engine.encode(prompts) if self.model_type == "embedding": - logits = [x["embedding"] for x in response] + response = self.engine.encode(prompt=prompts, image_data=image_data) + if isinstance(response, list): + logits = [x["embedding"] for x in response] + else: + logits = [response["embedding"]] return ModelOutput(embed_logits=logits) + # reward model else: + response = self.engine.encode(prompts) scores = [x["embedding"][0] for x in response] return ModelOutput(scores=scores) def batch_forward( self, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, + image_data: Optional[List[str]] = None, max_new_tokens=8, lora_paths=None, ): @@ -439,7 +533,7 @@ class SRTRunner: lora_paths=lora_paths, ) else: - response = self.engine.encode(prompts) + response = self.engine.encode(prompts, image_data) if self.model_type == "embedding": logits = [x["embedding"] for x in response] return ModelOutput(embed_logits=logits) diff --git a/test/srt/models/test_gme_qwen_models.py b/test/srt/models/test_gme_qwen_models.py new file mode 100644 index 000000000..82f56adb3 --- /dev/null +++ b/test/srt/models/test_gme_qwen_models.py @@ -0,0 +1,85 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import multiprocessing as mp +import unittest + +import torch + +from sglang.test.runners import HFRunner, SRTRunner +from sglang.test.test_utils import get_similarities + +TEXTS = "two Subway Series sandwiches with meats, cheese, lettuce, tomatoes, and onions on a black background, accompanied by the Subway Series logo, highlighting a new sandwich series." +IMAGES = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/023.jpg" + + +MODELS = [ + ("Alibaba-NLP/gme-Qwen2-VL-2B-Instruct", 1e-3), +] +TORCH_DTYPES = [torch.float16] + + +class TestQmeQwenModels(unittest.TestCase): + @classmethod + def setUpClass(cls): + mp.set_start_method("spawn", force=True) + + def assert_close_embeddings(self, model, prefill_tolerance, torch_dtype): + + prompts_no_image = f"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n{TEXTS}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>" + prompts_with_image = f"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>assistant\n<|endoftext|>" + with HFRunner( + model, + torch_dtype=torch_dtype, + model_type="embedding", + ) as hf_runner: + hf_text_embeddings = hf_runner.forward(prompts=[prompts_no_image]) + hf_image_embeddings = hf_runner.forward( + prompts=[prompts_with_image], image_data=[IMAGES] + ) + with SRTRunner( + model, + tp_size=1, + torch_dtype=torch_dtype, + model_type="embedding", + ) as srt_runner: + srt_text_embeddings = srt_runner.forward(prompts=prompts_no_image) + srt_image_embeddings = srt_runner.forward( + prompts=prompts_with_image, image_data=IMAGES + ) + + similarity = get_similarities( + hf_text_embeddings.embed_logits[0], srt_text_embeddings.embed_logits[0] + ) + print("texts similarity diff", abs(similarity - 1)) + assert torch.all( + abs(similarity - 1) < prefill_tolerance + ), "embeddings are not all close" + similarity = get_similarities( + hf_image_embeddings.embed_logits[0], srt_image_embeddings.embed_logits[0] + ) + print("images similarity diff", abs(similarity - 1)) + assert torch.all( + abs(similarity - 1) < prefill_tolerance + ), "embeddings are not all close" + + def test_accuracy(self): + for model, prefill_tolerance in MODELS: + for torch_dtype in TORCH_DTYPES: + self.assert_close_embeddings(model, prefill_tolerance, torch_dtype) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 5ad56e5f5..0ff12ce00 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -13,6 +13,7 @@ suites = { "models/test_qwen_models.py", "models/test_reward_models.py", "test_gptqmodel_dynamic.py", + "models/test_gme_qwen_models.py", "test_abort.py", "test_chunked_prefill.py", "test_custom_allreduce.py",