fix: retrieve mm token by modality, raise error if none (#8221)
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com> Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
@@ -101,6 +101,14 @@ class MultimodalSpecialTokens:
|
||||
|
||||
return None
|
||||
|
||||
def get_token_id_by_modality(self, modality: Modality) -> Optional[int]:
|
||||
return {
|
||||
Modality.IMAGE: self.image_token_id,
|
||||
Modality.MULTI_IMAGES: self.image_token_id,
|
||||
Modality.VIDEO: self.video_token_id,
|
||||
Modality.AUDIO: self.audio_token_id,
|
||||
}.get(modality)
|
||||
|
||||
def parse_regex(self):
|
||||
if self.image_token_regex is None and self.image_token is not None:
|
||||
self.image_token_regex = re.compile(re.escape(self.image_token))
|
||||
@@ -608,14 +616,12 @@ class BaseMultimodalProcessor(ABC):
|
||||
|
||||
# Add offsets to all items
|
||||
for mm_item in all_collected_items:
|
||||
mm_token_id = mm_tokens.get_token_id_by_modality(mm_item.modality)
|
||||
if mm_token_id is None:
|
||||
raise ValueError(f"No token id found for modality: {mm_item.modality}")
|
||||
mm_item.offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id={
|
||||
Modality.IMAGE: mm_tokens.image_token_id,
|
||||
Modality.MULTI_IMAGES: mm_tokens.image_token_id,
|
||||
Modality.VIDEO: mm_tokens.video_token_id,
|
||||
Modality.AUDIO: mm_tokens.audio_token_id,
|
||||
}.get(mm_item.modality, None),
|
||||
mm_token_id=mm_token_id,
|
||||
)
|
||||
|
||||
return all_collected_items, input_ids, ret
|
||||
|
||||
Reference in New Issue
Block a user