fix: retrieve mm token by modality, raise error if none (#8221)

Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
Xinyuan Tong
2025-07-21 17:06:35 -07:00
committed by GitHub
parent 114837854f
commit 69adc4f81c

View File

@@ -101,6 +101,14 @@ class MultimodalSpecialTokens:
return None
def get_token_id_by_modality(self, modality: Modality) -> Optional[int]:
return {
Modality.IMAGE: self.image_token_id,
Modality.MULTI_IMAGES: self.image_token_id,
Modality.VIDEO: self.video_token_id,
Modality.AUDIO: self.audio_token_id,
}.get(modality)
def parse_regex(self):
if self.image_token_regex is None and self.image_token is not None:
self.image_token_regex = re.compile(re.escape(self.image_token))
@@ -608,14 +616,12 @@ class BaseMultimodalProcessor(ABC):
# Add offsets to all items
for mm_item in all_collected_items:
mm_token_id = mm_tokens.get_token_id_by_modality(mm_item.modality)
if mm_token_id is None:
raise ValueError(f"No token id found for modality: {mm_item.modality}")
mm_item.offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id={
Modality.IMAGE: mm_tokens.image_token_id,
Modality.MULTI_IMAGES: mm_tokens.image_token_id,
Modality.VIDEO: mm_tokens.video_token_id,
Modality.AUDIO: mm_tokens.audio_token_id,
}.get(mm_item.modality, None),
mm_token_id=mm_token_id,
)
return all_collected_items, input_ids, ret