[VLM] Support chunk prefill for VLM (#6355)
Co-authored-by: yizhang2077 <1109276519@qq.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -343,6 +343,33 @@ class BaseMultimodalProcessor(ABC):
|
||||
out.normalize()
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def get_mm_items_offset(
|
||||
input_ids: torch.Tensor, mm_token_id: int
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Get a set of range for mm_items from input_ids
|
||||
Example:
|
||||
input_ids = [1, 2, 3, 3, 3, 4, 3, 3]
|
||||
mm_token_id = 3
|
||||
return result = [(2,4),(6,7)]
|
||||
"""
|
||||
mask = input_ids == mm_token_id
|
||||
|
||||
start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0]
|
||||
end_positions = (mask & ~torch.roll(mask, -1)).nonzero(as_tuple=True)[0]
|
||||
|
||||
return list(zip(start_positions.tolist(), end_positions.tolist()))
|
||||
|
||||
@staticmethod
|
||||
def get_mm_items_offset_by_pair(
|
||||
input_ids: torch.Tensor, mm_start_id: int, mm_end_id: int
|
||||
) -> List[Tuple[int, int]]:
|
||||
indices_start = (input_ids == mm_start_id).nonzero(as_tuple=True)[0] + 1
|
||||
indices_end = (input_ids == mm_end_id).nonzero(as_tuple=True)[0] - 1
|
||||
|
||||
return list(zip(indices_start.tolist(), indices_end.tolist()))
|
||||
|
||||
def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]):
|
||||
"""Returns true if all images are preprocessed, false if all are not, and error otherwise."""
|
||||
if not mm_inputs:
|
||||
|
||||
@@ -70,8 +70,13 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
|
||||
|
||||
items = []
|
||||
input_ids = res["input_ids"]
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids, mm_token_id=self._processor.image_token_id
|
||||
)
|
||||
item = MultimodalDataItem(
|
||||
pixel_values=res["images"],
|
||||
image_offsets=image_offsets,
|
||||
modality=Modality.IMAGE,
|
||||
image_emb_mask=images_seq_mask,
|
||||
image_spatial_crop=batched_images_spatial_crop,
|
||||
@@ -80,6 +85,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
|
||||
|
||||
return {
|
||||
"mm_items": items,
|
||||
"input_ids": res["input_ids"].tolist(),
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_token_id": self._processor.image_token_id,
|
||||
}
|
||||
|
||||
@@ -61,6 +61,11 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
)
|
||||
|
||||
items = []
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.hf_config.image_token_index,
|
||||
)
|
||||
for i, image in enumerate(base_output.images):
|
||||
if images_are_preprocessed:
|
||||
pixel_values = image.pixel_values
|
||||
@@ -73,12 +78,13 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
|
||||
pixel_values=pixel_values,
|
||||
precomputed_features=precomputed_features,
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets[i],
|
||||
)
|
||||
items += [item]
|
||||
|
||||
return {
|
||||
"mm_items": items,
|
||||
"input_ids": ret["input_ids"].flatten().tolist(),
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
}
|
||||
|
||||
@@ -209,7 +209,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
return None
|
||||
|
||||
pixel_values = torch.cat(pixel_values, dim=0)
|
||||
items = [MultimodalDataItem(pixel_values=pixel_values, modality=Modality.IMAGE)]
|
||||
|
||||
for idx, num_patches in enumerate(num_patches_list):
|
||||
image_tokens = (
|
||||
@@ -220,10 +219,21 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
input_text = input_text.replace("<image>", image_tokens, 1)
|
||||
|
||||
tokenizer = self._processor
|
||||
input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.img_context_token_id,
|
||||
)
|
||||
items = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=pixel_values,
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
)
|
||||
]
|
||||
|
||||
return {
|
||||
"input_ids": tokenizer(input_text, return_tensors="pt")["input_ids"]
|
||||
.flatten()
|
||||
.tolist(),
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": items,
|
||||
"im_start_id": self.img_start_token_id,
|
||||
"im_end_id": self.img_end_token_id,
|
||||
|
||||
@@ -45,15 +45,21 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
|
||||
prompt=base_out.input_text,
|
||||
images=images,
|
||||
)
|
||||
|
||||
input_ids = res["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids, mm_token_id=processor.image_id
|
||||
)
|
||||
return {
|
||||
"mm_items": [
|
||||
MultimodalDataItem(
|
||||
pixel_values=res["pixel_values"],
|
||||
image_emb_mask=res["images_emb_mask"],
|
||||
image_offsets=image_offsets,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
],
|
||||
"input_ids": res["input_ids"].flatten().tolist(),
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_start_id": processor.image_start_id,
|
||||
"im_end_id": processor.image_end_id,
|
||||
"im_token_id": processor.image_id,
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
import asyncio
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
)
|
||||
@@ -57,13 +52,19 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
)
|
||||
input_ids = ret["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.im_token_id,
|
||||
)
|
||||
return {
|
||||
"input_ids": ret["input_ids"].flatten().tolist(),
|
||||
"input_ids": input_ids.tolist(),
|
||||
"mm_items": [
|
||||
MultimodalDataItem(
|
||||
pixel_values=ret["pixel_values"],
|
||||
image_grid_thws=ret["image_grid_hws"],
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
)
|
||||
],
|
||||
"im_token_id": self.im_token_id,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -69,6 +69,8 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
audio_start_id = tokenizer.audio_start_id
|
||||
audio_end_id = tokenizer.audio_end_id
|
||||
|
||||
im_start_id = tokenizer.im_start_id
|
||||
im_end_id = tokenizer.im_end_id
|
||||
im_token_id = tokenizer.unk_id
|
||||
pixel_values = res["pixel_values"]
|
||||
tgt_sizes = res["tgt_sizes"]
|
||||
@@ -104,9 +106,20 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
pixel_values = pixel_values_flat
|
||||
|
||||
items = []
|
||||
input_ids = res["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset_by_pair(
|
||||
input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id
|
||||
)
|
||||
slice_offsets = self.get_mm_items_offset_by_pair(
|
||||
input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id
|
||||
)
|
||||
image_offsets.extend(slice_offsets)
|
||||
image_offsets = sorted(image_offsets)
|
||||
|
||||
if len(pixel_values) != 0:
|
||||
item = MultimodalDataItem(
|
||||
pixel_values=pixel_values,
|
||||
image_offsets=image_offsets,
|
||||
tgt_size=tgt_sizes_flat,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
@@ -117,21 +130,30 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
and res["audio_features"] is not None
|
||||
and len(res["audio_features"]) != 0
|
||||
):
|
||||
if audio_start_id is not None and audio_end_id is not None:
|
||||
audio_offsets = self.get_mm_items_offset_by_pair(
|
||||
input_ids=input_ids,
|
||||
mm_start_id=audio_start_id,
|
||||
mm_end_id=audio_end_id,
|
||||
)
|
||||
else:
|
||||
audio_offsets = None
|
||||
item = MultimodalDataItem(
|
||||
audio_features=[res["audio_features"]],
|
||||
audio_feature_lens=res["audio_feature_lens"],
|
||||
audio_offsets=audio_offsets,
|
||||
modality=Modality.AUDIO,
|
||||
)
|
||||
items += [item]
|
||||
|
||||
return {
|
||||
"mm_items": items,
|
||||
"input_ids": res["input_ids"].flatten().tolist(),
|
||||
"input_ids": input_ids.tolist(),
|
||||
"audio_start_id": audio_start_id,
|
||||
"audio_end_id": audio_end_id,
|
||||
"im_token_id": im_token_id,
|
||||
"im_start_id": tokenizer.im_start_id,
|
||||
"im_end_id": tokenizer.im_end_id,
|
||||
"im_start_id": im_start_id,
|
||||
"im_end_id": im_end_id,
|
||||
"slice_start_id": slice_start_id,
|
||||
"slice_end_id": slice_end_id,
|
||||
}
|
||||
|
||||
@@ -135,11 +135,17 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
processor_output["im_end_id"] = self.eoi_token_index
|
||||
processor_output["im_token_id"] = self.image_token_index
|
||||
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=torch.tensor(processor_output["input_ids"]),
|
||||
mm_token_id=self.image_token_index,
|
||||
)
|
||||
|
||||
# Add metadata for image processing
|
||||
processor_output["mm_items"] = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=processor_output["pixel_values"],
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import asyncio
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.models.pixtral.image_processing_pixtral import (
|
||||
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
|
||||
)
|
||||
@@ -12,11 +10,7 @@ from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.pixtral import PixtralVisionModel
|
||||
|
||||
|
||||
@@ -108,15 +102,21 @@ class PixtralProcessor(BaseMultimodalProcessor):
|
||||
)
|
||||
|
||||
if "pixel_values" in processor_output:
|
||||
input_ids = processor_output["input_ids"].view(-1)
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=self.image_token_id,
|
||||
)
|
||||
mm_items = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=processor_output["pixel_values"],
|
||||
image_sizes=processor_output["image_sizes"],
|
||||
modality=Modality.IMAGE,
|
||||
image_offsets=image_offsets,
|
||||
)
|
||||
]
|
||||
|
||||
input_ids = processor_output["input_ids"].view(-1).tolist()
|
||||
input_ids = input_ids.tolist()
|
||||
processor_output.update(
|
||||
input_ids=input_ids,
|
||||
mm_items=mm_items,
|
||||
|
||||
@@ -135,6 +135,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
images=None if images_are_preprocessed else base_output.images,
|
||||
)
|
||||
input_ids = ret["input_ids"].flatten().tolist()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=ret["input_ids"].flatten(), mm_token_id=self.image_token_id
|
||||
)
|
||||
image_grid_thw = None
|
||||
video_grid_thw = None # TODO
|
||||
items = []
|
||||
@@ -175,6 +178,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
||||
image_grid_thws=image_grid_thw,
|
||||
video_grid_thws=video_grid_thw,
|
||||
precomputed_features=precomputed_features,
|
||||
image_offsets=image_offsets,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user