Support for Qwen2.5-VL Model in bitsandbytes Format (#5003)
This commit is contained in:
1
.github/workflows/vllm-dependency-test.yml
vendored
1
.github/workflows/vllm-dependency-test.yml
vendored
@@ -33,6 +33,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
bash scripts/ci_install_dependency.sh
|
bash scripts/ci_install_dependency.sh
|
||||||
pip install "vllm>=0.6.4.post1,<=0.7.2"
|
pip install "vllm>=0.6.4.post1,<=0.7.2"
|
||||||
|
pip install "bitsandbytes>=0.44.0"
|
||||||
|
|
||||||
- name: Run VLLM dependency tests
|
- name: Run VLLM dependency tests
|
||||||
timeout-minutes: 60
|
timeout-minutes: 60
|
||||||
|
|||||||
@@ -1071,6 +1071,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
param_dict = dict(model.named_parameters())
|
param_dict = dict(model.named_parameters())
|
||||||
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
|
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
|
||||||
|
model_type = model_config.hf_config.model_type
|
||||||
for quant_param_name in quant_state_dict:
|
for quant_param_name in quant_state_dict:
|
||||||
non_stacked_param_name = quant_param_name
|
non_stacked_param_name = quant_param_name
|
||||||
|
|
||||||
@@ -1079,11 +1080,24 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
weight_name,
|
weight_name,
|
||||||
index,
|
index,
|
||||||
) in model.bitsandbytes_stacked_params_mapping.items():
|
) in model.bitsandbytes_stacked_params_mapping.items():
|
||||||
|
if (
|
||||||
|
model_type in ["qwen2_vl", "qwen2_5_vl"]
|
||||||
|
and "visual" in quant_param_name
|
||||||
|
):
|
||||||
|
break
|
||||||
if shard_name in quant_param_name:
|
if shard_name in quant_param_name:
|
||||||
shard_index = index
|
shard_index = index
|
||||||
quant_param_name = quant_param_name.replace(shard_name, weight_name)
|
quant_param_name = quant_param_name.replace(shard_name, weight_name)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if (
|
||||||
|
model_type in ["qwen2_vl", "qwen2_5_vl"]
|
||||||
|
and "visual" in quant_param_name
|
||||||
|
):
|
||||||
|
quant_param_name = quant_param_name.replace(
|
||||||
|
r"attn.qkv.", r"attn.qkv_proj."
|
||||||
|
)
|
||||||
|
|
||||||
if quant_param_name not in param_dict:
|
if quant_param_name not in param_dict:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Parameter {quant_param_name} not found in the model."
|
f"Parameter {quant_param_name} not found in the model."
|
||||||
@@ -1111,6 +1125,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
num_elements[seq] = math.prod(quant_state.shape) // pack_ratio
|
num_elements[seq] = math.prod(quant_state.shape) // pack_ratio
|
||||||
|
|
||||||
offsets = np.concatenate(([0], np.cumsum(num_elements)))
|
offsets = np.concatenate(([0], np.cumsum(num_elements)))
|
||||||
|
# Make torch infer_schema happy(Compatible with vLLM)
|
||||||
|
offsets = torch.tensor(offsets).cpu()
|
||||||
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
||||||
|
|
||||||
if load_8bit:
|
if load_8bit:
|
||||||
|
|||||||
@@ -141,7 +141,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|||||||
embed_dim=dim,
|
embed_dim=dim,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
projection_size=dim,
|
projection_size=dim,
|
||||||
use_qkv_parallel=False,
|
use_qkv_parallel=True,
|
||||||
use_context_forward=use_context_forward,
|
use_context_forward=use_context_forward,
|
||||||
softmax_in_single_precision=softmax_in_single_precision,
|
softmax_in_single_precision=softmax_in_single_precision,
|
||||||
flatten_batch=flatten_batch,
|
flatten_batch=flatten_batch,
|
||||||
@@ -325,7 +325,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
return self.blocks[0].mlp.gate_proj.weight.dtype
|
return self.patch_embed.proj.weight.dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
@@ -429,6 +429,25 @@ cached_get_processor = lru_cache(get_processor)
|
|||||||
|
|
||||||
|
|
||||||
class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||||
|
# BitandBytes specific attributes
|
||||||
|
default_bitsandbytes_target_modules = [
|
||||||
|
".gate_proj.",
|
||||||
|
".down_proj.",
|
||||||
|
".up_proj.",
|
||||||
|
".q_proj.",
|
||||||
|
".k_proj.",
|
||||||
|
".v_proj.",
|
||||||
|
".o_proj.",
|
||||||
|
]
|
||||||
|
bitsandbytes_stacked_params_mapping = {
|
||||||
|
# shard_name, weight_name, index
|
||||||
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
"k_proj": ("qkv_proj", 1),
|
||||||
|
"v_proj": ("qkv_proj", 2),
|
||||||
|
"gate_proj": ("gate_up_proj", 0),
|
||||||
|
"up_proj": ("gate_up_proj", 1),
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Qwen2_5_VLConfig,
|
config: Qwen2_5_VLConfig,
|
||||||
@@ -441,9 +460,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|||||||
self.visual = Qwen2_5_VisionTransformer(
|
self.visual = Qwen2_5_VisionTransformer(
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||||
# NOTE: Qwen2-VL vision encoder does not support any
|
# NOTE: Qwen2_5-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
|
||||||
# quantization method now.
|
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
|
||||||
quant_config=None,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("visual", prefix),
|
prefix=add_prefix("visual", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -573,23 +592,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if "visual" in name and "qkv.weight" in name:
|
|
||||||
visual_num_heads = self.config.vision_config.num_heads
|
|
||||||
visual_embed_dim = self.config.vision_config.hidden_size
|
|
||||||
head_size = visual_embed_dim // visual_num_heads
|
|
||||||
loaded_weight = loaded_weight.view(
|
|
||||||
3, visual_num_heads, head_size, visual_embed_dim
|
|
||||||
)
|
|
||||||
loaded_weight = loaded_weight.transpose(0, 1)
|
|
||||||
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
|
||||||
elif "visual" in name and "qkv.bias" in name:
|
|
||||||
visual_num_heads = self.config.vision_config.num_heads
|
|
||||||
visual_embed_dim = self.config.vision_config.hidden_size
|
|
||||||
head_size = visual_embed_dim // visual_num_heads
|
|
||||||
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
|
|
||||||
loaded_weight = loaded_weight.transpose(0, 1)
|
|
||||||
loaded_weight = loaded_weight.reshape(-1)
|
|
||||||
|
|
||||||
if "visual" in name:
|
if "visual" in name:
|
||||||
# adapt to VisionAttention
|
# adapt to VisionAttention
|
||||||
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ class Qwen2VisionBlock(nn.Module):
|
|||||||
embed_dim=dim,
|
embed_dim=dim,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
projection_size=dim,
|
projection_size=dim,
|
||||||
use_qkv_parallel=False,
|
use_qkv_parallel=True,
|
||||||
use_context_forward=use_context_forward,
|
use_context_forward=use_context_forward,
|
||||||
softmax_in_single_precision=softmax_in_single_precision,
|
softmax_in_single_precision=softmax_in_single_precision,
|
||||||
flatten_batch=True,
|
flatten_batch=True,
|
||||||
@@ -351,7 +351,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype:
|
||||||
return next(self.parameters()).dtype
|
return self.patch_embed.proj.weight.dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
@@ -423,6 +423,25 @@ cached_get_processor = lru_cache(get_processor)
|
|||||||
|
|
||||||
|
|
||||||
class Qwen2VLForConditionalGeneration(nn.Module):
|
class Qwen2VLForConditionalGeneration(nn.Module):
|
||||||
|
# BitandBytes specific attributes
|
||||||
|
default_bitsandbytes_target_modules = [
|
||||||
|
".gate_proj.",
|
||||||
|
".down_proj.",
|
||||||
|
".up_proj.",
|
||||||
|
".q_proj.",
|
||||||
|
".k_proj.",
|
||||||
|
".v_proj.",
|
||||||
|
".o_proj.",
|
||||||
|
]
|
||||||
|
bitsandbytes_stacked_params_mapping = {
|
||||||
|
# shard_name, weight_name, index
|
||||||
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
"k_proj": ("qkv_proj", 1),
|
||||||
|
"v_proj": ("qkv_proj", 2),
|
||||||
|
"gate_proj": ("gate_up_proj", 0),
|
||||||
|
"up_proj": ("gate_up_proj", 1),
|
||||||
|
}
|
||||||
|
|
||||||
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
|
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
|
||||||
processor = cached_get_processor(self.config._name_or_path)
|
processor = cached_get_processor(self.config._name_or_path)
|
||||||
grid_t, grid_h, grid_w = image_grid_thw
|
grid_t, grid_h, grid_w = image_grid_thw
|
||||||
@@ -447,9 +466,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
self.visual = Qwen2VisionTransformer(
|
self.visual = Qwen2VisionTransformer(
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||||
# NOTE: Qwen2-VL vision encoder does not support any
|
# NOTE: Qwen2-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
|
||||||
# quantization method now.
|
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
|
||||||
quant_config=None,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("visual", prefix),
|
prefix=add_prefix("visual", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -578,24 +597,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
|
||||||
if "visual" in name and "qkv.weight" in name:
|
|
||||||
visual_num_heads = self.config.vision_config.num_heads
|
|
||||||
visual_embed_dim = self.config.vision_config.embed_dim
|
|
||||||
head_size = visual_embed_dim // visual_num_heads
|
|
||||||
loaded_weight = loaded_weight.view(
|
|
||||||
3, visual_num_heads, head_size, visual_embed_dim
|
|
||||||
)
|
|
||||||
loaded_weight = loaded_weight.transpose(0, 1)
|
|
||||||
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
|
||||||
elif "visual" in name and "qkv.bias" in name:
|
|
||||||
visual_num_heads = self.config.vision_config.num_heads
|
|
||||||
visual_embed_dim = self.config.vision_config.embed_dim
|
|
||||||
head_size = visual_embed_dim // visual_num_heads
|
|
||||||
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
|
|
||||||
loaded_weight = loaded_weight.transpose(0, 1)
|
|
||||||
loaded_weight = loaded_weight.reshape(-1)
|
|
||||||
|
|
||||||
if "visual" in name:
|
if "visual" in name:
|
||||||
# adapt to VisionAttention
|
# adapt to VisionAttention
|
||||||
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ suites = {
|
|||||||
TestFile("test_awq.py"),
|
TestFile("test_awq.py"),
|
||||||
TestFile("test_gguf.py", 78),
|
TestFile("test_gguf.py", 78),
|
||||||
TestFile("test_gptqmodel_dynamic.py", 72),
|
TestFile("test_gptqmodel_dynamic.py", 72),
|
||||||
|
TestFile("test_bnb.py"),
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
309
test/srt/test_bnb.py
Normal file
309
test/srt/test_bnb.py
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_mixed_batch
|
||||||
|
python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_multi_images_chat_completion
|
||||||
|
"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import openai
|
||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.run_eval import run_eval
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
is_in_ci,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
VISION_MODELS = [
|
||||||
|
("unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit", "qwen2-vl"),
|
||||||
|
("unsloth/Qwen2-VL-7B-Instruct-bnb-4bit", "qwen2-vl"),
|
||||||
|
]
|
||||||
|
LANGUAGE_MODELS = [
|
||||||
|
"unsloth/Qwen2.5-7B-Instruct-bnb-4bit",
|
||||||
|
"unsloth/Qwen2-7B-Instruct-bnb-4bit",
|
||||||
|
]
|
||||||
|
|
||||||
|
# image
|
||||||
|
IMAGE_MAN_IRONING_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/man_ironing_on_back_of_suv.png"
|
||||||
|
IMAGE_SGL_LOGO_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/sgl_logo.png"
|
||||||
|
|
||||||
|
# video
|
||||||
|
VIDEO_JOBS_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/videos/jobs_presenting_ipod.mp4"
|
||||||
|
|
||||||
|
# audio
|
||||||
|
AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/Trump_WEF_2018_10s.mp3"
|
||||||
|
AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
|
||||||
|
|
||||||
|
|
||||||
|
def popen_launch_server_wrapper(base_url, model, other_args):
|
||||||
|
process = popen_launch_server(
|
||||||
|
model,
|
||||||
|
base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=other_args,
|
||||||
|
)
|
||||||
|
return process
|
||||||
|
|
||||||
|
|
||||||
|
class TestVisionModel(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
|
||||||
|
def _run_single_image_chat_completion(self):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": IMAGE_MAN_IRONING_URL},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Describe this image in a very short sentence.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.choices[0].message.role == "assistant"
|
||||||
|
text = response.choices[0].message.content
|
||||||
|
assert isinstance(text, str)
|
||||||
|
# `driver` is for gemma-3-it
|
||||||
|
assert "man" in text or "person" or "driver" in text, text
|
||||||
|
assert "cab" in text or "taxi" in text or "SUV" in text, text
|
||||||
|
# MiniCPMO fails to recognize `iron`, but `hanging`
|
||||||
|
assert "iron" in text or "hang" in text, text
|
||||||
|
assert response.id
|
||||||
|
assert response.created
|
||||||
|
assert response.usage.prompt_tokens > 0
|
||||||
|
assert response.usage.completion_tokens > 0
|
||||||
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
|
def _run_multi_turn_chat_completion(self):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": IMAGE_MAN_IRONING_URL},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Describe this image in a very short sentence.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "There is a man at the back of a yellow cab ironing his clothes.",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Repeat your previous answer."}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.choices[0].message.role == "assistant"
|
||||||
|
text = response.choices[0].message.content
|
||||||
|
assert isinstance(text, str)
|
||||||
|
assert "man" in text or "cab" in text, text
|
||||||
|
assert response.id
|
||||||
|
assert response.created
|
||||||
|
assert response.usage.prompt_tokens > 0
|
||||||
|
assert response.usage.completion_tokens > 0
|
||||||
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
|
def _run_multi_images_chat_completion(self):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": IMAGE_MAN_IRONING_URL},
|
||||||
|
"modalities": "multi-images",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": IMAGE_SGL_LOGO_URL},
|
||||||
|
"modalities": "multi-images",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "I have two very different images. They are not related at all. "
|
||||||
|
"Please describe the first image in one sentence, and then describe the second image in another sentence.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.choices[0].message.role == "assistant"
|
||||||
|
text = response.choices[0].message.content
|
||||||
|
assert isinstance(text, str)
|
||||||
|
print("-" * 30)
|
||||||
|
print(f"Multi images response:\n{text}")
|
||||||
|
print("-" * 30)
|
||||||
|
assert "man" in text or "cab" in text or "SUV" in text or "taxi" in text, text
|
||||||
|
assert "logo" in text or '"S"' in text or "SG" in text, text
|
||||||
|
assert response.id
|
||||||
|
assert response.created
|
||||||
|
assert response.usage.prompt_tokens > 0
|
||||||
|
assert response.usage.completion_tokens > 0
|
||||||
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
|
def run_decode_with_image(self, image_id):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
content = []
|
||||||
|
if image_id == 0:
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": IMAGE_MAN_IRONING_URL},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif image_id == 1:
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": IMAGE_SGL_LOGO_URL},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
content.append(
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Describe this image in a very short sentence.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": content},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.choices[0].message.role == "assistant"
|
||||||
|
text = response.choices[0].message.content
|
||||||
|
assert isinstance(text, str)
|
||||||
|
|
||||||
|
def _run_test_mixed_batch(self):
|
||||||
|
image_ids = [0, 1, 2] * 4
|
||||||
|
with ThreadPoolExecutor(4) as executor:
|
||||||
|
list(executor.map(self.run_decode_with_image, image_ids))
|
||||||
|
|
||||||
|
def test_vlm(self):
|
||||||
|
models_to_test = VISION_MODELS
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
models_to_test = [random.choice(VISION_MODELS)]
|
||||||
|
|
||||||
|
for model, template in models_to_test:
|
||||||
|
with self.subTest(model=model):
|
||||||
|
other_args = [
|
||||||
|
"--chat-template",
|
||||||
|
template,
|
||||||
|
"--mem-fraction-static",
|
||||||
|
"0.6",
|
||||||
|
"--load-format",
|
||||||
|
"bitsandbytes",
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
process = popen_launch_server_wrapper(
|
||||||
|
DEFAULT_URL_FOR_TEST, model, other_args
|
||||||
|
)
|
||||||
|
self._run_test_mixed_batch()
|
||||||
|
self._run_multi_images_chat_completion()
|
||||||
|
self._run_multi_turn_chat_completion()
|
||||||
|
self._run_single_image_chat_completion()
|
||||||
|
finally:
|
||||||
|
kill_process_tree(process.pid)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLanguageModel(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
# cls.base_url += "/v1"
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
|
||||||
|
def test_mmlu(self):
|
||||||
|
models_to_test = LANGUAGE_MODELS
|
||||||
|
|
||||||
|
if is_in_ci():
|
||||||
|
models_to_test = [random.choice(LANGUAGE_MODELS)]
|
||||||
|
|
||||||
|
for model in models_to_test:
|
||||||
|
with self.subTest(model=model):
|
||||||
|
other_args = [
|
||||||
|
"--mem-fraction-static",
|
||||||
|
"0.6",
|
||||||
|
"--load-format",
|
||||||
|
"bitsandbytes",
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
process = popen_launch_server_wrapper(
|
||||||
|
DEFAULT_URL_FOR_TEST, model, other_args
|
||||||
|
)
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=32,
|
||||||
|
num_threads=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
print(f"{metrics=}")
|
||||||
|
self.assertGreater(metrics["score"], 0.3)
|
||||||
|
finally:
|
||||||
|
kill_process_tree(process.pid)
|
||||||
Reference in New Issue
Block a user