Supported precomputed feature for Kimi VL (#6599)

This commit is contained in:
Lifu Huang
2025-05-26 01:24:13 -07:00
committed by GitHub
parent 501efc3d36
commit 0d503090aa
5 changed files with 93 additions and 47 deletions

View File

@@ -5,7 +5,7 @@ import multiprocessing as mp
import os
import re
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -382,3 +382,17 @@ class BaseMultimodalProcessor(ABC):
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
)
return ret
@staticmethod
def _extract_processor_features(
items: List[Any], attr_name: str
) -> Optional[torch.Tensor]:
"""
Helper function to concat extracted attributes from processor output.
"""
values = [
getattr(item, attr_name)
for item in items
if getattr(item, attr_name) is not None
]
return torch.concat(values) if values else None

View File

@@ -1,4 +1,7 @@
from typing import List, Union
import re
from typing import Any, Dict, List, Optional, Union
import torch
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
@@ -17,20 +20,12 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<|media_pad|>"
self.IMAGE_TOKEN_REGEX = re.compile(r"(?:<\|media_pad\|>)+")
self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
self.im_start = "<|media_start|>"
self.im_start_id = _processor.tokenizer.convert_tokens_to_ids(self.im_start)
self.im_end = "<|media_end|>"
self.im_end_id = _processor.tokenizer.convert_tokens_to_ids(self.im_end)
self.im_content = "<|media_content|>"
self.im_content_id = _processor.tokenizer.convert_tokens_to_ids(self.im_content)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
image_data: List[Union[str, bytes, Dict]],
input_text,
request_obj,
max_req_input_len,
@@ -45,30 +40,54 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
multimodal_tokens=MultimodalSpecialTokens(
image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX
),
max_req_input_len=max_req_input_len,
)
ret = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
)
input_ids = ret["input_ids"].flatten()
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
if not images_are_preprocessed:
ret = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
)
input_ids = ret["input_ids"].flatten()
image_grid_thws = ret["image_grid_hws"]
pixel_values = ret["pixel_values"]
precomputed_features = None
else:
input_ids = self._processor.tokenizer(
base_output.input_text,
return_tensors="pt",
add_special_tokens=True,
).input_ids.flatten()
image_grid_thws = self._extract_processor_features(
base_output.images, "image_grid_thws"
)
precomputed_features = self._extract_processor_features(
base_output.images, "precomputed_features"
)
pixel_values = self._extract_processor_features(
base_output.images, "pixel_values"
)
image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.im_token_id,
)
return {
"input_ids": input_ids.tolist(),
"mm_items": [
MultimodalDataItem(
pixel_values=ret["pixel_values"],
image_grid_thws=ret["image_grid_hws"],
pixel_values=pixel_values,
image_grid_thws=image_grid_thws,
precomputed_features=precomputed_features,
modality=Modality.IMAGE,
image_offsets=image_offsets,
)
],
"im_token_id": self.im_token_id,
"im_start_id": self.im_start_id,
"im_end_id": self.im_end_id,
"im_content_id": self.im_content_id,
}

View File

@@ -42,7 +42,8 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data=audio_data,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(
image_token=self.image_token, audio_token=self.audio_token
image_token=self.image_token,
audio_token=self.audio_token,
),
)
if base_output is None:

View File

@@ -144,31 +144,14 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
if base_output.images:
if images_are_preprocessed:
all_image_grid_thws = [
item.image_grid_thws
for item in base_output.images
if item.image_grid_thws is not None
]
all_pixel_values = [
item.pixel_values
for item in base_output.images
if item.pixel_values is not None
]
all_precomputed_features = [
item.precomputed_features
for item in base_output.images
if item.precomputed_features is not None
]
image_grid_thw = (
torch.concat(all_image_grid_thws) if all_image_grid_thws else None
image_grid_thw = self._extract_processor_features(
base_output.images, "image_grid_thws"
)
pixel_values = (
torch.concat(all_pixel_values) if all_pixel_values else None
precomputed_features = self._extract_processor_features(
base_output.images, "precomputed_features"
)
precomputed_features = (
torch.concat(all_precomputed_features)
if all_precomputed_features
else None
pixel_values = self._extract_processor_features(
base_output.images, "pixel_values"
)
else:
image_grid_thw = ret["image_grid_thw"]