[Refactor] image data process in bench_serving (#6879)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -388,7 +388,6 @@ async def async_request_sglang_generate(
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
# print(chunk_bytes)
|
||||
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
||||
latency = time.perf_counter() - st
|
||||
@@ -655,6 +654,7 @@ class DatasetRow:
|
||||
prompt: str
|
||||
prompt_len: int
|
||||
output_len: int
|
||||
image_data: Optional[str] = None
|
||||
|
||||
|
||||
def sample_mmmu_requests(
|
||||
@@ -730,42 +730,50 @@ def sample_mmmu_requests(
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="JPEG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
image_path = f"data:image/jpeg;base64,{img_str}"
|
||||
image_data = f"data:image/jpeg;base64,{img_str}"
|
||||
else:
|
||||
continue
|
||||
|
||||
# Extract the question
|
||||
question = example.get("question")
|
||||
|
||||
# Create the prompt with image, question
|
||||
# Construct the prompt
|
||||
prompt = f"Question: {question}\n\nAnswer: "
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": image_path}},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
prompt = f"<image>{image_path}</image>{prompt}"
|
||||
|
||||
# Calculate token lengths
|
||||
# Note: This is approximate since we're not rendering the actual image tokens
|
||||
try:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
except Exception as e:
|
||||
# Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL)
|
||||
print(f"Error applying chat template: {e}, fallback to <image> tag")
|
||||
prompt = f"<image>{prompt}"
|
||||
|
||||
# Calculate token lengths for text only (without image data)
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
prompt_len = (
|
||||
len(prompt_token_ids) + 512
|
||||
) # Add estimate for image tokens
|
||||
prompt_len = len(prompt_token_ids)
|
||||
|
||||
output_len = fixed_output_len if fixed_output_len is not None else 256
|
||||
|
||||
filtered_dataset.append(
|
||||
DatasetRow(
|
||||
prompt=prompt, prompt_len=prompt_len, output_len=output_len
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
image_data=image_data,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1199,34 +1207,21 @@ async def benchmark(
|
||||
|
||||
# Use the first request for all warmup iterations
|
||||
test_request = input_requests[0]
|
||||
test_prompt, test_prompt_len, test_output_len = (
|
||||
test_request.prompt,
|
||||
test_request.prompt_len,
|
||||
test_request.output_len,
|
||||
)
|
||||
|
||||
if lora_names is not None and len(lora_names) != 0:
|
||||
lora_name = lora_names[0]
|
||||
else:
|
||||
lora_name = None
|
||||
|
||||
if "<image>" in test_prompt:
|
||||
import re
|
||||
|
||||
image_match = re.search(r"<image>(.*?)</image>(.*)", test_prompt, re.DOTALL)
|
||||
image_data = image_match.group(1) if image_match else None
|
||||
test_prompt = image_match.group(2) if image_match else test_prompt
|
||||
else:
|
||||
image_data = None
|
||||
|
||||
# Create the test input once
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=test_prompt,
|
||||
prompt=test_request.prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=min(test_output_len, 32),
|
||||
prompt_len=test_request.prompt_len,
|
||||
output_len=min(test_request.output_len, 32),
|
||||
lora_name=lora_name,
|
||||
image_data=image_data,
|
||||
image_data=test_request.image_data,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
|
||||
@@ -1271,36 +1266,23 @@ async def benchmark(
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: List[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate):
|
||||
prompt, prompt_len, output_len = (
|
||||
request.prompt,
|
||||
request.prompt_len,
|
||||
request.output_len,
|
||||
)
|
||||
if lora_names is not None and len(lora_names) != 0:
|
||||
idx = random.randint(0, len(lora_names) - 1)
|
||||
lora_name = lora_names[idx]
|
||||
else:
|
||||
lora_name = None
|
||||
|
||||
if "<image>" in prompt:
|
||||
import re
|
||||
|
||||
image_match = re.search(r"<image>(.*?)</image>(.*)", prompt, re.DOTALL)
|
||||
image_data = image_match.group(1) if image_match else None
|
||||
prompt = image_match.group(2) if image_match else prompt
|
||||
else:
|
||||
image_data = None
|
||||
|
||||
request_func_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
prompt=request.prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
prompt_len=request.prompt_len,
|
||||
output_len=request.output_len,
|
||||
lora_name=lora_name,
|
||||
image_data=image_data,
|
||||
image_data=request.image_data,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
limited_request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
|
||||
@@ -175,6 +175,10 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
# Ensure image_data is a list
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
|
||||
Reference in New Issue
Block a user