Refactor vlm embedding routine to use precomputed feature (#6543)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -81,7 +81,7 @@ suites = {
|
||||
TestFile("test_update_weights_from_tensor.py", 48),
|
||||
TestFile("test_vertex_endpoint.py", 31),
|
||||
TestFile("test_vision_chunked_prefill.py", 175),
|
||||
TestFile("test_vlm_accuracy.py", 60),
|
||||
TestFile("test_vlm_input_format.py", 300),
|
||||
TestFile("test_vision_openai_server_a.py", 700),
|
||||
TestFile("test_vision_openai_server_b.py", 700),
|
||||
TestFile("test_w8a8_quantization.py", 46),
|
||||
|
||||
@@ -10,15 +10,8 @@ import requests
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
Gemma3ForConditionalGeneration,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
)
|
||||
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
||||
|
||||
from sglang import Engine
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.conversation import generate_chat_conv
|
||||
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache
|
||||
@@ -41,9 +34,6 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
|
||||
def setUpClass(cls):
|
||||
cls.image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
cls.model_path = ""
|
||||
cls.chat_template = ""
|
||||
cls.processor = ""
|
||||
response = requests.get(cls.image_url)
|
||||
cls.main_image = Image.open(BytesIO(response.content))
|
||||
|
||||
@@ -274,131 +264,3 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
|
||||
)
|
||||
|
||||
self.compare_outputs(sglang_output, hf_output)
|
||||
|
||||
|
||||
class TestQwenVLUnderstandsImage(VisionLLMLogitsBase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
cls.chat_template = "qwen2-vl"
|
||||
cls.processor = AutoProcessor.from_pretrained(
|
||||
cls.model_path, trust_remote_code=True, use_fast=True
|
||||
)
|
||||
cls.visual = (
|
||||
Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
cls.model_path, torch_dtype=torch.bfloat16
|
||||
)
|
||||
.eval()
|
||||
.visual.to(cls.device)
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.engine = Engine(
|
||||
model_path=self.model_path,
|
||||
chat_template=self.chat_template,
|
||||
device=self.device.type,
|
||||
mem_fraction_static=0.8,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
self.engine.shutdown()
|
||||
|
||||
async def test_qwen_vl_understands_image(self):
|
||||
req = self.get_completion_request()
|
||||
conv = generate_chat_conv(req, template_name=self.chat_template)
|
||||
text = conv.get_prompt()
|
||||
output = await self.engine.async_generate(
|
||||
prompt=text,
|
||||
image_data=[self.main_image],
|
||||
sampling_params=dict(temperature=0.0),
|
||||
)
|
||||
self.assertIn("taxi", output["text"].lower())
|
||||
|
||||
async def test_qwen_vl_understands_precomputed_features(self):
|
||||
req = self.get_completion_request()
|
||||
processor_output = self.get_processor_output(req=req)
|
||||
with torch.inference_mode():
|
||||
precomputed_features = self.visual(
|
||||
processor_output["pixel_values"], processor_output["image_grid_thw"]
|
||||
)
|
||||
output = await self.engine.async_generate(
|
||||
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
|
||||
image_data=[
|
||||
dict(
|
||||
modality="IMAGE",
|
||||
image_grid_thws=processor_output["image_grid_thw"],
|
||||
precomputed_features=precomputed_features,
|
||||
)
|
||||
],
|
||||
sampling_params=dict(temperature=0.0),
|
||||
)
|
||||
self.assertIn("taxi", output["text"].lower())
|
||||
|
||||
|
||||
class TestGemmaUnderstandsImage(VisionLLMLogitsBase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.model_path = "google/gemma-3-4b-it"
|
||||
cls.chat_template = "gemma-it"
|
||||
cls.processor = AutoProcessor.from_pretrained(
|
||||
cls.model_path, trust_remote_code=True, use_fast=True
|
||||
)
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
cls.model_path, torch_dtype=torch.bfloat16
|
||||
)
|
||||
cls.vision_tower = model.vision_tower.eval().to(cls.device)
|
||||
cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)
|
||||
|
||||
@classmethod
|
||||
def visual(cls, pixel_values):
|
||||
vision_outputs = cls.vision_tower(pixel_values=pixel_values).last_hidden_state
|
||||
image_features = cls.mm_projector(vision_outputs)
|
||||
return image_features
|
||||
|
||||
def setUp(self):
|
||||
self.engine = Engine(
|
||||
model_path=self.model_path,
|
||||
chat_template=self.chat_template,
|
||||
device=self.device.type,
|
||||
mem_fraction_static=0.5,
|
||||
enable_multimodal=True,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
self.engine.shutdown()
|
||||
|
||||
async def test_gemma_understands_image(self):
|
||||
req = self.get_completion_request()
|
||||
conv = generate_chat_conv(req, template_name=self.chat_template)
|
||||
text = conv.get_prompt()
|
||||
output = await self.engine.async_generate(
|
||||
prompt=text,
|
||||
image_data=[self.main_image],
|
||||
sampling_params=dict(temperature=0.0),
|
||||
)
|
||||
self.assertIn("taxi", output["text"].lower())
|
||||
|
||||
async def test_gemma_understands_precomputed_features(self):
|
||||
req = self.get_completion_request()
|
||||
processor_output = self.get_processor_output(req=req)
|
||||
with torch.inference_mode():
|
||||
precomputed_features = self.visual(processor_output["pixel_values"])
|
||||
output = await self.engine.async_generate(
|
||||
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
|
||||
image_data=[
|
||||
dict(
|
||||
modality="IMAGE",
|
||||
precomputed_features=precomputed_features,
|
||||
)
|
||||
],
|
||||
sampling_params=dict(temperature=0.0),
|
||||
)
|
||||
self.assertIn("taxi", output["text"].lower())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
187
test/srt/test_vlm_input_format.py
Normal file
187
test/srt/test_vlm_input_format.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import json
|
||||
import unittest
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
Gemma3ForConditionalGeneration,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
)
|
||||
|
||||
from sglang import Engine
|
||||
from sglang.srt.conversation import generate_chat_conv
|
||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
||||
|
||||
TEST_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||
|
||||
|
||||
class VLMInputTestBase:
|
||||
model_path = None
|
||||
chat_template = None
|
||||
processor = None
|
||||
visual = None # Should be a callable for precomputed features
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
assert cls.model_path is not None, "Set model_path in subclass"
|
||||
assert cls.chat_template is not None, "Set chat_template in subclass"
|
||||
cls.image_url = TEST_IMAGE_URL
|
||||
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
response = requests.get(cls.image_url)
|
||||
cls.main_image = Image.open(BytesIO(response.content))
|
||||
cls.processor = AutoProcessor.from_pretrained(
|
||||
cls.model_path, trust_remote_code=True, use_fast=True
|
||||
)
|
||||
cls._init_visual()
|
||||
|
||||
@classmethod
|
||||
def _init_visual(cls):
|
||||
"""Override in subclass to set up cls.visual as a callable for precomputed features."""
|
||||
raise NotImplementedError
|
||||
|
||||
def setUp(self):
|
||||
self.engine = Engine(
|
||||
model_path=self.model_path,
|
||||
chat_template=self.chat_template,
|
||||
device=self.device.type,
|
||||
mem_fraction_static=0.8,
|
||||
enable_multimodal=True,
|
||||
disable_cuda_graph=True,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
self.engine.shutdown()
|
||||
|
||||
def get_completion_request(self) -> ChatCompletionRequest:
|
||||
json_structure = {
|
||||
"model": self.model_path,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": self.image_url}},
|
||||
{"type": "text", "text": "What's in this picture?"},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
json_str = json.dumps(json_structure)
|
||||
return ChatCompletionRequest.model_validate_json(json_str)
|
||||
|
||||
def get_processor_output(self, req: Optional[ChatCompletionRequest] = None):
|
||||
if req is None:
|
||||
req = self.get_completion_request()
|
||||
conv = generate_chat_conv(req, template_name=self.chat_template)
|
||||
text = conv.get_prompt()
|
||||
|
||||
# Process inputs using processor
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=[self.main_image],
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
return inputs
|
||||
|
||||
async def test_understands_image(self):
|
||||
req = self.get_completion_request()
|
||||
conv = generate_chat_conv(req, template_name=self.chat_template)
|
||||
text = conv.get_prompt()
|
||||
output = await self.engine.async_generate(
|
||||
prompt=text,
|
||||
image_data=[self.main_image],
|
||||
sampling_params=dict(temperature=0.0),
|
||||
)
|
||||
self.assertIn("taxi", output["text"].lower())
|
||||
|
||||
async def test_understands_precomputed_features(self):
|
||||
req = self.get_completion_request()
|
||||
processor_output = self.get_processor_output(req=req)
|
||||
with torch.inference_mode():
|
||||
precomputed_features = self.__class__.visual(processor_output)
|
||||
output = await self.engine.async_generate(
|
||||
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
|
||||
image_data=[
|
||||
self._precomputed_image_data(processor_output, precomputed_features)
|
||||
],
|
||||
sampling_params=dict(temperature=0.0),
|
||||
)
|
||||
self.assertIn("taxi", output["text"].lower())
|
||||
|
||||
async def test_understands_pixel_values(self):
|
||||
req = self.get_completion_request()
|
||||
processor_output = self.get_processor_output(req=req)
|
||||
output = await self.engine.async_generate(
|
||||
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
|
||||
image_data=[self._pixel_values_image_data(processor_output)],
|
||||
sampling_params=dict(temperature=0.0),
|
||||
)
|
||||
self.assertIn("taxi", output["text"].lower())
|
||||
|
||||
def _precomputed_image_data(self, processor_output, precomputed_features):
|
||||
"""This should not be overridden."""
|
||||
return dict(
|
||||
modality="IMAGE",
|
||||
precomputed_features=precomputed_features,
|
||||
)
|
||||
|
||||
def _pixel_values_image_data(self, processor_output):
|
||||
"""Override in subclass to pass the correct set of arguments."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TestQwenVLUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestCase):
|
||||
model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
chat_template = "qwen2-vl"
|
||||
|
||||
@classmethod
|
||||
def _init_visual(cls):
|
||||
cls.visual_model = (
|
||||
Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
cls.model_path, torch_dtype=torch.bfloat16
|
||||
)
|
||||
.eval()
|
||||
.visual.to(cls.device)
|
||||
)
|
||||
cls.visual = lambda processor_output: cls.visual_model(
|
||||
processor_output["pixel_values"], processor_output["image_grid_thw"]
|
||||
)
|
||||
|
||||
def _pixel_values_image_data(self, processor_output):
|
||||
return dict(
|
||||
modality="IMAGE",
|
||||
image_grid_thws=processor_output["image_grid_thw"],
|
||||
pixel_values=processor_output["pixel_values"],
|
||||
)
|
||||
|
||||
|
||||
class TestGemmaUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestCase):
|
||||
model_path = "google/gemma-3-4b-it"
|
||||
chat_template = "gemma-it"
|
||||
|
||||
@classmethod
|
||||
def _init_visual(cls):
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
cls.model_path, torch_dtype=torch.bfloat16
|
||||
)
|
||||
cls.vision_tower = model.vision_tower.eval().to(cls.device)
|
||||
cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)
|
||||
cls.visual = lambda processor_output: cls.mm_projector(
|
||||
cls.vision_tower(
|
||||
pixel_values=processor_output["pixel_values"]
|
||||
).last_hidden_state
|
||||
)
|
||||
|
||||
def _pixel_values_image_data(self, processor_output):
|
||||
return dict(
|
||||
modality="IMAGE",
|
||||
pixel_values=processor_output["pixel_values"][0],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user