From 681fdc264bd278a30559b4317c62763197fef9d5 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Date: Sat, 24 May 2025 18:39:21 -0700 Subject: [PATCH] Refactor vlm embedding routine to use precomputed feature (#6543) Signed-off-by: Xinyuan Tong --- python/sglang/srt/managers/mm_utils.py | 126 ++++++++---- .../managers/multimodal_processors/qwen_vl.py | 14 +- python/sglang/srt/models/gemma3_mm.py | 7 - python/sglang/srt/models/qwen2_5_vl.py | 6 - python/sglang/srt/models/qwen2_vl.py | 6 - test/srt/run_suite.py | 2 +- test/srt/test_vlm_accuracy.py | 140 +------------ test/srt/test_vlm_input_format.py | 187 ++++++++++++++++++ 8 files changed, 285 insertions(+), 203 deletions(-) create mode 100644 test/srt/test_vlm_input_format.py diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index f39f730e3..8a9cea83a 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -252,40 +252,36 @@ def get_embedding_chunk( return embedding_chunk, start_index, end_index -def get_embedding_and_mask( +def _get_precomputed_embedding( + items: List[MultimodalDataItem], +) -> Optional[torch.Tensor]: + """ + If all items have precomputed_features, return their concatenation. + If some but not all have precomputed_features, raise NotImplementedError. + If none have precomputed_features, return None. + """ + precomputed_features = [item.precomputed_features for item in items] + if any(feature is not None for feature in precomputed_features): + if not all(feature is not None for feature in precomputed_features): + raise NotImplementedError( + "MM inputs where only some items are precomputed." + ) + result = torch.concat(precomputed_features) + # some models embedding is 3-dim, reshape it to 2-dim (similar to get_embedding_chunk) + result = result.reshape(-1, result.shape[-1]) + return result + return None + + +def _get_chunked_prefill_embedding( data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor], embedding_items: List[MultimodalDataItem], - placeholder_tensor: torch.Tensor, - input_ids: torch.Tensor, items_size: List[int], prefix_length: List[int], extend_length: List[int], items_offset_list: List[List[Tuple[int, int]]], -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Generate multimodal embeddings and create a mask for identifying their positions in the input sequence. - - Args: - data_embedding_func: Function that generates embeddings for multimodal items - embedding_items: List of multimodal items to embed - placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content - input_ids: The input token IDs tensor - items_size: Cumulative sizes of multimodal items per request - prefix_length: Prefix lengths for each request - extend_length: Sequence lengths for each request - items_offset_list: List of offset ranges for multimodal items in each request - - Returns: - A tuple containing: - - The generated embeddings tensor - - A boolean mask tensor indicating where these embeddings should be placed - - Raises: - AssertionError: If the number of multimodal tokens in input_ids doesn't match - the number of tokens in the generated embeddings - """ - # 1. Get the embedding - # Calculate embedding for each request, try to get it from cache to avoid repeated calculation +) -> Optional[torch.Tensor]: + # Calculate embedding for each request, try to get it from cache to avoid repeated calculation embedding_list = [] for i in range(len(items_size) - 1): if items_size[i] == items_size[i + 1]: @@ -321,21 +317,28 @@ def get_embedding_and_mask( embedding_cache.free(embedding_items_hash) embedding_list.append(embedding_per_req_chunk) if len(embedding_list) == 0: - return None, None - embedding = torch.concat(embedding_list, dim=0) - # 2. Check the embedding - num_mm_tokens_in_embedding = embedding.shape[0] - special_multimodal_mask = torch.isin( - input_ids, - placeholder_tensor, - ).unsqueeze(-1) + return None + return torch.concat(embedding_list, dim=0) - num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item() + +def _get_multimodal_mask( + input_ids: torch.Tensor, placeholder_tensor: torch.Tensor +) -> torch.Tensor: + return torch.isin(input_ids, placeholder_tensor).unsqueeze(-1) + + +def _adjust_embedding_length( + embedding: torch.Tensor, + mask: torch.Tensor, + logger, +) -> torch.Tensor: + num_mm_tokens_in_embedding = embedding.shape[0] + num_mm_tokens_in_input_ids = mask.sum().item() if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding: logger.warning( f"Number of tokens in multimodal embedding does not match those in the input text. " f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} " - "tokens from multimodal embeddings." + f"tokens from multimodal embeddings." ) if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding: chunked_prefill_size = global_server_args_dict["chunked_prefill_size"] @@ -353,7 +356,54 @@ def get_embedding_and_mask( raise RuntimeError( f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error" ) + return embedding + +def get_embedding_and_mask( + data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor], + embedding_items: List[MultimodalDataItem], + placeholder_tensor: torch.Tensor, + input_ids: torch.Tensor, + items_size: List[int], + prefix_length: List[int], + extend_length: List[int], + items_offset_list: List[List[Tuple[int, int]]], +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Generate multimodal embeddings and create a mask for identifying their positions in the input sequence. + + Args: + data_embedding_func: Function that generates embeddings for multimodal items + embedding_items: List of multimodal items to embed + placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content + input_ids: The input token IDs tensor + items_size: Cumulative sizes of multimodal items per request + prefix_length: Prefix lengths for each request + extend_length: Sequence lengths for each request + items_offset_list: List of offset ranges for multimodal items in each request + + Returns: + A tuple containing: + - The generated embeddings tensor + - A boolean mask tensor indicating where these embeddings should be placed + """ + # 1. Get embedding + embedding = _get_precomputed_embedding(embedding_items) + if embedding is None: + embedding = _get_chunked_prefill_embedding( + data_embedding_func, + embedding_items, + items_size, + prefix_length, + extend_length, + items_offset_list, + ) + if embedding is None: + return None, None + # 2. Get mask + special_multimodal_mask = _get_multimodal_mask(input_ids, placeholder_tensor) + # 3. Adjust embedding length if needed + embedding = _adjust_embedding_length(embedding, special_multimodal_mask, logger) return embedding, special_multimodal_mask diff --git a/python/sglang/srt/managers/multimodal_processors/qwen_vl.py b/python/sglang/srt/managers/multimodal_processors/qwen_vl.py index 76c3c546f..c47652c15 100644 --- a/python/sglang/srt/managers/multimodal_processors/qwen_vl.py +++ b/python/sglang/srt/managers/multimodal_processors/qwen_vl.py @@ -144,12 +144,11 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): if base_output.images: if images_are_preprocessed: - image_grid_thw = torch.concat( - [ - torch.as_tensor(item.image_grid_thws) - for item in base_output.images - ] - ) + all_image_grid_thws = [ + item.image_grid_thws + for item in base_output.images + if item.image_grid_thws is not None + ] all_pixel_values = [ item.pixel_values for item in base_output.images @@ -160,6 +159,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): for item in base_output.images if item.precomputed_features is not None ] + image_grid_thw = ( + torch.concat(all_image_grid_thws) if all_image_grid_thws else None + ) pixel_values = ( torch.concat(all_pixel_values) if all_pixel_values else None ) diff --git a/python/sglang/srt/models/gemma3_mm.py b/python/sglang/srt/models/gemma3_mm.py index 7d776da0d..45b3c4572 100644 --- a/python/sglang/srt/models/gemma3_mm.py +++ b/python/sglang/srt/models/gemma3_mm.py @@ -282,13 +282,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ - if any(item.precomputed_features is not None for item in items): - if not all(item.precomputed_features is not None for item in items): - raise NotImplementedError( - "MM inputs where only some items are precomputed." - ) - return torch.concat([item.precomputed_features for item in items]) - # Process images one by one to handle flatten_batch=True constraint in vision_tower all_pixel_values = flatten_nested_list([item.pixel_values for item in items]) vision_outputs_list = [] diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 420216c7b..4fbac6fd4 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -499,12 +499,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): return pattern.pad_input_tokens(input_ids, mm_inputs) def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: - if any(item.precomputed_features is not None for item in items): - if not all(item.precomputed_features is not None for item in items): - raise NotImplementedError( - "MM inputs where only some items are precomputed." - ) - return torch.concat([item.precomputed_features for item in items]) # in qwen-vl, last dim is the same pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type( self.visual.dtype diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index b4421290e..f653401d8 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -486,12 +486,6 @@ class Qwen2VLForConditionalGeneration(nn.Module): return pattern.pad_input_tokens(input_ids, mm_inputs) def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: - if any(item.precomputed_features is not None for item in items): - if not all(item.precomputed_features is not None for item in items): - raise NotImplementedError( - "MM inputs where only some items are precomputed." - ) - return torch.concat([item.precomputed_features for item in items]) # in qwen-vl, last dim is the same pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type( self.visual.dtype diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index b63d7b5e7..50c014584 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -81,7 +81,7 @@ suites = { TestFile("test_update_weights_from_tensor.py", 48), TestFile("test_vertex_endpoint.py", 31), TestFile("test_vision_chunked_prefill.py", 175), - TestFile("test_vlm_accuracy.py", 60), + TestFile("test_vlm_input_format.py", 300), TestFile("test_vision_openai_server_a.py", 700), TestFile("test_vision_openai_server_b.py", 700), TestFile("test_w8a8_quantization.py", 46), diff --git a/test/srt/test_vlm_accuracy.py b/test/srt/test_vlm_accuracy.py index 9ad17dcb7..c05d2d526 100644 --- a/test/srt/test_vlm_accuracy.py +++ b/test/srt/test_vlm_accuracy.py @@ -10,15 +10,8 @@ import requests import torch import torch.nn.functional as F from PIL import Image -from transformers import ( - AutoModel, - AutoProcessor, - AutoTokenizer, - Gemma3ForConditionalGeneration, - Qwen2_5_VLForConditionalGeneration, -) +from transformers import AutoModel, AutoProcessor, AutoTokenizer -from sglang import Engine from sglang.srt.configs.model_config import ModelConfig from sglang.srt.conversation import generate_chat_conv from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache @@ -41,9 +34,6 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase): def setUpClass(cls): cls.image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - cls.model_path = "" - cls.chat_template = "" - cls.processor = "" response = requests.get(cls.image_url) cls.main_image = Image.open(BytesIO(response.content)) @@ -274,131 +264,3 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase): ) self.compare_outputs(sglang_output, hf_output) - - -class TestQwenVLUnderstandsImage(VisionLLMLogitsBase): - - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.model_path = "Qwen/Qwen2.5-VL-3B-Instruct" - cls.chat_template = "qwen2-vl" - cls.processor = AutoProcessor.from_pretrained( - cls.model_path, trust_remote_code=True, use_fast=True - ) - cls.visual = ( - Qwen2_5_VLForConditionalGeneration.from_pretrained( - cls.model_path, torch_dtype=torch.bfloat16 - ) - .eval() - .visual.to(cls.device) - ) - - def setUp(self): - self.engine = Engine( - model_path=self.model_path, - chat_template=self.chat_template, - device=self.device.type, - mem_fraction_static=0.8, - ) - - def tearDown(self): - self.engine.shutdown() - - async def test_qwen_vl_understands_image(self): - req = self.get_completion_request() - conv = generate_chat_conv(req, template_name=self.chat_template) - text = conv.get_prompt() - output = await self.engine.async_generate( - prompt=text, - image_data=[self.main_image], - sampling_params=dict(temperature=0.0), - ) - self.assertIn("taxi", output["text"].lower()) - - async def test_qwen_vl_understands_precomputed_features(self): - req = self.get_completion_request() - processor_output = self.get_processor_output(req=req) - with torch.inference_mode(): - precomputed_features = self.visual( - processor_output["pixel_values"], processor_output["image_grid_thw"] - ) - output = await self.engine.async_generate( - input_ids=processor_output["input_ids"][0].detach().cpu().tolist(), - image_data=[ - dict( - modality="IMAGE", - image_grid_thws=processor_output["image_grid_thw"], - precomputed_features=precomputed_features, - ) - ], - sampling_params=dict(temperature=0.0), - ) - self.assertIn("taxi", output["text"].lower()) - - -class TestGemmaUnderstandsImage(VisionLLMLogitsBase): - - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.model_path = "google/gemma-3-4b-it" - cls.chat_template = "gemma-it" - cls.processor = AutoProcessor.from_pretrained( - cls.model_path, trust_remote_code=True, use_fast=True - ) - model = Gemma3ForConditionalGeneration.from_pretrained( - cls.model_path, torch_dtype=torch.bfloat16 - ) - cls.vision_tower = model.vision_tower.eval().to(cls.device) - cls.mm_projector = model.multi_modal_projector.eval().to(cls.device) - - @classmethod - def visual(cls, pixel_values): - vision_outputs = cls.vision_tower(pixel_values=pixel_values).last_hidden_state - image_features = cls.mm_projector(vision_outputs) - return image_features - - def setUp(self): - self.engine = Engine( - model_path=self.model_path, - chat_template=self.chat_template, - device=self.device.type, - mem_fraction_static=0.5, - enable_multimodal=True, - ) - - def tearDown(self): - self.engine.shutdown() - - async def test_gemma_understands_image(self): - req = self.get_completion_request() - conv = generate_chat_conv(req, template_name=self.chat_template) - text = conv.get_prompt() - output = await self.engine.async_generate( - prompt=text, - image_data=[self.main_image], - sampling_params=dict(temperature=0.0), - ) - self.assertIn("taxi", output["text"].lower()) - - async def test_gemma_understands_precomputed_features(self): - req = self.get_completion_request() - processor_output = self.get_processor_output(req=req) - with torch.inference_mode(): - precomputed_features = self.visual(processor_output["pixel_values"]) - output = await self.engine.async_generate( - input_ids=processor_output["input_ids"][0].detach().cpu().tolist(), - image_data=[ - dict( - modality="IMAGE", - precomputed_features=precomputed_features, - ) - ], - sampling_params=dict(temperature=0.0), - ) - self.assertIn("taxi", output["text"].lower()) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/srt/test_vlm_input_format.py b/test/srt/test_vlm_input_format.py new file mode 100644 index 000000000..ccf8d33de --- /dev/null +++ b/test/srt/test_vlm_input_format.py @@ -0,0 +1,187 @@ +import json +import unittest +from io import BytesIO +from typing import Optional + +import requests +import torch +from PIL import Image +from transformers import ( + AutoProcessor, + Gemma3ForConditionalGeneration, + Qwen2_5_VLForConditionalGeneration, +) + +from sglang import Engine +from sglang.srt.conversation import generate_chat_conv +from sglang.srt.openai_api.protocol import ChatCompletionRequest + +TEST_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" + + +class VLMInputTestBase: + model_path = None + chat_template = None + processor = None + visual = None # Should be a callable for precomputed features + + @classmethod + def setUpClass(cls): + assert cls.model_path is not None, "Set model_path in subclass" + assert cls.chat_template is not None, "Set chat_template in subclass" + cls.image_url = TEST_IMAGE_URL + cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + response = requests.get(cls.image_url) + cls.main_image = Image.open(BytesIO(response.content)) + cls.processor = AutoProcessor.from_pretrained( + cls.model_path, trust_remote_code=True, use_fast=True + ) + cls._init_visual() + + @classmethod + def _init_visual(cls): + """Override in subclass to set up cls.visual as a callable for precomputed features.""" + raise NotImplementedError + + def setUp(self): + self.engine = Engine( + model_path=self.model_path, + chat_template=self.chat_template, + device=self.device.type, + mem_fraction_static=0.8, + enable_multimodal=True, + disable_cuda_graph=True, + ) + + def tearDown(self): + self.engine.shutdown() + + def get_completion_request(self) -> ChatCompletionRequest: + json_structure = { + "model": self.model_path, + "messages": [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": self.image_url}}, + {"type": "text", "text": "What's in this picture?"}, + ], + } + ], + } + json_str = json.dumps(json_structure) + return ChatCompletionRequest.model_validate_json(json_str) + + def get_processor_output(self, req: Optional[ChatCompletionRequest] = None): + if req is None: + req = self.get_completion_request() + conv = generate_chat_conv(req, template_name=self.chat_template) + text = conv.get_prompt() + + # Process inputs using processor + inputs = self.processor( + text=[text], + images=[self.main_image], + return_tensors="pt", + ).to(self.device) + + return inputs + + async def test_understands_image(self): + req = self.get_completion_request() + conv = generate_chat_conv(req, template_name=self.chat_template) + text = conv.get_prompt() + output = await self.engine.async_generate( + prompt=text, + image_data=[self.main_image], + sampling_params=dict(temperature=0.0), + ) + self.assertIn("taxi", output["text"].lower()) + + async def test_understands_precomputed_features(self): + req = self.get_completion_request() + processor_output = self.get_processor_output(req=req) + with torch.inference_mode(): + precomputed_features = self.__class__.visual(processor_output) + output = await self.engine.async_generate( + input_ids=processor_output["input_ids"][0].detach().cpu().tolist(), + image_data=[ + self._precomputed_image_data(processor_output, precomputed_features) + ], + sampling_params=dict(temperature=0.0), + ) + self.assertIn("taxi", output["text"].lower()) + + async def test_understands_pixel_values(self): + req = self.get_completion_request() + processor_output = self.get_processor_output(req=req) + output = await self.engine.async_generate( + input_ids=processor_output["input_ids"][0].detach().cpu().tolist(), + image_data=[self._pixel_values_image_data(processor_output)], + sampling_params=dict(temperature=0.0), + ) + self.assertIn("taxi", output["text"].lower()) + + def _precomputed_image_data(self, processor_output, precomputed_features): + """This should not be overridden.""" + return dict( + modality="IMAGE", + precomputed_features=precomputed_features, + ) + + def _pixel_values_image_data(self, processor_output): + """Override in subclass to pass the correct set of arguments.""" + raise NotImplementedError + + +class TestQwenVLUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestCase): + model_path = "Qwen/Qwen2.5-VL-3B-Instruct" + chat_template = "qwen2-vl" + + @classmethod + def _init_visual(cls): + cls.visual_model = ( + Qwen2_5_VLForConditionalGeneration.from_pretrained( + cls.model_path, torch_dtype=torch.bfloat16 + ) + .eval() + .visual.to(cls.device) + ) + cls.visual = lambda processor_output: cls.visual_model( + processor_output["pixel_values"], processor_output["image_grid_thw"] + ) + + def _pixel_values_image_data(self, processor_output): + return dict( + modality="IMAGE", + image_grid_thws=processor_output["image_grid_thw"], + pixel_values=processor_output["pixel_values"], + ) + + +class TestGemmaUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestCase): + model_path = "google/gemma-3-4b-it" + chat_template = "gemma-it" + + @classmethod + def _init_visual(cls): + model = Gemma3ForConditionalGeneration.from_pretrained( + cls.model_path, torch_dtype=torch.bfloat16 + ) + cls.vision_tower = model.vision_tower.eval().to(cls.device) + cls.mm_projector = model.multi_modal_projector.eval().to(cls.device) + cls.visual = lambda processor_output: cls.mm_projector( + cls.vision_tower( + pixel_values=processor_output["pixel_values"] + ).last_hidden_state + ) + + def _pixel_values_image_data(self, processor_output): + return dict( + modality="IMAGE", + pixel_values=processor_output["pixel_values"][0], + ) + + +if __name__ == "__main__": + unittest.main()