refactor: minor refactors regarding multimodal processing (#6187)

This commit is contained in:
Mick
2025-05-18 13:53:20 +08:00
committed by GitHub
parent b3f3d610fd
commit 01dd39bac1
15 changed files with 140 additions and 98 deletions

View File

@@ -36,9 +36,21 @@ class BaseMultiModalProcessorOutput:
@dataclasses.dataclass
class MultimodalSpecialTokens:
image_token: Optional[str] = None
video_token: Optional[str] = None
audio_token: Optional[str] = None
image_token: Optional[Union[int, str, List[str]]] = None
video_token: Optional[Union[int, str, List[str]]] = None
audio_token: Optional[Union[int, str, List[str]]] = None
def convert_to_str(self, token: Union[str, int], processor) -> str:
if token is None:
return token
if isinstance(token, str):
return token
return processor.tokenizer.convert_ids_to_tokens([token])[0]
def convert_to_strs(self, processor):
self.image_token = self.convert_to_str(self.image_token, processor)
self.video_token = self.convert_to_str(self.video_token, processor)
self.audio_token = self.convert_to_str(self.audio_token, processor)
image_token_regex: Optional[re.Pattern] = None
video_token_regex: Optional[re.Pattern] = None
@@ -74,6 +86,7 @@ class BaseMultimodalProcessor(ABC):
def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config
self._processor = _processor
self.arch = hf_config.architectures[0]
self.server_args = server_args
# FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330
@@ -260,19 +273,10 @@ class BaseMultimodalProcessor(ABC):
"""
if not return_text:
raise NotImplementedError()
if image_data is None:
image_data = []
if isinstance(multimodal_tokens.image_token, int):
multimodal_tokens.image_token = re.compile(
re.escape(
self._processor.tokenizer.convert_ids_to_tokens(
multimodal_tokens.image_token
)
)
)
else:
multimodal_tokens.image_token = multimodal_tokens.image_token
multimodal_tokens.convert_to_strs(self._processor)
multimodal_tokens_pattern = multimodal_tokens.collect()
if isinstance(prompt, list) and return_text:
@@ -332,9 +336,9 @@ class BaseMultimodalProcessor(ABC):
new_text += text_part
out = BaseMultiModalProcessorOutput(
input_text=new_text,
images=images,
audios=audios,
input_text=new_text,
)
out.normalize()
return out

View File

@@ -1,7 +1,6 @@
from typing import List, Union
import torch
from transformers import BaseImageProcessorFast
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
@@ -21,33 +20,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
self.image_token = "(<image>./</image>)"
self.audio_token = "(<audio>./</audio>)"
def process_data_task(self, input_text, images=None, audios=None):
if isinstance(images, list) and len(images) == 0:
images = None
if isinstance(audios, list) and len(audios) == 0:
audios = None
processor = self._processor
args = {}
if isinstance(processor, BaseImageProcessorFast):
args["device"] = "cuda"
result = self._processor.__call__(
text=input_text,
images=images,
audios=audios,
return_tensors="pt",
chunk_input=True,
**args,
)
return {
"input_ids": result.input_ids,
"pixel_values": getattr(result, "pixel_values", None),
"tgt_sizes": getattr(result, "tgt_sizes", None),
"audio_features": getattr(result, "audio_features", None),
"audio_feature_lens": getattr(result, "audio_feature_lens", None),
"audio_bounds": getattr(result, "audio_bounds", None),
}
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],