Refactor vlm embedding routine to use precomputed feature (#6543)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -252,40 +252,36 @@ def get_embedding_chunk(
|
|||||||
return embedding_chunk, start_index, end_index
|
return embedding_chunk, start_index, end_index
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_and_mask(
|
def _get_precomputed_embedding(
|
||||||
|
items: List[MultimodalDataItem],
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
If all items have precomputed_features, return their concatenation.
|
||||||
|
If some but not all have precomputed_features, raise NotImplementedError.
|
||||||
|
If none have precomputed_features, return None.
|
||||||
|
"""
|
||||||
|
precomputed_features = [item.precomputed_features for item in items]
|
||||||
|
if any(feature is not None for feature in precomputed_features):
|
||||||
|
if not all(feature is not None for feature in precomputed_features):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"MM inputs where only some items are precomputed."
|
||||||
|
)
|
||||||
|
result = torch.concat(precomputed_features)
|
||||||
|
# some models embedding is 3-dim, reshape it to 2-dim (similar to get_embedding_chunk)
|
||||||
|
result = result.reshape(-1, result.shape[-1])
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_chunked_prefill_embedding(
|
||||||
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
|
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
|
||||||
embedding_items: List[MultimodalDataItem],
|
embedding_items: List[MultimodalDataItem],
|
||||||
placeholder_tensor: torch.Tensor,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
items_size: List[int],
|
items_size: List[int],
|
||||||
prefix_length: List[int],
|
prefix_length: List[int],
|
||||||
extend_length: List[int],
|
extend_length: List[int],
|
||||||
items_offset_list: List[List[Tuple[int, int]]],
|
items_offset_list: List[List[Tuple[int, int]]],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
"""
|
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
|
||||||
Generate multimodal embeddings and create a mask for identifying their positions in the input sequence.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data_embedding_func: Function that generates embeddings for multimodal items
|
|
||||||
embedding_items: List of multimodal items to embed
|
|
||||||
placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content
|
|
||||||
input_ids: The input token IDs tensor
|
|
||||||
items_size: Cumulative sizes of multimodal items per request
|
|
||||||
prefix_length: Prefix lengths for each request
|
|
||||||
extend_length: Sequence lengths for each request
|
|
||||||
items_offset_list: List of offset ranges for multimodal items in each request
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple containing:
|
|
||||||
- The generated embeddings tensor
|
|
||||||
- A boolean mask tensor indicating where these embeddings should be placed
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AssertionError: If the number of multimodal tokens in input_ids doesn't match
|
|
||||||
the number of tokens in the generated embeddings
|
|
||||||
"""
|
|
||||||
# 1. Get the embedding
|
|
||||||
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
|
|
||||||
embedding_list = []
|
embedding_list = []
|
||||||
for i in range(len(items_size) - 1):
|
for i in range(len(items_size) - 1):
|
||||||
if items_size[i] == items_size[i + 1]:
|
if items_size[i] == items_size[i + 1]:
|
||||||
@@ -321,21 +317,28 @@ def get_embedding_and_mask(
|
|||||||
embedding_cache.free(embedding_items_hash)
|
embedding_cache.free(embedding_items_hash)
|
||||||
embedding_list.append(embedding_per_req_chunk)
|
embedding_list.append(embedding_per_req_chunk)
|
||||||
if len(embedding_list) == 0:
|
if len(embedding_list) == 0:
|
||||||
return None, None
|
return None
|
||||||
embedding = torch.concat(embedding_list, dim=0)
|
return torch.concat(embedding_list, dim=0)
|
||||||
# 2. Check the embedding
|
|
||||||
num_mm_tokens_in_embedding = embedding.shape[0]
|
|
||||||
special_multimodal_mask = torch.isin(
|
|
||||||
input_ids,
|
|
||||||
placeholder_tensor,
|
|
||||||
).unsqueeze(-1)
|
|
||||||
|
|
||||||
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
|
|
||||||
|
def _get_multimodal_mask(
|
||||||
|
input_ids: torch.Tensor, placeholder_tensor: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.isin(input_ids, placeholder_tensor).unsqueeze(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def _adjust_embedding_length(
|
||||||
|
embedding: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
logger,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
num_mm_tokens_in_embedding = embedding.shape[0]
|
||||||
|
num_mm_tokens_in_input_ids = mask.sum().item()
|
||||||
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
|
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Number of tokens in multimodal embedding does not match those in the input text. "
|
f"Number of tokens in multimodal embedding does not match those in the input text. "
|
||||||
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
|
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
|
||||||
"tokens from multimodal embeddings."
|
f"tokens from multimodal embeddings."
|
||||||
)
|
)
|
||||||
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
|
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
|
||||||
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
|
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
|
||||||
@@ -353,7 +356,54 @@ def get_embedding_and_mask(
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
|
f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
|
||||||
)
|
)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_and_mask(
|
||||||
|
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
|
||||||
|
embedding_items: List[MultimodalDataItem],
|
||||||
|
placeholder_tensor: torch.Tensor,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
items_size: List[int],
|
||||||
|
prefix_length: List[int],
|
||||||
|
extend_length: List[int],
|
||||||
|
items_offset_list: List[List[Tuple[int, int]]],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Generate multimodal embeddings and create a mask for identifying their positions in the input sequence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_embedding_func: Function that generates embeddings for multimodal items
|
||||||
|
embedding_items: List of multimodal items to embed
|
||||||
|
placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content
|
||||||
|
input_ids: The input token IDs tensor
|
||||||
|
items_size: Cumulative sizes of multimodal items per request
|
||||||
|
prefix_length: Prefix lengths for each request
|
||||||
|
extend_length: Sequence lengths for each request
|
||||||
|
items_offset_list: List of offset ranges for multimodal items in each request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing:
|
||||||
|
- The generated embeddings tensor
|
||||||
|
- A boolean mask tensor indicating where these embeddings should be placed
|
||||||
|
"""
|
||||||
|
# 1. Get embedding
|
||||||
|
embedding = _get_precomputed_embedding(embedding_items)
|
||||||
|
if embedding is None:
|
||||||
|
embedding = _get_chunked_prefill_embedding(
|
||||||
|
data_embedding_func,
|
||||||
|
embedding_items,
|
||||||
|
items_size,
|
||||||
|
prefix_length,
|
||||||
|
extend_length,
|
||||||
|
items_offset_list,
|
||||||
|
)
|
||||||
|
if embedding is None:
|
||||||
|
return None, None
|
||||||
|
# 2. Get mask
|
||||||
|
special_multimodal_mask = _get_multimodal_mask(input_ids, placeholder_tensor)
|
||||||
|
# 3. Adjust embedding length if needed
|
||||||
|
embedding = _adjust_embedding_length(embedding, special_multimodal_mask, logger)
|
||||||
return embedding, special_multimodal_mask
|
return embedding, special_multimodal_mask
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -144,12 +144,11 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
|
|
||||||
if base_output.images:
|
if base_output.images:
|
||||||
if images_are_preprocessed:
|
if images_are_preprocessed:
|
||||||
image_grid_thw = torch.concat(
|
all_image_grid_thws = [
|
||||||
[
|
item.image_grid_thws
|
||||||
torch.as_tensor(item.image_grid_thws)
|
for item in base_output.images
|
||||||
for item in base_output.images
|
if item.image_grid_thws is not None
|
||||||
]
|
]
|
||||||
)
|
|
||||||
all_pixel_values = [
|
all_pixel_values = [
|
||||||
item.pixel_values
|
item.pixel_values
|
||||||
for item in base_output.images
|
for item in base_output.images
|
||||||
@@ -160,6 +159,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
|
|||||||
for item in base_output.images
|
for item in base_output.images
|
||||||
if item.precomputed_features is not None
|
if item.precomputed_features is not None
|
||||||
]
|
]
|
||||||
|
image_grid_thw = (
|
||||||
|
torch.concat(all_image_grid_thws) if all_image_grid_thws else None
|
||||||
|
)
|
||||||
pixel_values = (
|
pixel_values = (
|
||||||
torch.concat(all_pixel_values) if all_pixel_values else None
|
torch.concat(all_pixel_values) if all_pixel_values else None
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -282,13 +282,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|||||||
Returns:
|
Returns:
|
||||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||||
"""
|
"""
|
||||||
if any(item.precomputed_features is not None for item in items):
|
|
||||||
if not all(item.precomputed_features is not None for item in items):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"MM inputs where only some items are precomputed."
|
|
||||||
)
|
|
||||||
return torch.concat([item.precomputed_features for item in items])
|
|
||||||
|
|
||||||
# Process images one by one to handle flatten_batch=True constraint in vision_tower
|
# Process images one by one to handle flatten_batch=True constraint in vision_tower
|
||||||
all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
||||||
vision_outputs_list = []
|
vision_outputs_list = []
|
||||||
|
|||||||
@@ -499,12 +499,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|||||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||||
|
|
||||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
if any(item.precomputed_features is not None for item in items):
|
|
||||||
if not all(item.precomputed_features is not None for item in items):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"MM inputs where only some items are precomputed."
|
|
||||||
)
|
|
||||||
return torch.concat([item.precomputed_features for item in items])
|
|
||||||
# in qwen-vl, last dim is the same
|
# in qwen-vl, last dim is the same
|
||||||
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
||||||
self.visual.dtype
|
self.visual.dtype
|
||||||
|
|||||||
@@ -486,12 +486,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||||
|
|
||||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
if any(item.precomputed_features is not None for item in items):
|
|
||||||
if not all(item.precomputed_features is not None for item in items):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"MM inputs where only some items are precomputed."
|
|
||||||
)
|
|
||||||
return torch.concat([item.precomputed_features for item in items])
|
|
||||||
# in qwen-vl, last dim is the same
|
# in qwen-vl, last dim is the same
|
||||||
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
||||||
self.visual.dtype
|
self.visual.dtype
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ suites = {
|
|||||||
TestFile("test_update_weights_from_tensor.py", 48),
|
TestFile("test_update_weights_from_tensor.py", 48),
|
||||||
TestFile("test_vertex_endpoint.py", 31),
|
TestFile("test_vertex_endpoint.py", 31),
|
||||||
TestFile("test_vision_chunked_prefill.py", 175),
|
TestFile("test_vision_chunked_prefill.py", 175),
|
||||||
TestFile("test_vlm_accuracy.py", 60),
|
TestFile("test_vlm_input_format.py", 300),
|
||||||
TestFile("test_vision_openai_server_a.py", 700),
|
TestFile("test_vision_openai_server_a.py", 700),
|
||||||
TestFile("test_vision_openai_server_b.py", 700),
|
TestFile("test_vision_openai_server_b.py", 700),
|
||||||
TestFile("test_w8a8_quantization.py", 46),
|
TestFile("test_w8a8_quantization.py", 46),
|
||||||
|
|||||||
@@ -10,15 +10,8 @@ import requests
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import (
|
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
||||||
AutoModel,
|
|
||||||
AutoProcessor,
|
|
||||||
AutoTokenizer,
|
|
||||||
Gemma3ForConditionalGeneration,
|
|
||||||
Qwen2_5_VLForConditionalGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang import Engine
|
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.conversation import generate_chat_conv
|
from sglang.srt.conversation import generate_chat_conv
|
||||||
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache
|
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache
|
||||||
@@ -41,9 +34,6 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
|
|||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
cls.image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||||
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
cls.model_path = ""
|
|
||||||
cls.chat_template = ""
|
|
||||||
cls.processor = ""
|
|
||||||
response = requests.get(cls.image_url)
|
response = requests.get(cls.image_url)
|
||||||
cls.main_image = Image.open(BytesIO(response.content))
|
cls.main_image = Image.open(BytesIO(response.content))
|
||||||
|
|
||||||
@@ -274,131 +264,3 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.compare_outputs(sglang_output, hf_output)
|
self.compare_outputs(sglang_output, hf_output)
|
||||||
|
|
||||||
|
|
||||||
class TestQwenVLUnderstandsImage(VisionLLMLogitsBase):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
super().setUpClass()
|
|
||||||
cls.model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
|
|
||||||
cls.chat_template = "qwen2-vl"
|
|
||||||
cls.processor = AutoProcessor.from_pretrained(
|
|
||||||
cls.model_path, trust_remote_code=True, use_fast=True
|
|
||||||
)
|
|
||||||
cls.visual = (
|
|
||||||
Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
||||||
cls.model_path, torch_dtype=torch.bfloat16
|
|
||||||
)
|
|
||||||
.eval()
|
|
||||||
.visual.to(cls.device)
|
|
||||||
)
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.engine = Engine(
|
|
||||||
model_path=self.model_path,
|
|
||||||
chat_template=self.chat_template,
|
|
||||||
device=self.device.type,
|
|
||||||
mem_fraction_static=0.8,
|
|
||||||
)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
self.engine.shutdown()
|
|
||||||
|
|
||||||
async def test_qwen_vl_understands_image(self):
|
|
||||||
req = self.get_completion_request()
|
|
||||||
conv = generate_chat_conv(req, template_name=self.chat_template)
|
|
||||||
text = conv.get_prompt()
|
|
||||||
output = await self.engine.async_generate(
|
|
||||||
prompt=text,
|
|
||||||
image_data=[self.main_image],
|
|
||||||
sampling_params=dict(temperature=0.0),
|
|
||||||
)
|
|
||||||
self.assertIn("taxi", output["text"].lower())
|
|
||||||
|
|
||||||
async def test_qwen_vl_understands_precomputed_features(self):
|
|
||||||
req = self.get_completion_request()
|
|
||||||
processor_output = self.get_processor_output(req=req)
|
|
||||||
with torch.inference_mode():
|
|
||||||
precomputed_features = self.visual(
|
|
||||||
processor_output["pixel_values"], processor_output["image_grid_thw"]
|
|
||||||
)
|
|
||||||
output = await self.engine.async_generate(
|
|
||||||
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
|
|
||||||
image_data=[
|
|
||||||
dict(
|
|
||||||
modality="IMAGE",
|
|
||||||
image_grid_thws=processor_output["image_grid_thw"],
|
|
||||||
precomputed_features=precomputed_features,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
sampling_params=dict(temperature=0.0),
|
|
||||||
)
|
|
||||||
self.assertIn("taxi", output["text"].lower())
|
|
||||||
|
|
||||||
|
|
||||||
class TestGemmaUnderstandsImage(VisionLLMLogitsBase):
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
super().setUpClass()
|
|
||||||
cls.model_path = "google/gemma-3-4b-it"
|
|
||||||
cls.chat_template = "gemma-it"
|
|
||||||
cls.processor = AutoProcessor.from_pretrained(
|
|
||||||
cls.model_path, trust_remote_code=True, use_fast=True
|
|
||||||
)
|
|
||||||
model = Gemma3ForConditionalGeneration.from_pretrained(
|
|
||||||
cls.model_path, torch_dtype=torch.bfloat16
|
|
||||||
)
|
|
||||||
cls.vision_tower = model.vision_tower.eval().to(cls.device)
|
|
||||||
cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def visual(cls, pixel_values):
|
|
||||||
vision_outputs = cls.vision_tower(pixel_values=pixel_values).last_hidden_state
|
|
||||||
image_features = cls.mm_projector(vision_outputs)
|
|
||||||
return image_features
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.engine = Engine(
|
|
||||||
model_path=self.model_path,
|
|
||||||
chat_template=self.chat_template,
|
|
||||||
device=self.device.type,
|
|
||||||
mem_fraction_static=0.5,
|
|
||||||
enable_multimodal=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
self.engine.shutdown()
|
|
||||||
|
|
||||||
async def test_gemma_understands_image(self):
|
|
||||||
req = self.get_completion_request()
|
|
||||||
conv = generate_chat_conv(req, template_name=self.chat_template)
|
|
||||||
text = conv.get_prompt()
|
|
||||||
output = await self.engine.async_generate(
|
|
||||||
prompt=text,
|
|
||||||
image_data=[self.main_image],
|
|
||||||
sampling_params=dict(temperature=0.0),
|
|
||||||
)
|
|
||||||
self.assertIn("taxi", output["text"].lower())
|
|
||||||
|
|
||||||
async def test_gemma_understands_precomputed_features(self):
|
|
||||||
req = self.get_completion_request()
|
|
||||||
processor_output = self.get_processor_output(req=req)
|
|
||||||
with torch.inference_mode():
|
|
||||||
precomputed_features = self.visual(processor_output["pixel_values"])
|
|
||||||
output = await self.engine.async_generate(
|
|
||||||
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
|
|
||||||
image_data=[
|
|
||||||
dict(
|
|
||||||
modality="IMAGE",
|
|
||||||
precomputed_features=precomputed_features,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
sampling_params=dict(temperature=0.0),
|
|
||||||
)
|
|
||||||
self.assertIn("taxi", output["text"].lower())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
|
|||||||
187
test/srt/test_vlm_input_format.py
Normal file
187
test/srt/test_vlm_input_format.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import (
|
||||||
|
AutoProcessor,
|
||||||
|
Gemma3ForConditionalGeneration,
|
||||||
|
Qwen2_5_VLForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang import Engine
|
||||||
|
from sglang.srt.conversation import generate_chat_conv
|
||||||
|
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
||||||
|
|
||||||
|
TEST_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||||
|
|
||||||
|
|
||||||
|
class VLMInputTestBase:
|
||||||
|
model_path = None
|
||||||
|
chat_template = None
|
||||||
|
processor = None
|
||||||
|
visual = None # Should be a callable for precomputed features
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
assert cls.model_path is not None, "Set model_path in subclass"
|
||||||
|
assert cls.chat_template is not None, "Set chat_template in subclass"
|
||||||
|
cls.image_url = TEST_IMAGE_URL
|
||||||
|
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
response = requests.get(cls.image_url)
|
||||||
|
cls.main_image = Image.open(BytesIO(response.content))
|
||||||
|
cls.processor = AutoProcessor.from_pretrained(
|
||||||
|
cls.model_path, trust_remote_code=True, use_fast=True
|
||||||
|
)
|
||||||
|
cls._init_visual()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _init_visual(cls):
|
||||||
|
"""Override in subclass to set up cls.visual as a callable for precomputed features."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.engine = Engine(
|
||||||
|
model_path=self.model_path,
|
||||||
|
chat_template=self.chat_template,
|
||||||
|
device=self.device.type,
|
||||||
|
mem_fraction_static=0.8,
|
||||||
|
enable_multimodal=True,
|
||||||
|
disable_cuda_graph=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.engine.shutdown()
|
||||||
|
|
||||||
|
def get_completion_request(self) -> ChatCompletionRequest:
|
||||||
|
json_structure = {
|
||||||
|
"model": self.model_path,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": self.image_url}},
|
||||||
|
{"type": "text", "text": "What's in this picture?"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
json_str = json.dumps(json_structure)
|
||||||
|
return ChatCompletionRequest.model_validate_json(json_str)
|
||||||
|
|
||||||
|
def get_processor_output(self, req: Optional[ChatCompletionRequest] = None):
|
||||||
|
if req is None:
|
||||||
|
req = self.get_completion_request()
|
||||||
|
conv = generate_chat_conv(req, template_name=self.chat_template)
|
||||||
|
text = conv.get_prompt()
|
||||||
|
|
||||||
|
# Process inputs using processor
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[text],
|
||||||
|
images=[self.main_image],
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
async def test_understands_image(self):
|
||||||
|
req = self.get_completion_request()
|
||||||
|
conv = generate_chat_conv(req, template_name=self.chat_template)
|
||||||
|
text = conv.get_prompt()
|
||||||
|
output = await self.engine.async_generate(
|
||||||
|
prompt=text,
|
||||||
|
image_data=[self.main_image],
|
||||||
|
sampling_params=dict(temperature=0.0),
|
||||||
|
)
|
||||||
|
self.assertIn("taxi", output["text"].lower())
|
||||||
|
|
||||||
|
async def test_understands_precomputed_features(self):
|
||||||
|
req = self.get_completion_request()
|
||||||
|
processor_output = self.get_processor_output(req=req)
|
||||||
|
with torch.inference_mode():
|
||||||
|
precomputed_features = self.__class__.visual(processor_output)
|
||||||
|
output = await self.engine.async_generate(
|
||||||
|
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
|
||||||
|
image_data=[
|
||||||
|
self._precomputed_image_data(processor_output, precomputed_features)
|
||||||
|
],
|
||||||
|
sampling_params=dict(temperature=0.0),
|
||||||
|
)
|
||||||
|
self.assertIn("taxi", output["text"].lower())
|
||||||
|
|
||||||
|
async def test_understands_pixel_values(self):
|
||||||
|
req = self.get_completion_request()
|
||||||
|
processor_output = self.get_processor_output(req=req)
|
||||||
|
output = await self.engine.async_generate(
|
||||||
|
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
|
||||||
|
image_data=[self._pixel_values_image_data(processor_output)],
|
||||||
|
sampling_params=dict(temperature=0.0),
|
||||||
|
)
|
||||||
|
self.assertIn("taxi", output["text"].lower())
|
||||||
|
|
||||||
|
def _precomputed_image_data(self, processor_output, precomputed_features):
|
||||||
|
"""This should not be overridden."""
|
||||||
|
return dict(
|
||||||
|
modality="IMAGE",
|
||||||
|
precomputed_features=precomputed_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _pixel_values_image_data(self, processor_output):
|
||||||
|
"""Override in subclass to pass the correct set of arguments."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class TestQwenVLUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestCase):
|
||||||
|
model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||||
|
chat_template = "qwen2-vl"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _init_visual(cls):
|
||||||
|
cls.visual_model = (
|
||||||
|
Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
cls.model_path, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
.eval()
|
||||||
|
.visual.to(cls.device)
|
||||||
|
)
|
||||||
|
cls.visual = lambda processor_output: cls.visual_model(
|
||||||
|
processor_output["pixel_values"], processor_output["image_grid_thw"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _pixel_values_image_data(self, processor_output):
|
||||||
|
return dict(
|
||||||
|
modality="IMAGE",
|
||||||
|
image_grid_thws=processor_output["image_grid_thw"],
|
||||||
|
pixel_values=processor_output["pixel_values"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGemmaUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestCase):
|
||||||
|
model_path = "google/gemma-3-4b-it"
|
||||||
|
chat_template = "gemma-it"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _init_visual(cls):
|
||||||
|
model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||||
|
cls.model_path, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
cls.vision_tower = model.vision_tower.eval().to(cls.device)
|
||||||
|
cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)
|
||||||
|
cls.visual = lambda processor_output: cls.mm_projector(
|
||||||
|
cls.vision_tower(
|
||||||
|
pixel_values=processor_output["pixel_values"]
|
||||||
|
).last_hidden_state
|
||||||
|
)
|
||||||
|
|
||||||
|
def _pixel_values_image_data(self, processor_output):
|
||||||
|
return dict(
|
||||||
|
modality="IMAGE",
|
||||||
|
pixel_values=processor_output["pixel_values"][0],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user