[VLM] Support chunk prefill for VLM (#6355)

Co-authored-by: yizhang2077 <1109276519@qq.com>
This commit is contained in:
Chang Su
2025-05-22 20:32:41 -07:00
committed by GitHub
parent 0a4fc73b48
commit 4685fbb888
20 changed files with 510 additions and 184 deletions

View File

@@ -5,7 +5,7 @@ import multiprocessing as mp
import os
import re
from abc import ABC, abstractmethod
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@@ -343,6 +343,33 @@ class BaseMultimodalProcessor(ABC):
out.normalize()
return out
@staticmethod
def get_mm_items_offset(
input_ids: torch.Tensor, mm_token_id: int
) -> List[Tuple[int, int]]:
"""
Get a set of range for mm_items from input_ids
Example:
input_ids = [1, 2, 3, 3, 3, 4, 3, 3]
mm_token_id = 3
return result = [(2,4),(6,7)]
"""
mask = input_ids == mm_token_id
start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0]
end_positions = (mask & ~torch.roll(mask, -1)).nonzero(as_tuple=True)[0]
return list(zip(start_positions.tolist(), end_positions.tolist()))
@staticmethod
def get_mm_items_offset_by_pair(
input_ids: torch.Tensor, mm_start_id: int, mm_end_id: int
) -> List[Tuple[int, int]]:
indices_start = (input_ids == mm_start_id).nonzero(as_tuple=True)[0] + 1
indices_end = (input_ids == mm_end_id).nonzero(as_tuple=True)[0] - 1
return list(zip(indices_start.tolist(), indices_end.tolist()))
def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]):
"""Returns true if all images are preprocessed, false if all are not, and error otherwise."""
if not mm_inputs:

View File

@@ -70,8 +70,13 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
items = []
input_ids = res["input_ids"]
image_offsets = self.get_mm_items_offset(
input_ids=input_ids, mm_token_id=self._processor.image_token_id
)
item = MultimodalDataItem(
pixel_values=res["images"],
image_offsets=image_offsets,
modality=Modality.IMAGE,
image_emb_mask=images_seq_mask,
image_spatial_crop=batched_images_spatial_crop,
@@ -80,6 +85,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
return {
"mm_items": items,
"input_ids": res["input_ids"].tolist(),
"input_ids": input_ids.tolist(),
"im_token_id": self._processor.image_token_id,
}

View File

@@ -61,6 +61,11 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
)
items = []
input_ids = ret["input_ids"].flatten()
image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.hf_config.image_token_index,
)
for i, image in enumerate(base_output.images):
if images_are_preprocessed:
pixel_values = image.pixel_values
@@ -73,12 +78,13 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
pixel_values=pixel_values,
precomputed_features=precomputed_features,
modality=Modality.IMAGE,
image_offsets=image_offsets[i],
)
items += [item]
return {
"mm_items": items,
"input_ids": ret["input_ids"].flatten().tolist(),
"input_ids": input_ids.tolist(),
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
}

View File

@@ -209,7 +209,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
return None
pixel_values = torch.cat(pixel_values, dim=0)
items = [MultimodalDataItem(pixel_values=pixel_values, modality=Modality.IMAGE)]
for idx, num_patches in enumerate(num_patches_list):
image_tokens = (
@@ -220,10 +219,21 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
input_text = input_text.replace("<image>", image_tokens, 1)
tokenizer = self._processor
input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].flatten()
image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.img_context_token_id,
)
items = [
MultimodalDataItem(
pixel_values=pixel_values,
modality=Modality.IMAGE,
image_offsets=image_offsets,
)
]
return {
"input_ids": tokenizer(input_text, return_tensors="pt")["input_ids"]
.flatten()
.tolist(),
"input_ids": input_ids.tolist(),
"mm_items": items,
"im_start_id": self.img_start_token_id,
"im_end_id": self.img_end_token_id,

View File

@@ -45,15 +45,21 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
prompt=base_out.input_text,
images=images,
)
input_ids = res["input_ids"].flatten()
image_offsets = self.get_mm_items_offset(
input_ids=input_ids, mm_token_id=processor.image_id
)
return {
"mm_items": [
MultimodalDataItem(
pixel_values=res["pixel_values"],
image_emb_mask=res["images_emb_mask"],
image_offsets=image_offsets,
modality=Modality.IMAGE,
)
],
"input_ids": res["input_ids"].flatten().tolist(),
"input_ids": input_ids.tolist(),
"im_start_id": processor.image_start_id,
"im_end_id": processor.image_end_id,
"im_token_id": processor.image_id,

View File

@@ -1,10 +1,5 @@
import asyncio
import math
from typing import List, Union
import torch
from PIL import Image
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
@@ -57,13 +52,19 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
input_text=base_output.input_text,
images=base_output.images,
)
input_ids = ret["input_ids"].flatten()
image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.im_token_id,
)
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"input_ids": input_ids.tolist(),
"mm_items": [
MultimodalDataItem(
pixel_values=ret["pixel_values"],
image_grid_thws=ret["image_grid_hws"],
modality=Modality.IMAGE,
image_offsets=image_offsets,
)
],
"im_token_id": self.im_token_id,

View File

@@ -1,5 +1,4 @@
import asyncio
import importlib
from typing import List, Optional, Union
import numpy as np

View File

@@ -69,6 +69,8 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_start_id = tokenizer.audio_start_id
audio_end_id = tokenizer.audio_end_id
im_start_id = tokenizer.im_start_id
im_end_id = tokenizer.im_end_id
im_token_id = tokenizer.unk_id
pixel_values = res["pixel_values"]
tgt_sizes = res["tgt_sizes"]
@@ -104,9 +106,20 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
pixel_values = pixel_values_flat
items = []
input_ids = res["input_ids"].flatten()
image_offsets = self.get_mm_items_offset_by_pair(
input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id
)
slice_offsets = self.get_mm_items_offset_by_pair(
input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id
)
image_offsets.extend(slice_offsets)
image_offsets = sorted(image_offsets)
if len(pixel_values) != 0:
item = MultimodalDataItem(
pixel_values=pixel_values,
image_offsets=image_offsets,
tgt_size=tgt_sizes_flat,
modality=Modality.IMAGE,
)
@@ -117,21 +130,30 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
and res["audio_features"] is not None
and len(res["audio_features"]) != 0
):
if audio_start_id is not None and audio_end_id is not None:
audio_offsets = self.get_mm_items_offset_by_pair(
input_ids=input_ids,
mm_start_id=audio_start_id,
mm_end_id=audio_end_id,
)
else:
audio_offsets = None
item = MultimodalDataItem(
audio_features=[res["audio_features"]],
audio_feature_lens=res["audio_feature_lens"],
audio_offsets=audio_offsets,
modality=Modality.AUDIO,
)
items += [item]
return {
"mm_items": items,
"input_ids": res["input_ids"].flatten().tolist(),
"input_ids": input_ids.tolist(),
"audio_start_id": audio_start_id,
"audio_end_id": audio_end_id,
"im_token_id": im_token_id,
"im_start_id": tokenizer.im_start_id,
"im_end_id": tokenizer.im_end_id,
"im_start_id": im_start_id,
"im_end_id": im_end_id,
"slice_start_id": slice_start_id,
"slice_end_id": slice_end_id,
}

View File

@@ -135,11 +135,17 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
processor_output["im_end_id"] = self.eoi_token_index
processor_output["im_token_id"] = self.image_token_index
image_offsets = self.get_mm_items_offset(
input_ids=torch.tensor(processor_output["input_ids"]),
mm_token_id=self.image_token_index,
)
# Add metadata for image processing
processor_output["mm_items"] = [
MultimodalDataItem(
pixel_values=processor_output["pixel_values"],
modality=Modality.IMAGE,
image_offsets=image_offsets,
)
]

View File

@@ -1,9 +1,7 @@
import asyncio
import math
from typing import List, Optional, Union
from typing import List, Union
import numpy as np
from transformers import PretrainedConfig
from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
)
@@ -12,11 +10,7 @@ from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.pixtral import PixtralVisionModel
@@ -108,15 +102,21 @@ class PixtralProcessor(BaseMultimodalProcessor):
)
if "pixel_values" in processor_output:
input_ids = processor_output["input_ids"].view(-1)
image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.image_token_id,
)
mm_items = [
MultimodalDataItem(
pixel_values=processor_output["pixel_values"],
image_sizes=processor_output["image_sizes"],
modality=Modality.IMAGE,
image_offsets=image_offsets,
)
]
input_ids = processor_output["input_ids"].view(-1).tolist()
input_ids = input_ids.tolist()
processor_output.update(
input_ids=input_ids,
mm_items=mm_items,

View File

@@ -135,6 +135,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
images=None if images_are_preprocessed else base_output.images,
)
input_ids = ret["input_ids"].flatten().tolist()
image_offsets = self.get_mm_items_offset(
input_ids=ret["input_ids"].flatten(), mm_token_id=self.image_token_id
)
image_grid_thw = None
video_grid_thw = None # TODO
items = []
@@ -175,6 +178,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
image_grid_thws=image_grid_thw,
video_grid_thws=video_grid_thw,
precomputed_features=precomputed_features,
image_offsets=image_offsets,
modality=Modality.IMAGE,
)
]