diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 44e22885c..5c44c4d49 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -21,7 +21,7 @@ class BaseMultiModalProcessorOutput: # input_text, with each frame of video/image represented with a image_token input_text: str - # frames loaded from image and video, in given order + # frames loaded from image, in given order images: Optional[list[Union[Image.Image, dict]]] = None # videos @@ -44,14 +44,26 @@ class BaseMultiModalProcessorOutput: @dataclasses.dataclass class MultimodalSpecialTokens: - image_token: Optional[Union[int, str, List[str]]] = None - video_token: Optional[Union[int, str, List[str]]] = None - audio_token: Optional[Union[int, str, List[str]]] = None + image_token: Optional[Union[str, List[str]]] = None + video_token: Optional[Union[str, List[str]]] = None + audio_token: Optional[Union[str, List[str]]] = None + + image_token_id: Optional[int] = None + video_token_id: Optional[int] = None + audio_token_id: Optional[int] = None image_token_regex: Optional[re.Pattern] = None video_token_regex: Optional[re.Pattern] = None audio_token_regex: Optional[re.Pattern] = None + combined_regex: Optional[re.Pattern] = None + + def build(self, processor): + self.convert_to_strs(processor) + self.parse_regex() + self.get_combined_regex() + return self + def convert_to_str(self, token: Union[str, int], processor) -> str: if token is None: return token @@ -60,11 +72,14 @@ class MultimodalSpecialTokens: return processor.tokenizer.convert_ids_to_tokens([token])[0] def convert_to_strs(self, processor): - self.image_token = self.convert_to_str(self.image_token, processor) - self.video_token = self.convert_to_str(self.video_token, processor) - self.audio_token = self.convert_to_str(self.audio_token, processor) + if not self.image_token: + self.image_token = self.convert_to_str(self.image_token_id, processor) + if not self.video_token: + self.video_token = self.convert_to_str(self.video_token_id, processor) + if not self.audio_token: + self.audio_token = self.convert_to_str(self.audio_token_id, processor) - def get_modality_of_token(self, token) -> Optional[Modality]: + def get_modality_of_token(self, token: str) -> Optional[Modality]: """ :return: the modality associated with the given token, if the token is a special_token or matches with the multimodal token regex """ @@ -94,7 +109,12 @@ class MultimodalSpecialTokens: if self.audio_token_regex is None and self.audio_token is not None: self.audio_token_regex = re.compile(re.escape(self.audio_token)) - def combine_regex(self) -> re.Pattern: + def get_combined_regex(self) -> re.Pattern: + """ + Builds and returns a regex, used to split input str into tokens (with mm special tokens) + """ + if self.combined_regex: + return self.combined_regex tokens = [ self.image_token_regex, self.video_token_regex, @@ -107,7 +127,8 @@ class MultimodalSpecialTokens: patterns.append(t.pattern) flags |= t.flags combined = "(" + "|".join(f"(?:{p})" for p in patterns) + ")" - return re.compile(combined, flags) + self.combined_regex = re.compile(combined, flags) + return self.combined_regex class BaseMultimodalProcessor(ABC): @@ -341,9 +362,8 @@ class BaseMultimodalProcessor(ABC): discard_alpha_channel: if True, discards the alpha channel in the returned images """ - multimodal_tokens.convert_to_strs(self._processor) - multimodal_tokens.parse_regex() - multimodal_tokens_pattern = multimodal_tokens.combine_regex() + multimodal_tokens_pattern = multimodal_tokens.get_combined_regex() + if isinstance(prompt, list) and return_text: assert len(prompt) and isinstance(prompt[0], int) prompt = self._processor.tokenizer.decode(prompt) @@ -445,7 +465,6 @@ class BaseMultimodalProcessor(ABC): return result = [(2,4),(6,7)] """ mask = input_ids == mm_token_id - start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0] end_positions = (mask & ~torch.roll(mask, -1)).nonzero(as_tuple=True)[0] @@ -554,7 +573,9 @@ class BaseMultimodalProcessor(ABC): return collected_items, input_ids, ret def process_and_combine_mm_data( - self, base_output: BaseMultiModalProcessorOutput + self, + base_output: BaseMultiModalProcessorOutput, + mm_tokens: MultimodalSpecialTokens, ) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]: """ Process multimodal data and return the combined multimodal items and input_ids. @@ -618,22 +639,14 @@ class BaseMultimodalProcessor(ABC): # Add offsets to all items for mm_item in all_collected_items: - if mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]: - mm_item.offsets = self.get_mm_items_offset( - input_ids=input_ids, - mm_token_id=self.IM_TOKEN_ID, - ) - elif mm_item.modality == Modality.AUDIO: - mm_item.offsets = self.get_mm_items_offset( - input_ids=input_ids, - mm_token_id=self.AUDIO_TOKEN_ID, - ) - elif mm_item.modality == Modality.VIDEO: - mm_item.offsets = self.get_mm_items_offset( - input_ids=input_ids, - mm_token_id=self.VIDEO_TOKEN_ID, - ) - else: - raise ValueError(f"Unknown modality: {mm_item.modality}") + mm_item.offsets = self.get_mm_items_offset( + input_ids=input_ids, + mm_token_id={ + Modality.IMAGE: mm_tokens.image_token_id, + Modality.MULTI_IMAGES: mm_tokens.image_token_id, + Modality.VIDEO: mm_tokens.video_token_id, + Modality.AUDIO: mm_tokens.audio_token_id, + }.get(mm_item.modality, None), + ) return all_collected_items, input_ids, ret diff --git a/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py b/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py index 50547ad2d..c21dce176 100644 --- a/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py +++ b/python/sglang/srt/multimodal/processors/deepseek_vl_v2.py @@ -33,7 +33,9 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) - self.IMAGE_TOKEN = "" + self.mm_tokens = MultimodalSpecialTokens(image_token="").build( + _processor + ) async def process_mm_data_async( self, @@ -47,7 +49,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): base_output = self.load_mm_data( input_text, image_data=image_data, - multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN), + multimodal_tokens=self.mm_tokens, max_req_input_len=max_req_input_len, ) res = self.process_mm_data( diff --git a/python/sglang/srt/multimodal/processors/gemma3.py b/python/sglang/srt/multimodal/processors/gemma3.py index e0858674a..dac9bd5c8 100644 --- a/python/sglang/srt/multimodal/processors/gemma3.py +++ b/python/sglang/srt/multimodal/processors/gemma3.py @@ -4,7 +4,6 @@ from typing import Dict, List, Union from sglang.srt.managers.multimodal_processor import ( BaseMultimodalProcessor as SGLangBaseProcessor, ) -from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens @@ -17,15 +16,17 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) - # The single, pre-expanded image token. - self.IMAGE_TOKEN = "" - # The regex that matches expanded image tokens. - self.IMAGE_TOKEN_REGEX = re.compile( - r"(?:(?:)*)?" - ) self.IM_START_TOKEN_ID = hf_config.boi_token_index self.IM_END_TOKEN_ID = hf_config.eoi_token_index - self.IM_TOKEN_ID = hf_config.image_token_index + self.mm_tokens = MultimodalSpecialTokens( + # The single, pre-expanded image token. + image_token="", + image_token_id=hf_config.image_token_index, + # The regex that matches expanded image tokens. + image_token_regex=re.compile( + r"(?:(?:)*)?" + ), + ).build(_processor) async def process_mm_data_async( self, @@ -39,14 +40,14 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): base_output = self.load_mm_data( prompt=input_text, image_data=image_data, - multimodal_tokens=MultimodalSpecialTokens( - image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX - ), + multimodal_tokens=self.mm_tokens, max_req_input_len=max_req_input_len, discard_alpha_channel=True, ) - mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output) + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) return { "input_ids": input_ids.tolist(), "mm_items": mm_items, diff --git a/python/sglang/srt/multimodal/processors/gemma3n.py b/python/sglang/srt/multimodal/processors/gemma3n.py index 92f3c0b93..aafeab7c9 100644 --- a/python/sglang/srt/multimodal/processors/gemma3n.py +++ b/python/sglang/srt/multimodal/processors/gemma3n.py @@ -30,23 +30,23 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) - self.IMAGE_TOKEN = "" - self.IMAGE_TOKEN_REGEX = re.compile( - r"(?:(?:)*)?" - ) - - self.AUDIO_TOKEN = "" - self.AUDIO_TOKEN_REGEX = re.compile( - r"(?:(?:)*)?" - ) - - self.IM_TOKEN_ID = hf_config.image_token_id self.IM_START_TOKEN_ID = hf_config.boi_token_id self.IM_END_TOKEN_ID = hf_config.eoi_token_id - self.AUDIO_TOKEN_ID = hf_config.audio_token_id self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id + self.mm_tokens = MultimodalSpecialTokens( + image_token="", + image_token_id=hf_config.image_token_id, + image_token_regex=re.compile( + r"(?:(?:)*)?" + ), + audio_token="", + audio_token_id=hf_config.audio_token_id, + audio_token_regex=re.compile( + r"(?:(?:)*)?" + ), + ).build(_processor) async def process_mm_data_async( self, @@ -64,19 +64,17 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor): image_data=image_data, audio_data=audio_data, max_req_input_len=max_req_input_len, - multimodal_tokens=MultimodalSpecialTokens( - image_token=self.IMAGE_TOKEN, - image_token_regex=self.IMAGE_TOKEN_REGEX, - audio_token=self.AUDIO_TOKEN, - audio_token_regex=self.AUDIO_TOKEN_REGEX, - ), + multimodal_tokens=self.mm_tokens, ) - mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output) + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) return { "input_ids": input_ids.tolist(), "mm_items": mm_items, - "im_token_id": self.IM_TOKEN_ID, - "audio_token_id": self.AUDIO_TOKEN_ID, + # TODO(mick): could we return MultimodalSpecialTokens directly? + "im_token_id": self.mm_tokens.image_token_id, + "audio_token_id": self.mm_tokens.audio_token_id, } diff --git a/python/sglang/srt/multimodal/processors/internvl.py b/python/sglang/srt/multimodal/processors/internvl.py index f9ed9ba76..d3413c457 100644 --- a/python/sglang/srt/multimodal/processors/internvl.py +++ b/python/sglang/srt/multimodal/processors/internvl.py @@ -24,7 +24,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor): self.IMG_CONTEXT_TOKEN = "" self.IMG_START_TOKEN = "" self.IMG_END_TOKEN = "" - self.IMG_TOKEN = "" self.num_image_token = int( (image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2) ) @@ -32,9 +31,10 @@ class InternVLImageProcessor(BaseMultimodalProcessor): tokenizer = self._processor self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN) self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN) - self.img_context_token_id = tokenizer.convert_tokens_to_ids( - self.IMG_CONTEXT_TOKEN - ) + self.mm_tokens = MultimodalSpecialTokens( + image_token="", + image_token_id=tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT_TOKEN), + ).build(_image_processor) @staticmethod def build_transform(input_size): @@ -175,7 +175,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor): base_output = self.load_mm_data( prompt=input_text, image_data=image_data, - multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMG_TOKEN), + multimodal_tokens=self.mm_tokens, max_req_input_len=max_req_input_len, discard_alpha_channel=True, ) @@ -219,7 +219,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor): input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].flatten() image_offsets = self.get_mm_items_offset( input_ids=input_ids, - mm_token_id=self.img_context_token_id, + mm_token_id=self.mm_tokens.image_token_id, ) items = [ MultimodalDataItem( @@ -234,5 +234,5 @@ class InternVLImageProcessor(BaseMultimodalProcessor): "mm_items": items, "im_start_id": self.img_start_token_id, "im_end_id": self.img_end_token_id, - "im_token_id": self.img_context_token_id, + "im_token_id": self.mm_tokens.image_token_id, } diff --git a/python/sglang/srt/multimodal/processors/janus_pro.py b/python/sglang/srt/multimodal/processors/janus_pro.py index 8ea013d29..28be34c57 100644 --- a/python/sglang/srt/multimodal/processors/janus_pro.py +++ b/python/sglang/srt/multimodal/processors/janus_pro.py @@ -11,8 +11,12 @@ from sglang.srt.multimodal.processors.base_processor import ( class JanusProImageProcessor(BaseMultimodalProcessor): models = [MultiModalityCausalLM] - def __init__(self, hf_config, server_args, _processor): - super().__init__(hf_config, server_args, _processor) + def __init__(self, hf_config, server_args, processor): + super().__init__(hf_config, server_args, processor) + + self.mm_tokens = MultimodalSpecialTokens( + image_token=processor.image_token + ).build(processor) async def process_mm_data_async( self, @@ -27,9 +31,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor): base_out = self.load_mm_data( prompt=input_text, image_data=image_data, - multimodal_tokens=MultimodalSpecialTokens( - image_token=processor.image_token - ), + multimodal_tokens=self.mm_tokens, max_req_input_len=max_req_input_len, ) diff --git a/python/sglang/srt/multimodal/processors/kimi_vl.py b/python/sglang/srt/multimodal/processors/kimi_vl.py index b593da48f..ef533c16d 100644 --- a/python/sglang/srt/multimodal/processors/kimi_vl.py +++ b/python/sglang/srt/multimodal/processors/kimi_vl.py @@ -1,9 +1,6 @@ import re -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Union -import torch - -from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.models.kimi_vl import KimiVLForConditionalGeneration from sglang.srt.multimodal.processors.base_processor import ( BaseMultimodalProcessor as SGLangBaseProcessor, @@ -17,9 +14,12 @@ class KimiVLImageProcessor(SGLangBaseProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) - self.IMAGE_TOKEN = "<|media_pad|>" - self.IMAGE_TOKEN_REGEX = re.compile(r"(?:<\|media_pad\|>)+") - self.IM_TOKEN_ID = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN) + self.mm_tokens = MultimodalSpecialTokens( + image_token="<|media_pad|>", + # TODO: could we convert in MultimodalSpecialTokens? + image_token_id=hf_config.media_placeholder_token_id, + image_token_regex=re.compile(r"(?:<\|media_pad\|>)+"), + ).build(_processor) async def process_mm_data_async( self, @@ -33,16 +33,16 @@ class KimiVLImageProcessor(SGLangBaseProcessor): base_output = self.load_mm_data( prompt=input_text, image_data=image_data, - multimodal_tokens=MultimodalSpecialTokens( - image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX - ), + multimodal_tokens=self.mm_tokens, max_req_input_len=max_req_input_len, ) - mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output) + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) return { "input_ids": input_ids.tolist(), "mm_items": mm_items, - "im_token_id": self.IM_TOKEN_ID, + "im_token_id": self.mm_tokens.image_token_id, } diff --git a/python/sglang/srt/multimodal/processors/minicpm.py b/python/sglang/srt/multimodal/processors/minicpm.py index 369971ccb..3ba547b38 100644 --- a/python/sglang/srt/multimodal/processors/minicpm.py +++ b/python/sglang/srt/multimodal/processors/minicpm.py @@ -17,9 +17,11 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) - self.image_token = "(./)" - self.audio_token = "()" - self.video_token = "()" + self.mm_tokens = MultimodalSpecialTokens( + image_token="(./)", + audio_token="()", + video_token="()", + ).build(_processor) async def process_mm_data_async( self, @@ -35,11 +37,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): max_req_input_len=max_req_input_len, audio_data=audio_data, image_data=image_data, - multimodal_tokens=MultimodalSpecialTokens( - image_token=self.image_token, - video_token=self.video_token, - audio_token=self.audio_token, - ), + multimodal_tokens=self.mm_tokens, ) if base_output is None: return None diff --git a/python/sglang/srt/multimodal/processors/mllama4.py b/python/sglang/srt/multimodal/processors/mllama4.py index ccf70adc8..566eb3230 100644 --- a/python/sglang/srt/multimodal/processors/mllama4.py +++ b/python/sglang/srt/multimodal/processors/mllama4.py @@ -26,8 +26,8 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor): self.eoi_token_index = hf_config.eoi_token_index self.image_token_index = hf_config.image_token_index self.multimodal_tokens = MultimodalSpecialTokens( - image_token=_processor.image_token - ) + image_token=_processor.image_token, + ).build(_processor) async def process_mm_data_async( self, diff --git a/python/sglang/srt/multimodal/processors/phi4mm.py b/python/sglang/srt/multimodal/processors/phi4mm.py index d2e009d27..aea06506d 100644 --- a/python/sglang/srt/multimodal/processors/phi4mm.py +++ b/python/sglang/srt/multimodal/processors/phi4mm.py @@ -21,7 +21,7 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor): super().__init__(hf_config, server_args, _processor) self.multimodal_tokens = MultimodalSpecialTokens( image_token=_IMAGE_SPECIAL_TOKEN, - ) + ).build(_processor) async def process_mm_data_async( self, diff --git a/python/sglang/srt/multimodal/processors/pixtral.py b/python/sglang/srt/multimodal/processors/pixtral.py index 8b741d627..b18dfa1b0 100644 --- a/python/sglang/srt/multimodal/processors/pixtral.py +++ b/python/sglang/srt/multimodal/processors/pixtral.py @@ -55,7 +55,7 @@ class PixtralProcessor(BaseMultimodalProcessor): self.patch_size = self.vision_config.patch_size self.multimodal_tokens = MultimodalSpecialTokens( image_token=_processor.image_token - ) + ).build(_processor) _processor.tokenizer.add_special_tokens( { "pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN), diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py index 1ecb4e119..bdfaf1406 100644 --- a/python/sglang/srt/multimodal/processors/qwen_vl.py +++ b/python/sglang/srt/multimodal/processors/qwen_vl.py @@ -203,16 +203,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) - # The single, pre-expanded image token. - self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>" # The regex that matches expanded image tokens. - self.IMAGE_TOKEN_REGEX = re.compile( - r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>" - ) self.IM_START_TOKEN_ID = hf_config.vision_start_token_id self.IM_END_TOKEN_ID = hf_config.vision_end_token_id - self.IM_TOKEN_ID = hf_config.image_token_id - self.VIDEO_TOKEN_ID = hf_config.video_token_id self.vision_start_token_id = hf_config.vision_start_token_id self.vision_end_token_id = hf_config.vision_end_token_id self.NUM_TOKEN_PER_FRAME = 770 @@ -220,12 +213,14 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): self.MIN_PIXELS = 4 * 28 * 28 self.MAX_PIXELS = 16384 * 28 * 28 self.MAX_RATIO = 200 - # TODO(mick): move all MultimodalSpecialTokens initializations into processor init - self.mm_special_tokens = MultimodalSpecialTokens( - image_token=self.IMAGE_TOKEN, - image_token_regex=self.IMAGE_TOKEN_REGEX, - video_token=self.VIDEO_TOKEN_ID, - ) + self.mm_tokens = MultimodalSpecialTokens( + image_token="<|vision_start|><|image_pad|><|vision_end|>", + image_token_id=hf_config.image_token_id, + image_token_regex=re.compile( + r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>" + ), + video_token_id=hf_config.video_token_id, + ).build(_processor) async def process_mm_data_async( self, @@ -241,7 +236,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): prompt=input_text, image_data=image_data, video_data=request_obj.video_data, - multimodal_tokens=self.mm_special_tokens, + multimodal_tokens=self.mm_tokens, max_req_input_len=max_req_input_len, ) @@ -255,13 +250,15 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): await preprocess_video(video) for video in base_output.videos ] - mm_items, input_ids, ret = self.process_and_combine_mm_data(base_output) + mm_items, input_ids, ret = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) input_ids = input_ids.flatten() mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( spatial_merge_size=self.hf_config.vision_config.spatial_merge_size, - image_token_id=self.IM_TOKEN_ID, - video_token_id=self.VIDEO_TOKEN_ID, + image_token_id=self.mm_tokens.image_token_id, + video_token_id=self.mm_tokens.video_token_id, vision_start_token_id=self.vision_start_token_id, model_type=self.hf_config.model_type, tokens_per_second=getattr( @@ -279,8 +276,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): "mm_items": mm_items, "im_start_id": self.IM_START_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID, - "im_token_id": self.IM_TOKEN_ID, - "video_token_id": self.VIDEO_TOKEN_ID, + "im_token_id": self.mm_tokens.image_token_id, + "video_token_id": self.mm_tokens.video_token_id, "mrope_positions": mrope_positions, "mrope_position_delta": mrope_position_delta, } diff --git a/python/sglang/srt/multimodal/processors/vila.py b/python/sglang/srt/multimodal/processors/vila.py index c4d676c6d..8e0f04aca 100644 --- a/python/sglang/srt/multimodal/processors/vila.py +++ b/python/sglang/srt/multimodal/processors/vila.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Type, cast +from typing import Any, Dict, List, Optional, Type import torch.nn as nn from transformers.configuration_utils import PretrainedConfig @@ -10,7 +10,6 @@ from sglang.srt.managers.io_struct import ( GenerateReqInput, ImageDataInputItem, ) -from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.models.vila import VILAForConditionalGeneration from sglang.srt.multimodal.processors.base_processor import ( BaseMultimodalProcessor, @@ -37,8 +36,11 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor): _processor: VILAProcessor, ) -> None: super().__init__(hf_config, server_args, _processor) - self.IM_TOKEN_ID = hf_config.image_token_id - self.VIDEO_TOKEN_ID = hf_config.video_token_id + self.mm_tokens = MultimodalSpecialTokens( + image_token=self._processor.tokenizer.image_token, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + ).build(_processor) async def process_mm_data_async( self, @@ -50,18 +52,18 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor): ) -> Optional[Dict[str, Any]]: base_output = self.load_mm_data( prompt=input_text, - multimodal_tokens=MultimodalSpecialTokens( - image_token=self._processor.tokenizer.image_token - ), + multimodal_tokens=self.mm_tokens, max_req_input_len=max_req_input_len, image_data=image_data, ) - mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output) + mm_items, input_ids, _ = self.process_and_combine_mm_data( + base_output, self.mm_tokens + ) return { "input_ids": input_ids.tolist(), "mm_items": mm_items, - "im_token_id": self.IM_TOKEN_ID, - "video_token_id": self.VIDEO_TOKEN_ID, + "im_token_id": self.mm_tokens.image_token_id, + "video_token_id": self.mm_tokens.video_token_id, }