refactor: bug fixes and refactor for vlm (#4661)

This commit is contained in:
Mick
2025-03-23 13:48:49 +08:00
committed by GitHub
parent ca75741e86
commit 11577cedb7
31 changed files with 770 additions and 735 deletions

View File

@@ -13,6 +13,8 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.managers.mm_utils import embed_image_inputs
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs
@@ -168,10 +170,14 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
).eval()
cls.model.to(cls.device)
async def test_encode_output(self):
async def test_vlm_embedding_output(self):
"""
Compares the embedding output of vlm
"""
inputs = self.get_processor_output()
with torch.no_grad():
# hf
model_inputs = {
"input_ids": inputs.input_ids,
"image_bound": inputs.image_bound,
@@ -183,22 +189,20 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
)
hf_output = hf_output.squeeze(0)
with torch.no_grad():
# sglang
model = self.get_sglang_model()
input_ids = inputs["input_ids"].to(self.device).flatten()
image_inputs = model._parse_and_validate_inputs(
sglang_output = embed_image_inputs(
image_input=ImageInputs(
pixel_values=inputs["pixel_values"][0],
tgt_sizes=inputs["tgt_sizes"][0],
),
input_ids=input_ids,
**{
"pixel_values": [inputs["pixel_values"]],
"tgt_sizes": [inputs["tgt_sizes"]],
"im_start_id": self.tokenizer.im_start_id,
"im_end_id": self.tokenizer.im_end_id,
"slice_start_id": self.tokenizer.slice_start_id,
"slice_end_id": self.tokenizer.slice_end_id,
},
)
(sglang_output, _) = model.get_embedding(
input_ids=input_ids, image_inputs=image_inputs
input_embedding=model.get_input_embeddings(),
image_embedding_func=model.get_image_features,
placeholder_token_ids=[
self.processor.tokenizer.unk_token_id,
],
)
self.compare_outputs(sglang_output, hf_output)