Supported precomputed feature for Kimi VL (#6599)
This commit is contained in:
@@ -5,7 +5,7 @@ import multiprocessing as mp
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -382,3 +382,17 @@ class BaseMultimodalProcessor(ABC):
|
|||||||
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
|
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
|
||||||
)
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_processor_features(
|
||||||
|
items: List[Any], attr_name: str
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Helper function to concat extracted attributes from processor output.
|
||||||
|
"""
|
||||||
|
values = [
|
||||||
|
getattr(item, attr_name)
|
||||||
|
for item in items
|
||||||
|
if getattr(item, attr_name) is not None
|
||||||
|
]
|
||||||
|
return torch.concat(values) if values else None
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
from typing import List, Union
|
import re
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||||
@@ -17,20 +20,12 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
|||||||
def __init__(self, hf_config, server_args, _processor):
|
def __init__(self, hf_config, server_args, _processor):
|
||||||
super().__init__(hf_config, server_args, _processor)
|
super().__init__(hf_config, server_args, _processor)
|
||||||
self.IMAGE_TOKEN = "<|media_pad|>"
|
self.IMAGE_TOKEN = "<|media_pad|>"
|
||||||
|
self.IMAGE_TOKEN_REGEX = re.compile(r"(?:<\|media_pad\|>)+")
|
||||||
self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
|
self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
|
||||||
|
|
||||||
self.im_start = "<|media_start|>"
|
|
||||||
self.im_start_id = _processor.tokenizer.convert_tokens_to_ids(self.im_start)
|
|
||||||
|
|
||||||
self.im_end = "<|media_end|>"
|
|
||||||
self.im_end_id = _processor.tokenizer.convert_tokens_to_ids(self.im_end)
|
|
||||||
|
|
||||||
self.im_content = "<|media_content|>"
|
|
||||||
self.im_content_id = _processor.tokenizer.convert_tokens_to_ids(self.im_content)
|
|
||||||
|
|
||||||
async def process_mm_data_async(
|
async def process_mm_data_async(
|
||||||
self,
|
self,
|
||||||
image_data: List[Union[str, bytes]],
|
image_data: List[Union[str, bytes, Dict]],
|
||||||
input_text,
|
input_text,
|
||||||
request_obj,
|
request_obj,
|
||||||
max_req_input_len,
|
max_req_input_len,
|
||||||
@@ -45,30 +40,54 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
|||||||
base_output = self.load_mm_data(
|
base_output = self.load_mm_data(
|
||||||
prompt=input_text,
|
prompt=input_text,
|
||||||
image_data=image_data,
|
image_data=image_data,
|
||||||
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
|
multimodal_tokens=MultimodalSpecialTokens(
|
||||||
|
image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX
|
||||||
|
),
|
||||||
max_req_input_len=max_req_input_len,
|
max_req_input_len=max_req_input_len,
|
||||||
)
|
)
|
||||||
ret = self.process_mm_data(
|
|
||||||
input_text=base_output.input_text,
|
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
|
||||||
images=base_output.images,
|
if not images_are_preprocessed:
|
||||||
)
|
ret = self.process_mm_data(
|
||||||
input_ids = ret["input_ids"].flatten()
|
input_text=base_output.input_text,
|
||||||
|
images=base_output.images,
|
||||||
|
)
|
||||||
|
input_ids = ret["input_ids"].flatten()
|
||||||
|
image_grid_thws = ret["image_grid_hws"]
|
||||||
|
pixel_values = ret["pixel_values"]
|
||||||
|
precomputed_features = None
|
||||||
|
else:
|
||||||
|
input_ids = self._processor.tokenizer(
|
||||||
|
base_output.input_text,
|
||||||
|
return_tensors="pt",
|
||||||
|
add_special_tokens=True,
|
||||||
|
).input_ids.flatten()
|
||||||
|
|
||||||
|
image_grid_thws = self._extract_processor_features(
|
||||||
|
base_output.images, "image_grid_thws"
|
||||||
|
)
|
||||||
|
precomputed_features = self._extract_processor_features(
|
||||||
|
base_output.images, "precomputed_features"
|
||||||
|
)
|
||||||
|
pixel_values = self._extract_processor_features(
|
||||||
|
base_output.images, "pixel_values"
|
||||||
|
)
|
||||||
|
|
||||||
image_offsets = self.get_mm_items_offset(
|
image_offsets = self.get_mm_items_offset(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
mm_token_id=self.im_token_id,
|
mm_token_id=self.im_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids.tolist(),
|
"input_ids": input_ids.tolist(),
|
||||||
"mm_items": [
|
"mm_items": [
|
||||||
MultimodalDataItem(
|
MultimodalDataItem(
|
||||||
pixel_values=ret["pixel_values"],
|
pixel_values=pixel_values,
|
||||||
image_grid_thws=ret["image_grid_hws"],
|
image_grid_thws=image_grid_thws,
|
||||||
|
precomputed_features=precomputed_features,
|
||||||
modality=Modality.IMAGE,
|
modality=Modality.IMAGE,
|
||||||
image_offsets=image_offsets,
|
image_offsets=image_offsets,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
"im_token_id": self.im_token_id,
|
"im_token_id": self.im_token_id,
|
||||||
"im_start_id": self.im_start_id,
|
|
||||||
"im_end_id": self.im_end_id,
|
|
||||||
"im_content_id": self.im_content_id,
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,7 +42,8 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
|||||||
audio_data=audio_data,
|
audio_data=audio_data,
|
||||||
image_data=image_data,
|
image_data=image_data,
|
||||||
multimodal_tokens=MultimodalSpecialTokens(
|
multimodal_tokens=MultimodalSpecialTokens(
|
||||||
image_token=self.image_token, audio_token=self.audio_token
|
image_token=self.image_token,
|
||||||
|
audio_token=self.audio_token,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if base_output is None:
|
if base_output is None:
|
||||||
|
|||||||
@@ -144,31 +144,14 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
|
|
||||||
if base_output.images:
|
if base_output.images:
|
||||||
if images_are_preprocessed:
|
if images_are_preprocessed:
|
||||||
all_image_grid_thws = [
|
image_grid_thw = self._extract_processor_features(
|
||||||
item.image_grid_thws
|
base_output.images, "image_grid_thws"
|
||||||
for item in base_output.images
|
|
||||||
if item.image_grid_thws is not None
|
|
||||||
]
|
|
||||||
all_pixel_values = [
|
|
||||||
item.pixel_values
|
|
||||||
for item in base_output.images
|
|
||||||
if item.pixel_values is not None
|
|
||||||
]
|
|
||||||
all_precomputed_features = [
|
|
||||||
item.precomputed_features
|
|
||||||
for item in base_output.images
|
|
||||||
if item.precomputed_features is not None
|
|
||||||
]
|
|
||||||
image_grid_thw = (
|
|
||||||
torch.concat(all_image_grid_thws) if all_image_grid_thws else None
|
|
||||||
)
|
)
|
||||||
pixel_values = (
|
precomputed_features = self._extract_processor_features(
|
||||||
torch.concat(all_pixel_values) if all_pixel_values else None
|
base_output.images, "precomputed_features"
|
||||||
)
|
)
|
||||||
precomputed_features = (
|
pixel_values = self._extract_processor_features(
|
||||||
torch.concat(all_precomputed_features)
|
base_output.images, "pixel_values"
|
||||||
if all_precomputed_features
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image_grid_thw = ret["image_grid_thw"]
|
image_grid_thw = ret["image_grid_thw"]
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import requests
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
AutoModel,
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
Gemma3ForConditionalGeneration,
|
Gemma3ForConditionalGeneration,
|
||||||
Qwen2_5_VLForConditionalGeneration,
|
Qwen2_5_VLForConditionalGeneration,
|
||||||
@@ -51,6 +52,7 @@ class VLMInputTestBase:
|
|||||||
mem_fraction_static=0.8,
|
mem_fraction_static=0.8,
|
||||||
enable_multimodal=True,
|
enable_multimodal=True,
|
||||||
disable_cuda_graph=True,
|
disable_cuda_graph=True,
|
||||||
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
@@ -183,5 +185,32 @@ class TestGemmaUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestCa
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestKimiVLImageUnderstandsImage(
|
||||||
|
VLMInputTestBase, unittest.IsolatedAsyncioTestCase
|
||||||
|
):
|
||||||
|
model_path = "moonshotai/Kimi-VL-A3B-Instruct"
|
||||||
|
chat_template = "kimi-vl"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _init_visual(cls):
|
||||||
|
model = AutoModel.from_pretrained(cls.model_path, trust_remote_code=True)
|
||||||
|
cls.vision_tower = model.vision_tower.eval().to(cls.device)
|
||||||
|
cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)
|
||||||
|
|
||||||
|
cls.visual = lambda tokenizer_output: cls.mm_projector(
|
||||||
|
cls.vision_tower(
|
||||||
|
pixel_values=tokenizer_output["pixel_values"],
|
||||||
|
grid_hws=tokenizer_output["image_grid_hws"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _pixel_values_image_data(self, processor_output):
|
||||||
|
return dict(
|
||||||
|
modality="IMAGE",
|
||||||
|
image_grid_thws=processor_output["image_grid_hws"],
|
||||||
|
pixel_values=processor_output["pixel_values"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user