chore: remove vlm unnecessary import (#7541)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
Co-authored-by: yhyang201 <yhyang201@gmail.com>
Co-authored-by: Mick <mickjagger19@icloud.com>
This commit is contained in:
Xinyuan Tong
2025-06-26 01:38:15 -07:00
committed by GitHub
parent 4d67025a1d
commit 9b00990bea
8 changed files with 2621 additions and 10 deletions

View File

@@ -23,6 +23,7 @@ class MultimodalInputFormat(Enum):
RAW_IMAGES = "raw_images"
PRECOMPUTED_FEATURES = "precomputed_features"
PIXEL_VALUES = "pixel_values"
AUDIO = "audio"
@dataclasses.dataclass
@@ -441,10 +442,13 @@ class BaseMultimodalProcessor(ABC):
has_image = False
has_pixel_values = False
has_precomputed_features = False
has_audio = False
for mm_input in mm_inputs:
if isinstance(mm_input, Image.Image):
has_image = True
elif isinstance(mm_input, np.ndarray):
has_audio = True
elif isinstance(mm_input, dict):
if mm_input.get("precomputed_features", None) is not None:
has_precomputed_features = True
@@ -461,13 +465,13 @@ class BaseMultimodalProcessor(ABC):
# Validate format consistency
format_count = sum(
[has_image, has_pixel_values, has_precomputed_features]
[has_image, has_pixel_values, has_precomputed_features, has_audio]
)
if format_count > 1:
raise ValueError(
"Unsupported: mixture of multimodal input formats. "
f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
f"precomputed_features={has_precomputed_features}"
f"precomputed_features={has_precomputed_features}, audio={has_audio}"
)
if has_image:
@@ -476,6 +480,8 @@ class BaseMultimodalProcessor(ABC):
return MultimodalInputFormat.PRECOMPUTED_FEATURES
elif has_pixel_values:
return MultimodalInputFormat.PIXEL_VALUES
elif has_audio:
return MultimodalInputFormat.AUDIO
else:
raise ValueError("No valid multimodal input format found")
except Exception as e:
@@ -521,20 +527,47 @@ class BaseMultimodalProcessor(ABC):
input_ids = tokenize_text(base_output.input_text)
return combined_mm_item, input_ids
def process_audio(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process inputs with audio."""
ret = self.process_mm_data(
input_text=base_output.input_text,
audio=base_output.audios, # Note: "audio" is for gemma3n only
)
combined_mm_item = MultimodalDataItem(modality=Modality.AUDIO)
for key, value in ret.items():
if key != "input_ids" and hasattr(combined_mm_item, key):
setattr(combined_mm_item, key, value)
input_ids = ret["input_ids"].flatten()
return combined_mm_item, input_ids
def finalize_mm_item(
combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
) -> MultimodalDataItem:
"""Apply common post-processing to the multimodal item."""
combined_mm_item.image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.IM_TOKEN_ID,
)
if combined_mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
combined_mm_item.image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.IM_TOKEN_ID,
)
elif combined_mm_item.modality == Modality.AUDIO:
combined_mm_item.audio_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.AUDIO_TOKEN_ID,
)
elif combined_mm_item.modality == Modality.VIDEO:
combined_mm_item.video_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.VIDEO_TOKEN_ID,
)
else:
raise ValueError(f"Unknown modality: {combined_mm_item.modality}")
return combined_mm_item
# Main logic
mm_inputs = base_output.images
# Main logic - determine input type and handle text-only case
mm_inputs = base_output.images or base_output.audios
if not mm_inputs:
# Return text-only case
input_ids = tokenize_text(base_output.input_text)
return None, input_ids
@@ -548,6 +581,8 @@ class BaseMultimodalProcessor(ABC):
combined_mm_item, input_ids = process_precomputed_features(base_output)
elif input_format == MultimodalInputFormat.PIXEL_VALUES:
combined_mm_item, input_ids = process_pixel_values(base_output)
elif input_format == MultimodalInputFormat.AUDIO:
combined_mm_item, input_ids = process_audio(base_output)
else:
raise ValueError(f"Unknown input format: {input_format}")

View File

@@ -0,0 +1,97 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import re
from typing import Dict, List, Optional, Union
from sglang.srt.managers.multimodal_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens,
)
from sglang.srt.models.gemma3n_mm import Gemma3nForConditionalGeneration
class Gemma3nSGLangProcessor(SGLangBaseProcessor):
"""Multimodal processor for Gemma3n supporting image and audio inputs."""
models = [Gemma3nForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<image_soft_token>"
self.IMAGE_TOKEN_REGEX = re.compile(
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
)
self.AUDIO_TOKEN = "<audio_soft_token>"
self.AUDIO_TOKEN_REGEX = re.compile(
r"<start_of_audio>(?:(?:<audio_soft_token>)*<end_of_audio>)?"
)
self.IM_TOKEN_ID = hf_config.image_token_id
self.IM_START_TOKEN_ID = hf_config.boi_token_id
self.IM_END_TOKEN_ID = hf_config.eoi_token_id
self.AUDIO_TOKEN_ID = hf_config.audio_token_id
self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id
self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id
async def process_mm_data_async(
self,
image_data: Optional[List[Union[str, bytes, Dict]]] = None,
audio_data: Optional[List[Union[str, bytes, Dict]]] = None,
input_text: str = "",
request_obj=None,
max_req_input_len: int = 0,
*args,
**kwargs,
):
"""Process multimodal data including images and audio."""
audio_data = request_obj.audio_data
if not image_data and not audio_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(audio_data, str):
audio_data = [audio_data]
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
audio_data=audio_data,
max_req_input_len=max_req_input_len,
multimodal_tokens=MultimodalSpecialTokens(
image_token=self.IMAGE_TOKEN,
image_token_regex=self.IMAGE_TOKEN_REGEX,
audio_token=self.AUDIO_TOKEN,
audio_token_regex=self.AUDIO_TOKEN_REGEX,
),
)
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
return {
"input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
"audio_start_id": self.AUDIO_START_TOKEN_ID,
"audio_end_id": self.AUDIO_END_TOKEN_ID,
}

View File

@@ -214,6 +214,10 @@ class MultimodalDataItem:
audio_feature_lens: Optional[List[torch.Tensor]] = None
audio_offsets: Optional[List[Tuple[int, int]]] = None
# gemma3n related
input_features: Optional[torch.Tensor] = None
input_features_mask: Optional[torch.Tensor] = None
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
@staticmethod
@@ -277,7 +281,10 @@ class MultimodalDataItem:
if self.precomputed_features is not None:
self.hash = hash_feature(self.precomputed_features)
elif self.is_audio():
self.hash = hash_feature(self.audio_features)
if self.audio_features is not None:
self.hash = hash_feature(self.audio_features)
elif self.input_features is not None:
self.hash = hash_feature(self.input_features)
else:
self.hash = hash_feature(self.pixel_values)
@@ -288,6 +295,7 @@ class MultimodalDataItem:
return (self.modality == Modality.AUDIO) and (
self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.audio_features)
or not MultimodalDataItem.is_empty_list(self.input_features)
)
def is_image(self):