Refactor vlm embedding routine to use precomputed feature (#6543)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
Xinyuan Tong
2025-05-24 18:39:21 -07:00
committed by GitHub
parent 0d47788025
commit 681fdc264b
8 changed files with 285 additions and 203 deletions

View File

@@ -10,15 +10,8 @@ import requests
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import (
AutoModel,
AutoProcessor,
AutoTokenizer,
Gemma3ForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
)
from transformers import AutoModel, AutoProcessor, AutoTokenizer
from sglang import Engine
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache
@@ -41,9 +34,6 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
def setUpClass(cls):
cls.image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cls.model_path = ""
cls.chat_template = ""
cls.processor = ""
response = requests.get(cls.image_url)
cls.main_image = Image.open(BytesIO(response.content))
@@ -274,131 +264,3 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
)
self.compare_outputs(sglang_output, hf_output)
class TestQwenVLUnderstandsImage(VisionLLMLogitsBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
cls.chat_template = "qwen2-vl"
cls.processor = AutoProcessor.from_pretrained(
cls.model_path, trust_remote_code=True, use_fast=True
)
cls.visual = (
Qwen2_5_VLForConditionalGeneration.from_pretrained(
cls.model_path, torch_dtype=torch.bfloat16
)
.eval()
.visual.to(cls.device)
)
def setUp(self):
self.engine = Engine(
model_path=self.model_path,
chat_template=self.chat_template,
device=self.device.type,
mem_fraction_static=0.8,
)
def tearDown(self):
self.engine.shutdown()
async def test_qwen_vl_understands_image(self):
req = self.get_completion_request()
conv = generate_chat_conv(req, template_name=self.chat_template)
text = conv.get_prompt()
output = await self.engine.async_generate(
prompt=text,
image_data=[self.main_image],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
async def test_qwen_vl_understands_precomputed_features(self):
req = self.get_completion_request()
processor_output = self.get_processor_output(req=req)
with torch.inference_mode():
precomputed_features = self.visual(
processor_output["pixel_values"], processor_output["image_grid_thw"]
)
output = await self.engine.async_generate(
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
image_data=[
dict(
modality="IMAGE",
image_grid_thws=processor_output["image_grid_thw"],
precomputed_features=precomputed_features,
)
],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
class TestGemmaUnderstandsImage(VisionLLMLogitsBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model_path = "google/gemma-3-4b-it"
cls.chat_template = "gemma-it"
cls.processor = AutoProcessor.from_pretrained(
cls.model_path, trust_remote_code=True, use_fast=True
)
model = Gemma3ForConditionalGeneration.from_pretrained(
cls.model_path, torch_dtype=torch.bfloat16
)
cls.vision_tower = model.vision_tower.eval().to(cls.device)
cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)
@classmethod
def visual(cls, pixel_values):
vision_outputs = cls.vision_tower(pixel_values=pixel_values).last_hidden_state
image_features = cls.mm_projector(vision_outputs)
return image_features
def setUp(self):
self.engine = Engine(
model_path=self.model_path,
chat_template=self.chat_template,
device=self.device.type,
mem_fraction_static=0.5,
enable_multimodal=True,
)
def tearDown(self):
self.engine.shutdown()
async def test_gemma_understands_image(self):
req = self.get_completion_request()
conv = generate_chat_conv(req, template_name=self.chat_template)
text = conv.get_prompt()
output = await self.engine.async_generate(
prompt=text,
image_data=[self.main_image],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
async def test_gemma_understands_precomputed_features(self):
req = self.get_completion_request()
processor_output = self.get_processor_output(req=req)
with torch.inference_mode():
precomputed_features = self.visual(processor_output["pixel_values"])
output = await self.engine.async_generate(
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
image_data=[
dict(
modality="IMAGE",
precomputed_features=precomputed_features,
)
],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
if __name__ == "__main__":
unittest.main()