diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index b79d90b98..3d548a19e 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -101,6 +101,14 @@ class MultimodalSpecialTokens: return None + def get_token_id_by_modality(self, modality: Modality) -> Optional[int]: + return { + Modality.IMAGE: self.image_token_id, + Modality.MULTI_IMAGES: self.image_token_id, + Modality.VIDEO: self.video_token_id, + Modality.AUDIO: self.audio_token_id, + }.get(modality) + def parse_regex(self): if self.image_token_regex is None and self.image_token is not None: self.image_token_regex = re.compile(re.escape(self.image_token)) @@ -608,14 +616,12 @@ class BaseMultimodalProcessor(ABC): # Add offsets to all items for mm_item in all_collected_items: + mm_token_id = mm_tokens.get_token_id_by_modality(mm_item.modality) + if mm_token_id is None: + raise ValueError(f"No token id found for modality: {mm_item.modality}") mm_item.offsets = self.get_mm_items_offset( input_ids=input_ids, - mm_token_id={ - Modality.IMAGE: mm_tokens.image_token_id, - Modality.MULTI_IMAGES: mm_tokens.image_token_id, - Modality.VIDEO: mm_tokens.video_token_id, - Modality.AUDIO: mm_tokens.audio_token_id, - }.get(mm_item.modality, None), + mm_token_id=mm_token_id, ) return all_collected_items, input_ids, ret