refactor: minor refactors regarding multimodal processing (#6187)
This commit is contained in:
@@ -22,13 +22,15 @@ from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from sglang.srt.mm_utils import has_valid_data
|
||||
|
||||
# handle serialization of Image for pydantic
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
else:
|
||||
Image = Any
|
||||
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason, flatten_nested_list
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@@ -104,6 +106,9 @@ class GenerateReqInput:
|
||||
bootstrap_port: Optional[Union[List[int], int]] = None
|
||||
bootstrap_room: Optional[Union[List[int], int]] = None
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
"""
|
||||
Normalize the batch size and arguments for the request.
|
||||
@@ -487,6 +492,9 @@ class EmbeddingReqInput:
|
||||
# The modalities of the image data [image, multi-images, video]
|
||||
modalities: Optional[List[str]] = None
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
# at least one of text, input_ids, or image should be provided
|
||||
if self.text is None and self.input_ids is None and self.image_data is None:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Multi-modality utils
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
@@ -41,11 +42,26 @@ class MultiModalityDataPaddingPattern:
|
||||
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
|
||||
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
|
||||
|
||||
The padded value in a region enclosed by a token pair with be the same one, as the MultimodalDataItem's pad value
|
||||
|
||||
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
|
||||
"""
|
||||
|
||||
def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
data_token_pairs: Optional[List[Tuple[int, int]]],
|
||||
data_start_token_ids: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
Args:
|
||||
data_start_token_ids marks the start of a single multimodal data
|
||||
See Minicpmo's slice_start_id for example
|
||||
"""
|
||||
self.data_token_id_pairs = data_token_pairs
|
||||
self.data_start_token_ids = data_start_token_ids or [
|
||||
s for s, _e in data_token_pairs
|
||||
]
|
||||
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
||||
@@ -79,7 +95,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
||||
for start_idx, end_idx in zip(start_indices, end_indices):
|
||||
padded_ids.extend(input_ids[last_idx : start_idx + 1])
|
||||
|
||||
if input_ids[start_idx] in start_token_ids:
|
||||
if input_ids[start_idx] in self.data_start_token_ids:
|
||||
data_idx += 1
|
||||
mm_inputs.data_offsets += [start_idx]
|
||||
|
||||
@@ -170,7 +186,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
||||
output_ids_tensor[start_idx:end_idx] = pad_value
|
||||
else:
|
||||
logger.warning(f"Skipping region {i} due to None pad_value.")
|
||||
|
||||
return output_ids_tensor.tolist()
|
||||
|
||||
|
||||
@@ -202,7 +217,7 @@ def get_embedding_and_mask(
|
||||
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
|
||||
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
|
||||
logger.warning(
|
||||
f"Number of tokens in multimodal embedding does not match those in the input text."
|
||||
f"Number of tokens in multimodal embedding does not match those in the input text. "
|
||||
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
|
||||
"tokens from multimodal embeddings."
|
||||
)
|
||||
|
||||
@@ -36,9 +36,21 @@ class BaseMultiModalProcessorOutput:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MultimodalSpecialTokens:
|
||||
image_token: Optional[str] = None
|
||||
video_token: Optional[str] = None
|
||||
audio_token: Optional[str] = None
|
||||
image_token: Optional[Union[int, str, List[str]]] = None
|
||||
video_token: Optional[Union[int, str, List[str]]] = None
|
||||
audio_token: Optional[Union[int, str, List[str]]] = None
|
||||
|
||||
def convert_to_str(self, token: Union[str, int], processor) -> str:
|
||||
if token is None:
|
||||
return token
|
||||
if isinstance(token, str):
|
||||
return token
|
||||
return processor.tokenizer.convert_ids_to_tokens([token])[0]
|
||||
|
||||
def convert_to_strs(self, processor):
|
||||
self.image_token = self.convert_to_str(self.image_token, processor)
|
||||
self.video_token = self.convert_to_str(self.video_token, processor)
|
||||
self.audio_token = self.convert_to_str(self.audio_token, processor)
|
||||
|
||||
image_token_regex: Optional[re.Pattern] = None
|
||||
video_token_regex: Optional[re.Pattern] = None
|
||||
@@ -74,6 +86,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
self.hf_config = hf_config
|
||||
self._processor = _processor
|
||||
self.arch = hf_config.architectures[0]
|
||||
self.server_args = server_args
|
||||
# FIXME: not accurate, model and image specific
|
||||
self.NUM_TOKEN_PER_FRAME = 330
|
||||
@@ -260,19 +273,10 @@ class BaseMultimodalProcessor(ABC):
|
||||
"""
|
||||
if not return_text:
|
||||
raise NotImplementedError()
|
||||
|
||||
if image_data is None:
|
||||
image_data = []
|
||||
if isinstance(multimodal_tokens.image_token, int):
|
||||
multimodal_tokens.image_token = re.compile(
|
||||
re.escape(
|
||||
self._processor.tokenizer.convert_ids_to_tokens(
|
||||
multimodal_tokens.image_token
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
multimodal_tokens.image_token = multimodal_tokens.image_token
|
||||
|
||||
multimodal_tokens.convert_to_strs(self._processor)
|
||||
multimodal_tokens_pattern = multimodal_tokens.collect()
|
||||
|
||||
if isinstance(prompt, list) and return_text:
|
||||
@@ -332,9 +336,9 @@ class BaseMultimodalProcessor(ABC):
|
||||
new_text += text_part
|
||||
|
||||
out = BaseMultiModalProcessorOutput(
|
||||
input_text=new_text,
|
||||
images=images,
|
||||
audios=audios,
|
||||
input_text=new_text,
|
||||
)
|
||||
out.normalize()
|
||||
return out
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from transformers import BaseImageProcessorFast
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
@@ -21,33 +20,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
self.image_token = "(<image>./</image>)"
|
||||
self.audio_token = "(<audio>./</audio>)"
|
||||
|
||||
def process_data_task(self, input_text, images=None, audios=None):
|
||||
|
||||
if isinstance(images, list) and len(images) == 0:
|
||||
images = None
|
||||
if isinstance(audios, list) and len(audios) == 0:
|
||||
audios = None
|
||||
processor = self._processor
|
||||
args = {}
|
||||
if isinstance(processor, BaseImageProcessorFast):
|
||||
args["device"] = "cuda"
|
||||
result = self._processor.__call__(
|
||||
text=input_text,
|
||||
images=images,
|
||||
audios=audios,
|
||||
return_tensors="pt",
|
||||
chunk_input=True,
|
||||
**args,
|
||||
)
|
||||
return {
|
||||
"input_ids": result.input_ids,
|
||||
"pixel_values": getattr(result, "pixel_values", None),
|
||||
"tgt_sizes": getattr(result, "tgt_sizes", None),
|
||||
"audio_features": getattr(result, "audio_features", None),
|
||||
"audio_feature_lens": getattr(result, "audio_feature_lens", None),
|
||||
"audio_bounds": getattr(result, "audio_bounds", None),
|
||||
}
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
|
||||
@@ -324,8 +324,9 @@ class MultimodalInputs:
|
||||
video_token_id: Optional[int] = None
|
||||
|
||||
# audio
|
||||
audio_start_id: Optional[torch.Tensor] = None
|
||||
audio_end_id: Optional[torch.Tensor] = None
|
||||
audio_token_id: Optional[int] = None
|
||||
audio_start_id: Optional[int] = None
|
||||
audio_end_id: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(obj: dict):
|
||||
@@ -349,6 +350,7 @@ class MultimodalInputs:
|
||||
"slice_end_id",
|
||||
"audio_start_id",
|
||||
"audio_end_id",
|
||||
"audio_token_id",
|
||||
]
|
||||
for arg in optional_args:
|
||||
if arg in obj:
|
||||
|
||||
@@ -459,14 +459,16 @@ class TokenizerManager:
|
||||
)
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
|
||||
image_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
||||
image_data=obj.image_data,
|
||||
input_text=input_text or input_ids,
|
||||
request_obj=obj,
|
||||
max_req_input_len=self.max_req_input_len,
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
image_inputs: Optional[Dict] = None
|
||||
if obj.contains_mm_input():
|
||||
image_inputs = await self.mm_processor.process_mm_data_async(
|
||||
image_data=obj.image_data,
|
||||
input_text=input_text or input_ids,
|
||||
request_obj=obj,
|
||||
max_req_input_len=self.max_req_input_len,
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
|
||||
self._validate_token_len(obj, input_ids)
|
||||
return self._create_tokenized_object(
|
||||
|
||||
Reference in New Issue
Block a user