Files
sglang/python/sglang/srt/multimodal/processors/gemma3n.py
2025-06-30 23:14:48 -07:00

83 lines
3.0 KiB
Python

# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import re
from typing import Dict, List, Optional, Union
from sglang.srt.managers.multimodal_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
from sglang.srt.models.gemma3n_mm import Gemma3nForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens
class Gemma3nSGLangProcessor(SGLangBaseProcessor):
"""Multimodal processor for Gemma3n supporting image and audio inputs."""
models = [Gemma3nForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<image_soft_token>"
self.IMAGE_TOKEN_REGEX = re.compile(
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
)
self.AUDIO_TOKEN = "<audio_soft_token>"
self.AUDIO_TOKEN_REGEX = re.compile(
r"<start_of_audio>(?:(?:<audio_soft_token>)*<end_of_audio>)?"
)
self.IM_TOKEN_ID = hf_config.image_token_id
self.IM_START_TOKEN_ID = hf_config.boi_token_id
self.IM_END_TOKEN_ID = hf_config.eoi_token_id
self.AUDIO_TOKEN_ID = hf_config.audio_token_id
self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id
self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id
async def process_mm_data_async(
self,
image_data: Optional[List[Union[str, bytes, Dict]]] = None,
audio_data: Optional[List[Union[str, bytes, Dict]]] = None,
input_text: str = "",
request_obj=None,
max_req_input_len: int = 0,
*args,
**kwargs,
):
"""Process multimodal data including images and audio."""
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
audio_data=audio_data,
max_req_input_len=max_req_input_len,
multimodal_tokens=MultimodalSpecialTokens(
image_token=self.IMAGE_TOKEN,
image_token_regex=self.IMAGE_TOKEN_REGEX,
audio_token=self.AUDIO_TOKEN,
audio_token_regex=self.AUDIO_TOKEN_REGEX,
),
)
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
return {
"input_ids": input_ids.tolist(),
"mm_items": mm_items,
"im_token_id": self.IM_TOKEN_ID,
"audio_token_id": self.AUDIO_TOKEN_ID,
}