refactor: multimodal data (#4754)
This commit is contained in:
@@ -155,9 +155,7 @@ class TestOpenAIVisionServer(CustomTestCase):
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
|
||||
},
|
||||
"image_url": {"url": IMAGE_MAN_IRONING_URL},
|
||||
"modalities": "multi-images",
|
||||
},
|
||||
{
|
||||
@@ -399,14 +397,14 @@ class TestOpenAIVisionServer(CustomTestCase):
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
},
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {"url": f"{audio_file_name}"},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import unittest
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -14,7 +15,11 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.conversation import generate_chat_conv
|
||||
from sglang.srt.managers.mm_utils import embed_mm_inputs
|
||||
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Modality,
|
||||
MultimodalDataItem,
|
||||
MultimodalInputs,
|
||||
)
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -195,14 +200,35 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
|
||||
# sglang
|
||||
model = self.get_sglang_model()
|
||||
input_ids = inputs["input_ids"].to(self.device).flatten()
|
||||
|
||||
pixel_values = inputs["pixel_values"]
|
||||
tgt_sizes = inputs["tgt_sizes"]
|
||||
pixel_values_flat: List[torch.Tensor] = []
|
||||
tgt_sizes_flat: List[torch.Tensor] = []
|
||||
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
|
||||
# per image
|
||||
if len(pixel_b) != len(tgt_b):
|
||||
raise ValueError(
|
||||
"Inconsistent N lengths, found: "
|
||||
f"{len(pixel_b)} vs {len(tgt_b)}"
|
||||
)
|
||||
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
|
||||
pixel_values_flat += [pixel_n]
|
||||
tgt_sizes_flat += [tgt_n]
|
||||
sglang_output = embed_mm_inputs(
|
||||
mm_input=MultimodalInputs(
|
||||
pixel_values=inputs["pixel_values"][0],
|
||||
tgt_sizes=inputs["tgt_sizes"][0],
|
||||
mm_inputs=MultimodalInputs(
|
||||
mm_items=[
|
||||
MultimodalDataItem(
|
||||
pixel_values=pixel_values_flat,
|
||||
tgt_size=tgt_sizes_flat,
|
||||
modality=Modality.IMAGE,
|
||||
pad_value=self.processor.tokenizer.unk_token_id,
|
||||
)
|
||||
]
|
||||
),
|
||||
input_ids=input_ids,
|
||||
input_embedding=model.get_input_embeddings(),
|
||||
mm_data_embedding_func=model.get_image_features,
|
||||
image_data_embedding_func=model.get_image_feature,
|
||||
placeholder_token_ids=[
|
||||
self.processor.tokenizer.unk_token_id,
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user