refactor: simply MultimodalTokens logic (#7924)
This commit is contained in:
@@ -21,7 +21,7 @@ class BaseMultiModalProcessorOutput:
|
||||
# input_text, with each frame of video/image represented with a image_token
|
||||
input_text: str
|
||||
|
||||
# frames loaded from image and video, in given order
|
||||
# frames loaded from image, in given order
|
||||
images: Optional[list[Union[Image.Image, dict]]] = None
|
||||
|
||||
# videos
|
||||
@@ -44,14 +44,26 @@ class BaseMultiModalProcessorOutput:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MultimodalSpecialTokens:
|
||||
image_token: Optional[Union[int, str, List[str]]] = None
|
||||
video_token: Optional[Union[int, str, List[str]]] = None
|
||||
audio_token: Optional[Union[int, str, List[str]]] = None
|
||||
image_token: Optional[Union[str, List[str]]] = None
|
||||
video_token: Optional[Union[str, List[str]]] = None
|
||||
audio_token: Optional[Union[str, List[str]]] = None
|
||||
|
||||
image_token_id: Optional[int] = None
|
||||
video_token_id: Optional[int] = None
|
||||
audio_token_id: Optional[int] = None
|
||||
|
||||
image_token_regex: Optional[re.Pattern] = None
|
||||
video_token_regex: Optional[re.Pattern] = None
|
||||
audio_token_regex: Optional[re.Pattern] = None
|
||||
|
||||
combined_regex: Optional[re.Pattern] = None
|
||||
|
||||
def build(self, processor):
|
||||
self.convert_to_strs(processor)
|
||||
self.parse_regex()
|
||||
self.get_combined_regex()
|
||||
return self
|
||||
|
||||
def convert_to_str(self, token: Union[str, int], processor) -> str:
|
||||
if token is None:
|
||||
return token
|
||||
@@ -60,11 +72,14 @@ class MultimodalSpecialTokens:
|
||||
return processor.tokenizer.convert_ids_to_tokens([token])[0]
|
||||
|
||||
def convert_to_strs(self, processor):
|
||||
self.image_token = self.convert_to_str(self.image_token, processor)
|
||||
self.video_token = self.convert_to_str(self.video_token, processor)
|
||||
self.audio_token = self.convert_to_str(self.audio_token, processor)
|
||||
if not self.image_token:
|
||||
self.image_token = self.convert_to_str(self.image_token_id, processor)
|
||||
if not self.video_token:
|
||||
self.video_token = self.convert_to_str(self.video_token_id, processor)
|
||||
if not self.audio_token:
|
||||
self.audio_token = self.convert_to_str(self.audio_token_id, processor)
|
||||
|
||||
def get_modality_of_token(self, token) -> Optional[Modality]:
|
||||
def get_modality_of_token(self, token: str) -> Optional[Modality]:
|
||||
"""
|
||||
:return: the modality associated with the given token, if the token is a special_token or matches with the multimodal token regex
|
||||
"""
|
||||
@@ -94,7 +109,12 @@ class MultimodalSpecialTokens:
|
||||
if self.audio_token_regex is None and self.audio_token is not None:
|
||||
self.audio_token_regex = re.compile(re.escape(self.audio_token))
|
||||
|
||||
def combine_regex(self) -> re.Pattern:
|
||||
def get_combined_regex(self) -> re.Pattern:
|
||||
"""
|
||||
Builds and returns a regex, used to split input str into tokens (with mm special tokens)
|
||||
"""
|
||||
if self.combined_regex:
|
||||
return self.combined_regex
|
||||
tokens = [
|
||||
self.image_token_regex,
|
||||
self.video_token_regex,
|
||||
@@ -107,7 +127,8 @@ class MultimodalSpecialTokens:
|
||||
patterns.append(t.pattern)
|
||||
flags |= t.flags
|
||||
combined = "(" + "|".join(f"(?:{p})" for p in patterns) + ")"
|
||||
return re.compile(combined, flags)
|
||||
self.combined_regex = re.compile(combined, flags)
|
||||
return self.combined_regex
|
||||
|
||||
|
||||
class BaseMultimodalProcessor(ABC):
|
||||
@@ -341,9 +362,8 @@ class BaseMultimodalProcessor(ABC):
|
||||
discard_alpha_channel: if True, discards the alpha channel in the returned images
|
||||
|
||||
"""
|
||||
multimodal_tokens.convert_to_strs(self._processor)
|
||||
multimodal_tokens.parse_regex()
|
||||
multimodal_tokens_pattern = multimodal_tokens.combine_regex()
|
||||
multimodal_tokens_pattern = multimodal_tokens.get_combined_regex()
|
||||
|
||||
if isinstance(prompt, list) and return_text:
|
||||
assert len(prompt) and isinstance(prompt[0], int)
|
||||
prompt = self._processor.tokenizer.decode(prompt)
|
||||
@@ -445,7 +465,6 @@ class BaseMultimodalProcessor(ABC):
|
||||
return result = [(2,4),(6,7)]
|
||||
"""
|
||||
mask = input_ids == mm_token_id
|
||||
|
||||
start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0]
|
||||
end_positions = (mask & ~torch.roll(mask, -1)).nonzero(as_tuple=True)[0]
|
||||
|
||||
@@ -554,7 +573,9 @@ class BaseMultimodalProcessor(ABC):
|
||||
return collected_items, input_ids, ret
|
||||
|
||||
def process_and_combine_mm_data(
|
||||
self, base_output: BaseMultiModalProcessorOutput
|
||||
self,
|
||||
base_output: BaseMultiModalProcessorOutput,
|
||||
mm_tokens: MultimodalSpecialTokens,
|
||||
) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]:
|
||||
"""
|
||||
Process multimodal data and return the combined multimodal items and input_ids.
|
||||
@@ -618,22 +639,14 @@ class BaseMultimodalProcessor(ABC):
|
||||
|
||||
# Add offsets to all items
|
||||
for mm_item in all_collected_items:
|
||||
if mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
|
||||
mm_item.offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.IM_TOKEN_ID,
|
||||
)
|
||||
elif mm_item.modality == Modality.AUDIO:
|
||||
mm_item.offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.AUDIO_TOKEN_ID,
|
||||
)
|
||||
elif mm_item.modality == Modality.VIDEO:
|
||||
mm_item.offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.VIDEO_TOKEN_ID,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown modality: {mm_item.modality}")
|
||||
mm_item.offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id={
|
||||
Modality.IMAGE: mm_tokens.image_token_id,
|
||||
Modality.MULTI_IMAGES: mm_tokens.image_token_id,
|
||||
Modality.VIDEO: mm_tokens.video_token_id,
|
||||
Modality.AUDIO: mm_tokens.audio_token_id,
|
||||
}.get(mm_item.modality, None),
|
||||
)
|
||||
|
||||
return all_collected_items, input_ids, ret
|
||||
|
||||
@@ -33,7 +33,9 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "<image>"
|
||||
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
|
||||
_processor
|
||||
)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
@@ -47,7 +49,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
base_output = self.load_mm_data(
|
||||
input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
res = self.process_mm_data(
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Dict, List, Union
|
||||
from sglang.srt.managers.multimodal_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
|
||||
from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens
|
||||
|
||||
@@ -17,15 +16,17 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
# The single, pre-expanded image token.
|
||||
self.IMAGE_TOKEN = "<start_of_image>"
|
||||
# The regex that matches expanded image tokens.
|
||||
self.IMAGE_TOKEN_REGEX = re.compile(
|
||||
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
|
||||
)
|
||||
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
||||
self.IM_TOKEN_ID = hf_config.image_token_index
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
# The single, pre-expanded image token.
|
||||
image_token="<start_of_image>",
|
||||
image_token_id=hf_config.image_token_index,
|
||||
# The regex that matches expanded image tokens.
|
||||
image_token_regex=re.compile(
|
||||
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
|
||||
),
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
@@ -39,14 +40,14 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX
|
||||
),
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output)
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
base_output, self.mm_tokens
|
||||
)
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": mm_items,
|
||||
|
||||
@@ -30,23 +30,23 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
self.IMAGE_TOKEN = "<image_soft_token>"
|
||||
self.IMAGE_TOKEN_REGEX = re.compile(
|
||||
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
|
||||
)
|
||||
|
||||
self.AUDIO_TOKEN = "<audio_soft_token>"
|
||||
self.AUDIO_TOKEN_REGEX = re.compile(
|
||||
r"<start_of_audio>(?:(?:<audio_soft_token>)*<end_of_audio>)?"
|
||||
)
|
||||
|
||||
self.IM_TOKEN_ID = hf_config.image_token_id
|
||||
self.IM_START_TOKEN_ID = hf_config.boi_token_id
|
||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_id
|
||||
|
||||
self.AUDIO_TOKEN_ID = hf_config.audio_token_id
|
||||
self.AUDIO_START_TOKEN_ID = hf_config.boa_token_id
|
||||
self.AUDIO_END_TOKEN_ID = hf_config.eoa_token_id
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token="<image_soft_token>",
|
||||
image_token_id=hf_config.image_token_id,
|
||||
image_token_regex=re.compile(
|
||||
r"<start_of_image>(?:(?:<image_soft_token>)*<end_of_image>)?"
|
||||
),
|
||||
audio_token="<audio_soft_token>",
|
||||
audio_token_id=hf_config.audio_token_id,
|
||||
audio_token_regex=re.compile(
|
||||
r"<start_of_audio>(?:(?:<audio_soft_token>)*<end_of_audio>)?"
|
||||
),
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
@@ -64,19 +64,17 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
|
||||
image_data=image_data,
|
||||
audio_data=audio_data,
|
||||
max_req_input_len=max_req_input_len,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN,
|
||||
image_token_regex=self.IMAGE_TOKEN_REGEX,
|
||||
audio_token=self.AUDIO_TOKEN,
|
||||
audio_token_regex=self.AUDIO_TOKEN_REGEX,
|
||||
),
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
)
|
||||
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output)
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
base_output, self.mm_tokens
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": mm_items,
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
"audio_token_id": self.AUDIO_TOKEN_ID,
|
||||
# TODO(mick): could we return MultimodalSpecialTokens directly?
|
||||
"im_token_id": self.mm_tokens.image_token_id,
|
||||
"audio_token_id": self.mm_tokens.audio_token_id,
|
||||
}
|
||||
|
||||
@@ -24,7 +24,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
|
||||
self.IMG_START_TOKEN = "<img>"
|
||||
self.IMG_END_TOKEN = "</img>"
|
||||
self.IMG_TOKEN = "<image>"
|
||||
self.num_image_token = int(
|
||||
(image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2)
|
||||
)
|
||||
@@ -32,9 +31,10 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
tokenizer = self._processor
|
||||
self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
|
||||
self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
|
||||
self.img_context_token_id = tokenizer.convert_tokens_to_ids(
|
||||
self.IMG_CONTEXT_TOKEN
|
||||
)
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token="<image>",
|
||||
image_token_id=tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT_TOKEN),
|
||||
).build(_image_processor)
|
||||
|
||||
@staticmethod
|
||||
def build_transform(input_size):
|
||||
@@ -175,7 +175,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMG_TOKEN),
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
@@ -219,7 +219,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.img_context_token_id,
|
||||
mm_token_id=self.mm_tokens.image_token_id,
|
||||
)
|
||||
items = [
|
||||
MultimodalDataItem(
|
||||
@@ -234,5 +234,5 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
"mm_items": items,
|
||||
"im_start_id": self.img_start_token_id,
|
||||
"im_end_id": self.img_end_token_id,
|
||||
"im_token_id": self.img_context_token_id,
|
||||
"im_token_id": self.mm_tokens.image_token_id,
|
||||
}
|
||||
|
||||
@@ -11,8 +11,12 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||
models = [MultiModalityCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
def __init__(self, hf_config, server_args, processor):
|
||||
super().__init__(hf_config, server_args, processor)
|
||||
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token=processor.image_token
|
||||
).build(processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
@@ -27,9 +31,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||
base_out = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=processor.image_token
|
||||
),
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.kimi_vl import KimiVLForConditionalGeneration
|
||||
from sglang.srt.multimodal.processors.base_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
@@ -17,9 +14,12 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "<|media_pad|>"
|
||||
self.IMAGE_TOKEN_REGEX = re.compile(r"(?:<\|media_pad\|>)+")
|
||||
self.IM_TOKEN_ID = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token="<|media_pad|>",
|
||||
# TODO: could we convert in MultimodalSpecialTokens?
|
||||
image_token_id=hf_config.media_placeholder_token_id,
|
||||
image_token_regex=re.compile(r"(?:<\|media_pad\|>)+"),
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
@@ -33,16 +33,16 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX
|
||||
),
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output)
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
base_output, self.mm_tokens
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": mm_items,
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
"im_token_id": self.mm_tokens.image_token_id,
|
||||
}
|
||||
|
||||
@@ -17,9 +17,11 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.image_token = "(<image>./</image>)"
|
||||
self.audio_token = "(<audio>./</audio>)"
|
||||
self.video_token = "(<video>./</video>)"
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token="(<image>./</image>)",
|
||||
audio_token="(<audio>./</audio>)",
|
||||
video_token="(<video>./</video>)",
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
@@ -35,11 +37,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
max_req_input_len=max_req_input_len,
|
||||
audio_data=audio_data,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self.image_token,
|
||||
video_token=self.video_token,
|
||||
audio_token=self.audio_token,
|
||||
),
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
)
|
||||
if base_output is None:
|
||||
return None
|
||||
|
||||
@@ -26,8 +26,8 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
self.eoi_token_index = hf_config.eoi_token_index
|
||||
self.image_token_index = hf_config.image_token_index
|
||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||
image_token=_processor.image_token
|
||||
)
|
||||
image_token=_processor.image_token,
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
|
||||
@@ -21,7 +21,7 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||
image_token=_IMAGE_SPECIAL_TOKEN,
|
||||
)
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
|
||||
@@ -55,7 +55,7 @@ class PixtralProcessor(BaseMultimodalProcessor):
|
||||
self.patch_size = self.vision_config.patch_size
|
||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||
image_token=_processor.image_token
|
||||
)
|
||||
).build(_processor)
|
||||
_processor.tokenizer.add_special_tokens(
|
||||
{
|
||||
"pad_token": getattr(hf_config, "pad_token", self.PAD_TOKEN),
|
||||
|
||||
@@ -203,16 +203,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
# The single, pre-expanded image token.
|
||||
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
# The regex that matches expanded image tokens.
|
||||
self.IMAGE_TOKEN_REGEX = re.compile(
|
||||
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
|
||||
)
|
||||
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
||||
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
||||
self.IM_TOKEN_ID = hf_config.image_token_id
|
||||
self.VIDEO_TOKEN_ID = hf_config.video_token_id
|
||||
self.vision_start_token_id = hf_config.vision_start_token_id
|
||||
self.vision_end_token_id = hf_config.vision_end_token_id
|
||||
self.NUM_TOKEN_PER_FRAME = 770
|
||||
@@ -220,12 +213,14 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
self.MIN_PIXELS = 4 * 28 * 28
|
||||
self.MAX_PIXELS = 16384 * 28 * 28
|
||||
self.MAX_RATIO = 200
|
||||
# TODO(mick): move all MultimodalSpecialTokens initializations into processor init
|
||||
self.mm_special_tokens = MultimodalSpecialTokens(
|
||||
image_token=self.IMAGE_TOKEN,
|
||||
image_token_regex=self.IMAGE_TOKEN_REGEX,
|
||||
video_token=self.VIDEO_TOKEN_ID,
|
||||
)
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
||||
image_token_id=hf_config.image_token_id,
|
||||
image_token_regex=re.compile(
|
||||
r"<\|vision_start\|>(?:<\|image_pad\|>)+<\|vision_end\|>"
|
||||
),
|
||||
video_token_id=hf_config.video_token_id,
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
@@ -241,7 +236,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
video_data=request_obj.video_data,
|
||||
multimodal_tokens=self.mm_special_tokens,
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
|
||||
@@ -255,13 +250,15 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
await preprocess_video(video) for video in base_output.videos
|
||||
]
|
||||
|
||||
mm_items, input_ids, ret = self.process_and_combine_mm_data(base_output)
|
||||
mm_items, input_ids, ret = self.process_and_combine_mm_data(
|
||||
base_output, self.mm_tokens
|
||||
)
|
||||
|
||||
input_ids = input_ids.flatten()
|
||||
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
|
||||
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
|
||||
image_token_id=self.IM_TOKEN_ID,
|
||||
video_token_id=self.VIDEO_TOKEN_ID,
|
||||
image_token_id=self.mm_tokens.image_token_id,
|
||||
video_token_id=self.mm_tokens.video_token_id,
|
||||
vision_start_token_id=self.vision_start_token_id,
|
||||
model_type=self.hf_config.model_type,
|
||||
tokens_per_second=getattr(
|
||||
@@ -279,8 +276,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
"mm_items": mm_items,
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
"video_token_id": self.VIDEO_TOKEN_ID,
|
||||
"im_token_id": self.mm_tokens.image_token_id,
|
||||
"video_token_id": self.mm_tokens.video_token_id,
|
||||
"mrope_positions": mrope_positions,
|
||||
"mrope_position_delta": mrope_position_delta,
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, Type, cast
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
@@ -10,7 +10,6 @@ from sglang.srt.managers.io_struct import (
|
||||
GenerateReqInput,
|
||||
ImageDataInputItem,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.vila import VILAForConditionalGeneration
|
||||
from sglang.srt.multimodal.processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
@@ -37,8 +36,11 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
||||
_processor: VILAProcessor,
|
||||
) -> None:
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IM_TOKEN_ID = hf_config.image_token_id
|
||||
self.VIDEO_TOKEN_ID = hf_config.video_token_id
|
||||
self.mm_tokens = MultimodalSpecialTokens(
|
||||
image_token=self._processor.tokenizer.image_token,
|
||||
image_token_id=hf_config.image_token_id,
|
||||
video_token_id=hf_config.video_token_id,
|
||||
).build(_processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
@@ -50,18 +52,18 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
multimodal_tokens=MultimodalSpecialTokens(
|
||||
image_token=self._processor.tokenizer.image_token
|
||||
),
|
||||
multimodal_tokens=self.mm_tokens,
|
||||
max_req_input_len=max_req_input_len,
|
||||
image_data=image_data,
|
||||
)
|
||||
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output)
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
base_output, self.mm_tokens
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": mm_items,
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
"video_token_id": self.VIDEO_TOKEN_ID,
|
||||
"im_token_id": self.mm_tokens.image_token_id,
|
||||
"video_token_id": self.mm_tokens.video_token_id,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user