[CI]Add performance CI for VLM (#6038)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -58,6 +58,7 @@ class RequestFuncInput:
|
||||
output_len: int
|
||||
model: str
|
||||
lora_name: str
|
||||
image_data: str
|
||||
extra_request_body: Dict[str, Any]
|
||||
|
||||
|
||||
@@ -347,6 +348,11 @@ async def async_request_sglang_generate(
|
||||
"logprob_start_len": -1,
|
||||
**request_func_input.extra_request_body,
|
||||
}
|
||||
|
||||
# Add image data if available
|
||||
if request_func_input.image_data:
|
||||
payload["image_data"] = request_func_input.image_data
|
||||
|
||||
headers = get_auth_headers()
|
||||
|
||||
output = RequestFuncOutput()
|
||||
@@ -510,6 +516,13 @@ def get_dataset(args, tokenizer):
|
||||
tokenizer=tokenizer,
|
||||
args=args,
|
||||
)
|
||||
elif args.dataset_name == "mmmu":
|
||||
input_requests = sample_mmmu_requests(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
fixed_output_len=args.random_output_len,
|
||||
random_sample=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||
return input_requests
|
||||
@@ -597,6 +610,121 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
|
||||
return filename
|
||||
|
||||
|
||||
def sample_mmmu_requests(
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
random_sample: bool = True,
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
"""
|
||||
Sample requests from the MMMU dataset using HuggingFace datasets.
|
||||
|
||||
Args:
|
||||
num_requests: Number of requests to sample.
|
||||
tokenizer: Tokenizer to use for token counting.
|
||||
fixed_output_len: If provided, use this fixed output length for all requests.
|
||||
random_sample: Whether to randomly sample or take the first N.
|
||||
|
||||
Returns:
|
||||
List of tuples (prompt, prompt_token_len, output_token_len).
|
||||
"""
|
||||
try:
|
||||
import base64
|
||||
import io
|
||||
|
||||
from datasets import load_dataset
|
||||
except ImportError:
|
||||
raise ImportError("Please install datasets: pip install datasets")
|
||||
|
||||
print("Loading MMMU dataset from HuggingFace...")
|
||||
|
||||
try:
|
||||
print("Attempting to load MMMU Math dataset...")
|
||||
mmmu_dataset = load_dataset("MMMU/MMMU", "Math", split="test")
|
||||
print(
|
||||
f"Successfully loaded MMMU Math dataset from HuggingFace with {len(mmmu_dataset)} examples"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to load MMMU Math dataset: {e}")
|
||||
raise ValueError(f"Failed to load MMMU dataset: {e}")
|
||||
|
||||
# Sample from the dataset
|
||||
if len(mmmu_dataset) > num_requests:
|
||||
if random_sample:
|
||||
# Random sample
|
||||
indices = random.sample(range(len(mmmu_dataset)), num_requests)
|
||||
sample_dataset = mmmu_dataset.select(indices)
|
||||
else:
|
||||
# Take first N
|
||||
sample_dataset = mmmu_dataset.select(
|
||||
range(min(num_requests, len(mmmu_dataset)))
|
||||
)
|
||||
else:
|
||||
print(f"Dataset has less than {num_requests} examples, using all examples")
|
||||
sample_dataset = mmmu_dataset
|
||||
|
||||
print(f"Selected {len(sample_dataset)} examples for benchmarking")
|
||||
|
||||
# Create prompts
|
||||
filtered_dataset = []
|
||||
|
||||
for i, example in enumerate(sample_dataset):
|
||||
try:
|
||||
# Extract image_1
|
||||
image = example.get("image_1")
|
||||
|
||||
if image is not None:
|
||||
if hasattr(image, "save"):
|
||||
# Convert RGBA images to RGB before encoding
|
||||
if image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
|
||||
# Encode image to base64
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
image_path = f"data:image/jpeg;base64,{img_str}"
|
||||
else:
|
||||
continue
|
||||
|
||||
# Extract the question
|
||||
question = example.get("question")
|
||||
|
||||
# Create the prompt with image, question
|
||||
prompt = f"Question: {question}\n\nAnswer: "
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_path}},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
prompt = f"<image>{image_path}</image>{prompt}"
|
||||
|
||||
# Calculate token lengths
|
||||
# Note: This is approximate since we're not rendering the actual image tokens
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
prompt_len = (
|
||||
len(prompt_token_ids) + 512
|
||||
) # Add estimate for image tokens
|
||||
|
||||
output_len = fixed_output_len if fixed_output_len is not None else 256
|
||||
|
||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing example {i}: {e}")
|
||||
|
||||
print(f"\nCreated {len(filtered_dataset)} MMMU prompts")
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
def sample_sharegpt_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
@@ -1004,6 +1132,15 @@ async def benchmark(
|
||||
else:
|
||||
lora_name = None
|
||||
|
||||
if "<image>" in test_prompt:
|
||||
import re
|
||||
|
||||
image_match = re.search(r"<image>(.*?)</image>(.*)", test_prompt)
|
||||
image_data = image_match.group(1) if image_match else None
|
||||
test_prompt = image_match.group(2) if image_match else test_prompt
|
||||
else:
|
||||
image_data = None
|
||||
|
||||
# Create the test input once
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
@@ -1012,6 +1149,7 @@ async def benchmark(
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=min(test_output_len, 32),
|
||||
lora_name=lora_name,
|
||||
image_data=image_data,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
|
||||
@@ -1063,6 +1201,15 @@ async def benchmark(
|
||||
else:
|
||||
lora_name = None
|
||||
|
||||
if "<image>" in prompt:
|
||||
import re
|
||||
|
||||
image_match = re.search(r"<image>(.*?)</image>(.*)", prompt)
|
||||
image_data = image_match.group(1) if image_match else None
|
||||
prompt = image_match.group(2) if image_match else prompt
|
||||
else:
|
||||
image_data = None
|
||||
|
||||
request_func_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
@@ -1070,6 +1217,7 @@ async def benchmark(
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
lora_name=lora_name,
|
||||
image_data=image_data,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
tasks.append(
|
||||
@@ -1444,7 +1592,7 @@ if __name__ == "__main__":
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="sharegpt",
|
||||
choices=["sharegpt", "random", "random-ids", "generated-shared-prefix"],
|
||||
choices=["sharegpt", "random", "random-ids", "generated-shared-prefix", "mmmu"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -79,7 +79,8 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Ins
|
||||
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
||||
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4,hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4"
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
DEFAULT_SMALL_VLM_MODEL_NAME = "Qwen/Qwen2-VL-2B"
|
||||
DEFAULT_SMALL_VLM_MODEL_NAME_FOR_TEST = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
DEFAULT_VLM_CHAT_TEMPLATE_FOR_TEST = "qwen2-vl"
|
||||
|
||||
DEFAULT_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||
DEFAULT_VIDEO_URL = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
|
||||
|
||||
Reference in New Issue
Block a user