From 0d503090aada82c5c561fad7c7f5fd4e97005368 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Mon, 26 May 2025 01:24:13 -0700 Subject: [PATCH] Supported precomputed feature for Kimi VL (#6599) --- .../multimodal_processors/base_processor.py | 16 ++++- .../managers/multimodal_processors/kimi_vl.py | 63 ++++++++++++------- .../managers/multimodal_processors/minicpm.py | 3 +- .../managers/multimodal_processors/qwen_vl.py | 29 ++------- test/srt/test_vlm_input_format.py | 29 +++++++++ 5 files changed, 93 insertions(+), 47 deletions(-) diff --git a/python/sglang/srt/managers/multimodal_processors/base_processor.py b/python/sglang/srt/managers/multimodal_processors/base_processor.py index a293d4be4..fa2e4d824 100644 --- a/python/sglang/srt/managers/multimodal_processors/base_processor.py +++ b/python/sglang/srt/managers/multimodal_processors/base_processor.py @@ -5,7 +5,7 @@ import multiprocessing as mp import os import re from abc import ABC, abstractmethod -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import numpy as np import torch @@ -382,3 +382,17 @@ class BaseMultimodalProcessor(ABC): "Unsupported: mixture of multimodal inputs where some but not all are preprocessed." ) return ret + + @staticmethod + def _extract_processor_features( + items: List[Any], attr_name: str + ) -> Optional[torch.Tensor]: + """ + Helper function to concat extracted attributes from processor output. + """ + values = [ + getattr(item, attr_name) + for item in items + if getattr(item, attr_name) is not None + ] + return torch.concat(values) if values else None diff --git a/python/sglang/srt/managers/multimodal_processors/kimi_vl.py b/python/sglang/srt/managers/multimodal_processors/kimi_vl.py index 0a276f7ce..86d41189e 100644 --- a/python/sglang/srt/managers/multimodal_processors/kimi_vl.py +++ b/python/sglang/srt/managers/multimodal_processors/kimi_vl.py @@ -1,4 +1,7 @@ -from typing import List, Union +import re +from typing import Any, Dict, List, Optional, Union + +import torch from sglang.srt.managers.multimodal_processors.base_processor import ( BaseMultimodalProcessor as SGLangBaseProcessor, @@ -17,20 +20,12 @@ class KimiVLImageProcessor(SGLangBaseProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) self.IMAGE_TOKEN = "<|media_pad|>" + self.IMAGE_TOKEN_REGEX = re.compile(r"(?:<\|media_pad\|>)+") self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN) - self.im_start = "<|media_start|>" - self.im_start_id = _processor.tokenizer.convert_tokens_to_ids(self.im_start) - - self.im_end = "<|media_end|>" - self.im_end_id = _processor.tokenizer.convert_tokens_to_ids(self.im_end) - - self.im_content = "<|media_content|>" - self.im_content_id = _processor.tokenizer.convert_tokens_to_ids(self.im_content) - async def process_mm_data_async( self, - image_data: List[Union[str, bytes]], + image_data: List[Union[str, bytes, Dict]], input_text, request_obj, max_req_input_len, @@ -45,30 +40,54 @@ class KimiVLImageProcessor(SGLangBaseProcessor): base_output = self.load_mm_data( prompt=input_text, image_data=image_data, - multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN), + multimodal_tokens=MultimodalSpecialTokens( + image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX + ), max_req_input_len=max_req_input_len, ) - ret = self.process_mm_data( - input_text=base_output.input_text, - images=base_output.images, - ) - input_ids = ret["input_ids"].flatten() + + images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images) + if not images_are_preprocessed: + ret = self.process_mm_data( + input_text=base_output.input_text, + images=base_output.images, + ) + input_ids = ret["input_ids"].flatten() + image_grid_thws = ret["image_grid_hws"] + pixel_values = ret["pixel_values"] + precomputed_features = None + else: + input_ids = self._processor.tokenizer( + base_output.input_text, + return_tensors="pt", + add_special_tokens=True, + ).input_ids.flatten() + + image_grid_thws = self._extract_processor_features( + base_output.images, "image_grid_thws" + ) + precomputed_features = self._extract_processor_features( + base_output.images, "precomputed_features" + ) + pixel_values = self._extract_processor_features( + base_output.images, "pixel_values" + ) + image_offsets = self.get_mm_items_offset( input_ids=input_ids, mm_token_id=self.im_token_id, ) + return { "input_ids": input_ids.tolist(), "mm_items": [ MultimodalDataItem( - pixel_values=ret["pixel_values"], - image_grid_thws=ret["image_grid_hws"], + pixel_values=pixel_values, + image_grid_thws=image_grid_thws, + precomputed_features=precomputed_features, modality=Modality.IMAGE, image_offsets=image_offsets, ) ], "im_token_id": self.im_token_id, - "im_start_id": self.im_start_id, - "im_end_id": self.im_end_id, - "im_content_id": self.im_content_id, } diff --git a/python/sglang/srt/managers/multimodal_processors/minicpm.py b/python/sglang/srt/managers/multimodal_processors/minicpm.py index dba7245e8..95bb231c3 100644 --- a/python/sglang/srt/managers/multimodal_processors/minicpm.py +++ b/python/sglang/srt/managers/multimodal_processors/minicpm.py @@ -42,7 +42,8 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): audio_data=audio_data, image_data=image_data, multimodal_tokens=MultimodalSpecialTokens( - image_token=self.image_token, audio_token=self.audio_token + image_token=self.image_token, + audio_token=self.audio_token, ), ) if base_output is None: diff --git a/python/sglang/srt/managers/multimodal_processors/qwen_vl.py b/python/sglang/srt/managers/multimodal_processors/qwen_vl.py index c47652c15..268d3350b 100644 --- a/python/sglang/srt/managers/multimodal_processors/qwen_vl.py +++ b/python/sglang/srt/managers/multimodal_processors/qwen_vl.py @@ -144,31 +144,14 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): if base_output.images: if images_are_preprocessed: - all_image_grid_thws = [ - item.image_grid_thws - for item in base_output.images - if item.image_grid_thws is not None - ] - all_pixel_values = [ - item.pixel_values - for item in base_output.images - if item.pixel_values is not None - ] - all_precomputed_features = [ - item.precomputed_features - for item in base_output.images - if item.precomputed_features is not None - ] - image_grid_thw = ( - torch.concat(all_image_grid_thws) if all_image_grid_thws else None + image_grid_thw = self._extract_processor_features( + base_output.images, "image_grid_thws" ) - pixel_values = ( - torch.concat(all_pixel_values) if all_pixel_values else None + precomputed_features = self._extract_processor_features( + base_output.images, "precomputed_features" ) - precomputed_features = ( - torch.concat(all_precomputed_features) - if all_precomputed_features - else None + pixel_values = self._extract_processor_features( + base_output.images, "pixel_values" ) else: image_grid_thw = ret["image_grid_thw"] diff --git a/test/srt/test_vlm_input_format.py b/test/srt/test_vlm_input_format.py index ccf8d33de..51f2f5592 100644 --- a/test/srt/test_vlm_input_format.py +++ b/test/srt/test_vlm_input_format.py @@ -7,6 +7,7 @@ import requests import torch from PIL import Image from transformers import ( + AutoModel, AutoProcessor, Gemma3ForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, @@ -51,6 +52,7 @@ class VLMInputTestBase: mem_fraction_static=0.8, enable_multimodal=True, disable_cuda_graph=True, + trust_remote_code=True, ) def tearDown(self): @@ -183,5 +185,32 @@ class TestGemmaUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestCa ) +class TestKimiVLImageUnderstandsImage( + VLMInputTestBase, unittest.IsolatedAsyncioTestCase +): + model_path = "moonshotai/Kimi-VL-A3B-Instruct" + chat_template = "kimi-vl" + + @classmethod + def _init_visual(cls): + model = AutoModel.from_pretrained(cls.model_path, trust_remote_code=True) + cls.vision_tower = model.vision_tower.eval().to(cls.device) + cls.mm_projector = model.multi_modal_projector.eval().to(cls.device) + + cls.visual = lambda tokenizer_output: cls.mm_projector( + cls.vision_tower( + pixel_values=tokenizer_output["pixel_values"], + grid_hws=tokenizer_output["image_grid_hws"], + ) + ) + + def _pixel_values_image_data(self, processor_output): + return dict( + modality="IMAGE", + image_grid_thws=processor_output["image_grid_hws"], + pixel_values=processor_output["pixel_values"], + ) + + if __name__ == "__main__": unittest.main()