Supported precomputed feature for Kimi VL (#6599)
This commit is contained in:
@@ -5,7 +5,7 @@ import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -382,3 +382,17 @@ class BaseMultimodalProcessor(ABC):
|
||||
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
|
||||
)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _extract_processor_features(
|
||||
items: List[Any], attr_name: str
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Helper function to concat extracted attributes from processor output.
|
||||
"""
|
||||
values = [
|
||||
getattr(item, attr_name)
|
||||
for item in items
|
||||
if getattr(item, attr_name) is not None
|
||||
]
|
||||
return torch.concat(values) if values else None
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from typing import List, Union
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
@@ -17,20 +20,12 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "<|media_pad|>"
|
||||
self.IMAGE_TOKEN_REGEX = re.compile(r"(?:<\|media_pad\|>)+")
|
||||
self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
|
||||
|
||||
self.im_start = "<|media_start|>"
|
||||
self.im_start_id = _processor.tokenizer.convert_tokens_to_ids(self.im_start)
|
||||
|
||||
self.im_end = "<|media_end|>"
|
||||
self.im_end_id = _processor.tokenizer.convert_tokens_to_ids(self.im_end)
|
||||
|
||||
self.im_content = "<|media_content|>"
|
||||
self.im_content_id = _processor.tokenizer.convert_tokens_to_ids(self.im_content)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
image_data: List[Union[str, bytes, Dict]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
@@ -45,30 +40,54 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX
|
||||
),
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
)
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
|
||||
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
|
||||
if not images_are_preprocessed:
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
)
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
image_grid_thws = ret["image_grid_hws"]
|
||||
pixel_values = ret["pixel_values"]
|
||||
precomputed_features = None
|
||||
else:
|
||||
input_ids = self._processor.tokenizer(
|
||||
base_output.input_text,
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
).input_ids.flatten()
|
||||
|
||||
image_grid_thws = self._extract_processor_features(
|
||||
base_output.images, "image_grid_thws"
|
||||
)
|
||||
precomputed_features = self._extract_processor_features(
|
||||
base_output.images, "precomputed_features"
|
||||
)
|
||||
pixel_values = self._extract_processor_features(
|
||||
base_output.images, "pixel_values"
|
||||
)
|
||||
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.im_token_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [
|
||||
MultimodalDataItem(
|
||||
pixel_values=ret["pixel_values"],
|
||||
image_grid_thws=ret["image_grid_hws"],
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thws=image_grid_thws,
|
||||
precomputed_features=precomputed_features,
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
)
|
||||
],
|
||||
"im_token_id": self.im_token_id,
|
||||
"im_start_id": self.im_start_id,
|
||||
"im_end_id": self.im_end_id,
|
||||
"im_content_id": self.im_content_id,
|
||||
}
|
||||
|
||||
@@ -42,7 +42,8 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
audio_data=audio_data,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self.image_token, audio_token=self.audio_token
|
||||
image_token=self.image_token,
|
||||
audio_token=self.audio_token,
|
||||
),
|
||||
)
|
||||
if base_output is None:
|
||||
|
||||
@@ -144,31 +144,14 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
if base_output.images:
|
||||
if images_are_preprocessed:
|
||||
all_image_grid_thws = [
|
||||
item.image_grid_thws
|
||||
for item in base_output.images
|
||||
if item.image_grid_thws is not None
|
||||
]
|
||||
all_pixel_values = [
|
||||
item.pixel_values
|
||||
for item in base_output.images
|
||||
if item.pixel_values is not None
|
||||
]
|
||||
all_precomputed_features = [
|
||||
item.precomputed_features
|
||||
for item in base_output.images
|
||||
if item.precomputed_features is not None
|
||||
]
|
||||
image_grid_thw = (
|
||||
torch.concat(all_image_grid_thws) if all_image_grid_thws else None
|
||||
image_grid_thw = self._extract_processor_features(
|
||||
base_output.images, "image_grid_thws"
|
||||
)
|
||||
pixel_values = (
|
||||
torch.concat(all_pixel_values) if all_pixel_values else None
|
||||
precomputed_features = self._extract_processor_features(
|
||||
base_output.images, "precomputed_features"
|
||||
)
|
||||
precomputed_features = (
|
||||
torch.concat(all_precomputed_features)
|
||||
if all_precomputed_features
|
||||
else None
|
||||
pixel_values = self._extract_processor_features(
|
||||
base_output.images, "pixel_values"
|
||||
)
|
||||
else:
|
||||
image_grid_thw = ret["image_grid_thw"]
|
||||
|
||||
Reference in New Issue
Block a user