[VLM] Support chunk prefill for VLM (#6355)

Co-authored-by: yizhang2077 <1109276519@qq.com>
This commit is contained in:
Chang Su
2025-05-22 20:32:41 -07:00
committed by GitHub
parent 0a4fc73b48
commit 4685fbb888
20 changed files with 510 additions and 184 deletions

View File

@@ -21,7 +21,10 @@ from transformers import (
from sglang import Engine
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.managers.mm_utils import embed_mm_inputs
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
)
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
@@ -188,6 +191,7 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
.eval()
.to(cls.device)
)
init_embedding_cache(0)
async def test_vlm_embedding_output(self):
"""
@@ -226,17 +230,41 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
pixel_values_flat += [pixel_n]
tgt_sizes_flat += [tgt_n]
im_start_id, im_end_id = (
self.tokenizer.im_start_id,
self.tokenizer.im_end_id,
)
slice_start_id, slice_end_id = (
self.tokenizer.slice_start_id,
self.tokenizer.slice_end_id,
)
image_offsets = BaseMultimodalProcessor.get_mm_items_offset_by_pair(
input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id
)
slice_offsets = BaseMultimodalProcessor.get_mm_items_offset_by_pair(
input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id
)
image_offsets.extend(slice_offsets)
image_offsets = sorted(image_offsets)
sglang_output = embed_mm_inputs(
mm_inputs=MultimodalInputs(
mm_items=[
MultimodalDataItem(
pixel_values=pixel_values_flat,
tgt_size=tgt_sizes_flat,
modality=Modality.IMAGE,
pad_value=self.processor.tokenizer.unk_token_id,
)
]
),
mm_inputs_list=[
MultimodalInputs(
mm_items=[
MultimodalDataItem(
pixel_values=pixel_values_flat,
image_offsets=image_offsets,
tgt_size=tgt_sizes_flat,
modality=Modality.IMAGE,
pad_value=self.processor.tokenizer.unk_token_id,
)
]
),
],
extend_prefix_lens=[0],
extend_seq_lens=[input_ids.shape[0]],
input_ids=input_ids,
input_embedding=model.get_input_embeddings(),
image_data_embedding_func=model.get_image_feature,