Support precomputed multimodal features for Qwen-VL and Gemma3 models. (#6136)
Co-authored-by: Yury Sulsky <ysulsky@tesla.com>
This commit is contained in:
@@ -54,21 +54,17 @@ class TestSkipTokenizerInit(CustomTestCase):
|
||||
):
|
||||
input_ids = self.get_input_ids(prompt_text)
|
||||
|
||||
request = self.get_request_json(
|
||||
input_ids=input_ids,
|
||||
return_logprob=return_logprob,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
max_new_tokens=max_new_tokens,
|
||||
stream=False,
|
||||
n=n,
|
||||
)
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"n": n,
|
||||
"stop_token_ids": [self.tokenizer.eos_token_id],
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
json=request,
|
||||
)
|
||||
ret = response.json()
|
||||
print(json.dumps(ret, indent=2))
|
||||
@@ -87,9 +83,12 @@ class TestSkipTokenizerInit(CustomTestCase):
|
||||
self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids))
|
||||
|
||||
if return_logprob:
|
||||
num_input_logprobs = len(input_ids) - request["logprob_start_len"]
|
||||
if num_input_logprobs > len(input_ids):
|
||||
num_input_logprobs -= len(input_ids)
|
||||
self.assertEqual(
|
||||
len(item["meta_info"]["input_token_logprobs"]),
|
||||
len(input_ids),
|
||||
num_input_logprobs,
|
||||
f'{len(item["meta_info"]["input_token_logprobs"])} mismatch with {len(input_ids)}',
|
||||
)
|
||||
self.assertEqual(
|
||||
@@ -113,19 +112,14 @@ class TestSkipTokenizerInit(CustomTestCase):
|
||||
requests.post(self.base_url + "/flush_cache")
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"n": n,
|
||||
"stop_token_ids": self.eos_token_id,
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
json=self.get_request_json(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=max_new_tokens,
|
||||
return_logprob=return_logprob,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
stream=False,
|
||||
n=n,
|
||||
),
|
||||
)
|
||||
ret = response.json()
|
||||
print(json.dumps(ret))
|
||||
@@ -137,19 +131,13 @@ class TestSkipTokenizerInit(CustomTestCase):
|
||||
requests.post(self.base_url + "/flush_cache")
|
||||
response_stream = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"n": n,
|
||||
"stop_token_ids": self.eos_token_id,
|
||||
},
|
||||
"stream": True,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
json=self.get_request_json(
|
||||
input_ids=input_ids,
|
||||
return_logprob=return_logprob,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
stream=True,
|
||||
n=n,
|
||||
),
|
||||
)
|
||||
|
||||
response_stream_json = []
|
||||
@@ -188,6 +176,29 @@ class TestSkipTokenizerInit(CustomTestCase):
|
||||
].tolist()
|
||||
return input_ids
|
||||
|
||||
def get_request_json(
|
||||
self,
|
||||
input_ids,
|
||||
max_new_tokens=32,
|
||||
return_logprob=False,
|
||||
top_logprobs_num=0,
|
||||
stream=False,
|
||||
n=1,
|
||||
):
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"n": n,
|
||||
"stop_token_ids": self.eos_token_id,
|
||||
},
|
||||
"stream": stream,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"logprob_start_len": 0,
|
||||
}
|
||||
|
||||
|
||||
class TestSkipTokenizerInitVLM(TestSkipTokenizerInit):
|
||||
@classmethod
|
||||
@@ -218,6 +229,14 @@ class TestSkipTokenizerInitVLM(TestSkipTokenizerInit):
|
||||
|
||||
return inputs.input_ids[0].tolist()
|
||||
|
||||
def get_request_json(self, *args, **kwargs):
|
||||
ret = super().get_request_json(*args, **kwargs)
|
||||
ret["image_data"] = [self.image_url]
|
||||
ret["logprob_start_len"] = (
|
||||
-1
|
||||
) # Do not try to calculate logprobs of image embeddings.
|
||||
return ret
|
||||
|
||||
def test_simple_decode_stream(self):
|
||||
# TODO mick
|
||||
pass
|
||||
|
||||
@@ -3,15 +3,22 @@
|
||||
|
||||
import unittest
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
Gemma3ForConditionalGeneration,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
)
|
||||
|
||||
from sglang import Engine
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.conversation import generate_chat_conv
|
||||
from sglang.srt.managers.mm_utils import embed_mm_inputs
|
||||
@@ -100,7 +107,7 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
np.testing.assert_allclose(hf_np, sg_np)
|
||||
|
||||
def get_processor_output(self):
|
||||
def get_completion_request(self) -> ChatCompletionRequest:
|
||||
json_str = f"""
|
||||
{{
|
||||
"model": "{self.model_path}",
|
||||
@@ -124,10 +131,12 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
|
||||
}}
|
||||
"""
|
||||
|
||||
req = ChatCompletionRequest.model_validate_json(json_str)
|
||||
return ChatCompletionRequest.model_validate_json(json_str)
|
||||
|
||||
def get_processor_output(self, req: Optional[ChatCompletionRequest] = None):
|
||||
if req is None:
|
||||
req = self.get_completion_request()
|
||||
conv = generate_chat_conv(req, template_name=self.chat_template)
|
||||
|
||||
text = conv.get_prompt()
|
||||
|
||||
# Process inputs using processor
|
||||
@@ -239,5 +248,129 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
|
||||
self.compare_outputs(sglang_output, hf_output)
|
||||
|
||||
|
||||
class TestQwenVLUnderstandsImage(VisionLLMLogitsBase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
cls.chat_template = "qwen2-vl"
|
||||
cls.processor = AutoProcessor.from_pretrained(
|
||||
cls.model_path, trust_remote_code=True, use_fast=True
|
||||
)
|
||||
cls.visual = (
|
||||
Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
cls.model_path, torch_dtype=torch.bfloat16
|
||||
)
|
||||
.eval()
|
||||
.visual.to(cls.device)
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.engine = Engine(
|
||||
model_path=self.model_path,
|
||||
chat_template=self.chat_template,
|
||||
device=self.device.type,
|
||||
mem_fraction_static=0.8,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
self.engine.shutdown()
|
||||
|
||||
async def test_qwen_vl_understands_image(self):
|
||||
req = self.get_completion_request()
|
||||
conv = generate_chat_conv(req, template_name=self.chat_template)
|
||||
text = conv.get_prompt()
|
||||
output = await self.engine.async_generate(
|
||||
prompt=text,
|
||||
image_data=[self.main_image],
|
||||
sampling_params=dict(temperature=0.0),
|
||||
)
|
||||
self.assertIn("taxi", output["text"].lower())
|
||||
|
||||
async def test_qwen_vl_understands_precomputed_features(self):
|
||||
req = self.get_completion_request()
|
||||
processor_output = self.get_processor_output(req=req)
|
||||
with torch.inference_mode():
|
||||
precomputed_features = self.visual(
|
||||
processor_output["pixel_values"], processor_output["image_grid_thw"]
|
||||
)
|
||||
output = await self.engine.async_generate(
|
||||
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
|
||||
image_data=[
|
||||
dict(
|
||||
modality="IMAGE",
|
||||
image_grid_thws=processor_output["image_grid_thw"],
|
||||
precomputed_features=precomputed_features,
|
||||
)
|
||||
],
|
||||
sampling_params=dict(temperature=0.0),
|
||||
)
|
||||
self.assertIn("taxi", output["text"].lower())
|
||||
|
||||
|
||||
class TestGemmaUnderstandsImage(VisionLLMLogitsBase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.model_path = "google/gemma-3-4b-it"
|
||||
cls.chat_template = "gemma-it"
|
||||
cls.processor = AutoProcessor.from_pretrained(
|
||||
cls.model_path, trust_remote_code=True, use_fast=True
|
||||
)
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
cls.model_path, torch_dtype=torch.bfloat16
|
||||
)
|
||||
cls.vision_tower = model.vision_tower.eval().to(cls.device)
|
||||
cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)
|
||||
|
||||
@classmethod
|
||||
def visual(cls, pixel_values):
|
||||
vision_outputs = cls.vision_tower(pixel_values=pixel_values).last_hidden_state
|
||||
image_features = cls.mm_projector(vision_outputs)
|
||||
return image_features
|
||||
|
||||
def setUp(self):
|
||||
self.engine = Engine(
|
||||
model_path=self.model_path,
|
||||
chat_template=self.chat_template,
|
||||
device=self.device.type,
|
||||
mem_fraction_static=0.5,
|
||||
enable_multimodal=True,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
self.engine.shutdown()
|
||||
|
||||
async def test_gemma_understands_image(self):
|
||||
req = self.get_completion_request()
|
||||
conv = generate_chat_conv(req, template_name=self.chat_template)
|
||||
text = conv.get_prompt()
|
||||
output = await self.engine.async_generate(
|
||||
prompt=text,
|
||||
image_data=[self.main_image],
|
||||
sampling_params=dict(temperature=0.0),
|
||||
)
|
||||
self.assertIn("taxi", output["text"].lower())
|
||||
|
||||
async def test_gemma_understands_precomputed_features(self):
|
||||
req = self.get_completion_request()
|
||||
processor_output = self.get_processor_output(req=req)
|
||||
with torch.inference_mode():
|
||||
precomputed_features = self.visual(processor_output["pixel_values"])
|
||||
output = await self.engine.async_generate(
|
||||
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
|
||||
image_data=[
|
||||
dict(
|
||||
modality="IMAGE",
|
||||
precomputed_features=precomputed_features,
|
||||
)
|
||||
],
|
||||
sampling_params=dict(temperature=0.0),
|
||||
)
|
||||
self.assertIn("taxi", output["text"].lower())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user