refactor: bug fixes and refactor for vlm (#4661)
This commit is contained in:
@@ -23,6 +23,17 @@ from sglang.test.test_utils import (
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
# image
|
||||
IMAGE_MAN_IRONING_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/man_ironing_on_back_of_suv.png"
|
||||
IMAGE_SGL_LOGO_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/sgl_logo.png"
|
||||
|
||||
# video
|
||||
VIDEO_JOBS_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/videos/jobs_presenting_ipod.mp4"
|
||||
|
||||
# audio
|
||||
AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/Trump_WEF_2018_10s.mp3"
|
||||
AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
|
||||
|
||||
|
||||
class TestOpenAIVisionServer(unittest.TestCase):
|
||||
@classmethod
|
||||
@@ -58,9 +69,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||
},
|
||||
"image_url": {"url": IMAGE_MAN_IRONING_URL},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
@@ -96,9 +105,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||
},
|
||||
"image_url": {"url": IMAGE_MAN_IRONING_URL},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
@@ -153,9 +160,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
|
||||
},
|
||||
"image_url": {"url": IMAGE_SGL_LOGO_URL},
|
||||
"modalities": "multi-images",
|
||||
},
|
||||
{
|
||||
@@ -242,10 +247,12 @@ class TestOpenAIVisionServer(unittest.TestCase):
|
||||
]
|
||||
return messages
|
||||
|
||||
def test_video_chat_completion(self):
|
||||
url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
|
||||
def get_or_download_file(self, url: str) -> str:
|
||||
cache_dir = os.path.expanduser("~/.cache")
|
||||
file_path = os.path.join(cache_dir, "jobs.mp4")
|
||||
if url is None:
|
||||
raise ValueError()
|
||||
file_name = url.split("/")[-1]
|
||||
file_path = os.path.join(cache_dir, file_name)
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
@@ -254,6 +261,11 @@ class TestOpenAIVisionServer(unittest.TestCase):
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
return file_path
|
||||
|
||||
def test_video_chat_completion(self):
|
||||
url = VIDEO_JOBS_URL
|
||||
file_path = self.get_or_download_file(url)
|
||||
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
@@ -289,6 +301,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
|
||||
"present" in video_response
|
||||
or "examine" in video_response
|
||||
or "display" in video_response
|
||||
or "hold" in video_response
|
||||
)
|
||||
assert "black" in video_response or "dark" in video_response
|
||||
self.assertIsNotNone(video_response)
|
||||
@@ -312,9 +325,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||
},
|
||||
"image_url": {"url": IMAGE_MAN_IRONING_URL},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
@@ -344,18 +355,14 @@ class TestOpenAIVisionServer(unittest.TestCase):
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||
},
|
||||
"image_url": {"url": IMAGE_MAN_IRONING_URL},
|
||||
}
|
||||
)
|
||||
elif image_id == 1:
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
|
||||
},
|
||||
"image_url": {"url": IMAGE_SGL_LOGO_URL},
|
||||
}
|
||||
)
|
||||
else:
|
||||
@@ -465,9 +472,7 @@ class TestVLMContextLengthIssue(unittest.TestCase):
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||
},
|
||||
"image_url": {"url": IMAGE_MAN_IRONING_URL},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
|
||||
@@ -13,6 +13,8 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.conversation import generate_chat_conv
|
||||
from sglang.srt.managers.mm_utils import embed_image_inputs
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -168,10 +170,14 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
|
||||
).eval()
|
||||
cls.model.to(cls.device)
|
||||
|
||||
async def test_encode_output(self):
|
||||
async def test_vlm_embedding_output(self):
|
||||
"""
|
||||
Compares the embedding output of vlm
|
||||
"""
|
||||
inputs = self.get_processor_output()
|
||||
|
||||
with torch.no_grad():
|
||||
# hf
|
||||
model_inputs = {
|
||||
"input_ids": inputs.input_ids,
|
||||
"image_bound": inputs.image_bound,
|
||||
@@ -183,22 +189,20 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
|
||||
)
|
||||
hf_output = hf_output.squeeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
# sglang
|
||||
model = self.get_sglang_model()
|
||||
input_ids = inputs["input_ids"].to(self.device).flatten()
|
||||
image_inputs = model._parse_and_validate_inputs(
|
||||
sglang_output = embed_image_inputs(
|
||||
image_input=ImageInputs(
|
||||
pixel_values=inputs["pixel_values"][0],
|
||||
tgt_sizes=inputs["tgt_sizes"][0],
|
||||
),
|
||||
input_ids=input_ids,
|
||||
**{
|
||||
"pixel_values": [inputs["pixel_values"]],
|
||||
"tgt_sizes": [inputs["tgt_sizes"]],
|
||||
"im_start_id": self.tokenizer.im_start_id,
|
||||
"im_end_id": self.tokenizer.im_end_id,
|
||||
"slice_start_id": self.tokenizer.slice_start_id,
|
||||
"slice_end_id": self.tokenizer.slice_end_id,
|
||||
},
|
||||
)
|
||||
(sglang_output, _) = model.get_embedding(
|
||||
input_ids=input_ids, image_inputs=image_inputs
|
||||
input_embedding=model.get_input_embeddings(),
|
||||
image_embedding_func=model.get_image_features,
|
||||
placeholder_token_ids=[
|
||||
self.processor.tokenizer.unk_token_id,
|
||||
],
|
||||
)
|
||||
|
||||
self.compare_outputs(sglang_output, hf_output)
|
||||
|
||||
Reference in New Issue
Block a user