Supported precomputed feature for Kimi VL (#6599)

This commit is contained in:
Lifu Huang
2025-05-26 01:24:13 -07:00
committed by GitHub
parent 501efc3d36
commit 0d503090aa
5 changed files with 93 additions and 47 deletions

View File

@@ -5,7 +5,7 @@ import multiprocessing as mp
import os import os
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@@ -382,3 +382,17 @@ class BaseMultimodalProcessor(ABC):
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed." "Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
) )
return ret return ret
@staticmethod
def _extract_processor_features(
items: List[Any], attr_name: str
) -> Optional[torch.Tensor]:
"""
Helper function to concat extracted attributes from processor output.
"""
values = [
getattr(item, attr_name)
for item in items
if getattr(item, attr_name) is not None
]
return torch.concat(values) if values else None

View File

@@ -1,4 +1,7 @@
from typing import List, Union import re
from typing import Any, Dict, List, Optional, Union
import torch
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor, BaseMultimodalProcessor as SGLangBaseProcessor,
@@ -17,20 +20,12 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<|media_pad|>" self.IMAGE_TOKEN = "<|media_pad|>"
self.IMAGE_TOKEN_REGEX = re.compile(r"(?:<\|media_pad\|>)+")
self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN) self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
self.im_start = "<|media_start|>"
self.im_start_id = _processor.tokenizer.convert_tokens_to_ids(self.im_start)
self.im_end = "<|media_end|>"
self.im_end_id = _processor.tokenizer.convert_tokens_to_ids(self.im_end)
self.im_content = "<|media_content|>"
self.im_content_id = _processor.tokenizer.convert_tokens_to_ids(self.im_content)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes, Dict]],
input_text, input_text,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
@@ -45,30 +40,54 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN), multimodal_tokens=MultimodalSpecialTokens(
image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX
),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
ret = self.process_mm_data(
input_text=base_output.input_text, images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
images=base_output.images, if not images_are_preprocessed:
) ret = self.process_mm_data(
input_ids = ret["input_ids"].flatten() input_text=base_output.input_text,
images=base_output.images,
)
input_ids = ret["input_ids"].flatten()
image_grid_thws = ret["image_grid_hws"]
pixel_values = ret["pixel_values"]
precomputed_features = None
else:
input_ids = self._processor.tokenizer(
base_output.input_text,
return_tensors="pt",
add_special_tokens=True,
).input_ids.flatten()
image_grid_thws = self._extract_processor_features(
base_output.images, "image_grid_thws"
)
precomputed_features = self._extract_processor_features(
base_output.images, "precomputed_features"
)
pixel_values = self._extract_processor_features(
base_output.images, "pixel_values"
)
image_offsets = self.get_mm_items_offset( image_offsets = self.get_mm_items_offset(
input_ids=input_ids, input_ids=input_ids,
mm_token_id=self.im_token_id, mm_token_id=self.im_token_id,
) )
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
"mm_items": [ "mm_items": [
MultimodalDataItem( MultimodalDataItem(
pixel_values=ret["pixel_values"], pixel_values=pixel_values,
image_grid_thws=ret["image_grid_hws"], image_grid_thws=image_grid_thws,
precomputed_features=precomputed_features,
modality=Modality.IMAGE, modality=Modality.IMAGE,
image_offsets=image_offsets, image_offsets=image_offsets,
) )
], ],
"im_token_id": self.im_token_id, "im_token_id": self.im_token_id,
"im_start_id": self.im_start_id,
"im_end_id": self.im_end_id,
"im_content_id": self.im_content_id,
} }

View File

@@ -42,7 +42,8 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data=audio_data, audio_data=audio_data,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens( multimodal_tokens=MultimodalSpecialTokens(
image_token=self.image_token, audio_token=self.audio_token image_token=self.image_token,
audio_token=self.audio_token,
), ),
) )
if base_output is None: if base_output is None:

View File

@@ -144,31 +144,14 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
if base_output.images: if base_output.images:
if images_are_preprocessed: if images_are_preprocessed:
all_image_grid_thws = [ image_grid_thw = self._extract_processor_features(
item.image_grid_thws base_output.images, "image_grid_thws"
for item in base_output.images
if item.image_grid_thws is not None
]
all_pixel_values = [
item.pixel_values
for item in base_output.images
if item.pixel_values is not None
]
all_precomputed_features = [
item.precomputed_features
for item in base_output.images
if item.precomputed_features is not None
]
image_grid_thw = (
torch.concat(all_image_grid_thws) if all_image_grid_thws else None
) )
pixel_values = ( precomputed_features = self._extract_processor_features(
torch.concat(all_pixel_values) if all_pixel_values else None base_output.images, "precomputed_features"
) )
precomputed_features = ( pixel_values = self._extract_processor_features(
torch.concat(all_precomputed_features) base_output.images, "pixel_values"
if all_precomputed_features
else None
) )
else: else:
image_grid_thw = ret["image_grid_thw"] image_grid_thw = ret["image_grid_thw"]

View File

@@ -7,6 +7,7 @@ import requests
import torch import torch
from PIL import Image from PIL import Image
from transformers import ( from transformers import (
AutoModel,
AutoProcessor, AutoProcessor,
Gemma3ForConditionalGeneration, Gemma3ForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration,
@@ -51,6 +52,7 @@ class VLMInputTestBase:
mem_fraction_static=0.8, mem_fraction_static=0.8,
enable_multimodal=True, enable_multimodal=True,
disable_cuda_graph=True, disable_cuda_graph=True,
trust_remote_code=True,
) )
def tearDown(self): def tearDown(self):
@@ -183,5 +185,32 @@ class TestGemmaUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestCa
) )
class TestKimiVLImageUnderstandsImage(
VLMInputTestBase, unittest.IsolatedAsyncioTestCase
):
model_path = "moonshotai/Kimi-VL-A3B-Instruct"
chat_template = "kimi-vl"
@classmethod
def _init_visual(cls):
model = AutoModel.from_pretrained(cls.model_path, trust_remote_code=True)
cls.vision_tower = model.vision_tower.eval().to(cls.device)
cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)
cls.visual = lambda tokenizer_output: cls.mm_projector(
cls.vision_tower(
pixel_values=tokenizer_output["pixel_values"],
grid_hws=tokenizer_output["image_grid_hws"],
)
)
def _pixel_values_image_data(self, processor_output):
return dict(
modality="IMAGE",
image_grid_thws=processor_output["image_grid_hws"],
pixel_values=processor_output["pixel_values"],
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()