[Refactor] image data process in bench_serving (#6879)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -388,7 +388,6 @@ async def async_request_sglang_generate(
|
|||||||
chunk_bytes = chunk_bytes.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
if not chunk_bytes:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
# print(chunk_bytes)
|
|
||||||
|
|
||||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
||||||
latency = time.perf_counter() - st
|
latency = time.perf_counter() - st
|
||||||
@@ -655,6 +654,7 @@ class DatasetRow:
|
|||||||
prompt: str
|
prompt: str
|
||||||
prompt_len: int
|
prompt_len: int
|
||||||
output_len: int
|
output_len: int
|
||||||
|
image_data: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def sample_mmmu_requests(
|
def sample_mmmu_requests(
|
||||||
@@ -730,42 +730,50 @@ def sample_mmmu_requests(
|
|||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
image.save(buffered, format="JPEG")
|
image.save(buffered, format="JPEG")
|
||||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
image_path = f"data:image/jpeg;base64,{img_str}"
|
image_data = f"data:image/jpeg;base64,{img_str}"
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Extract the question
|
# Extract the question
|
||||||
question = example.get("question")
|
question = example.get("question")
|
||||||
|
|
||||||
# Create the prompt with image, question
|
# Construct the prompt
|
||||||
prompt = f"Question: {question}\n\nAnswer: "
|
prompt = f"Question: {question}\n\nAnswer: "
|
||||||
prompt = tokenizer.apply_chat_template(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "image_url", "image_url": {"url": image_path}},
|
|
||||||
{"type": "text", "text": prompt},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
add_generation_prompt=True,
|
|
||||||
tokenize=False,
|
|
||||||
)
|
|
||||||
prompt = f"<image>{image_path}</image>{prompt}"
|
|
||||||
|
|
||||||
# Calculate token lengths
|
try:
|
||||||
# Note: This is approximate since we're not rendering the actual image tokens
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": image_data},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": prompt},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL)
|
||||||
|
print(f"Error applying chat template: {e}, fallback to <image> tag")
|
||||||
|
prompt = f"<image>{prompt}"
|
||||||
|
|
||||||
|
# Calculate token lengths for text only (without image data)
|
||||||
prompt_token_ids = tokenizer.encode(prompt)
|
prompt_token_ids = tokenizer.encode(prompt)
|
||||||
prompt_len = (
|
prompt_len = len(prompt_token_ids)
|
||||||
len(prompt_token_ids) + 512
|
|
||||||
) # Add estimate for image tokens
|
|
||||||
|
|
||||||
output_len = fixed_output_len if fixed_output_len is not None else 256
|
output_len = fixed_output_len if fixed_output_len is not None else 256
|
||||||
|
|
||||||
filtered_dataset.append(
|
filtered_dataset.append(
|
||||||
DatasetRow(
|
DatasetRow(
|
||||||
prompt=prompt, prompt_len=prompt_len, output_len=output_len
|
prompt=prompt,
|
||||||
|
prompt_len=prompt_len,
|
||||||
|
output_len=output_len,
|
||||||
|
image_data=image_data,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1199,34 +1207,21 @@ async def benchmark(
|
|||||||
|
|
||||||
# Use the first request for all warmup iterations
|
# Use the first request for all warmup iterations
|
||||||
test_request = input_requests[0]
|
test_request = input_requests[0]
|
||||||
test_prompt, test_prompt_len, test_output_len = (
|
|
||||||
test_request.prompt,
|
|
||||||
test_request.prompt_len,
|
|
||||||
test_request.output_len,
|
|
||||||
)
|
|
||||||
if lora_names is not None and len(lora_names) != 0:
|
if lora_names is not None and len(lora_names) != 0:
|
||||||
lora_name = lora_names[0]
|
lora_name = lora_names[0]
|
||||||
else:
|
else:
|
||||||
lora_name = None
|
lora_name = None
|
||||||
|
|
||||||
if "<image>" in test_prompt:
|
|
||||||
import re
|
|
||||||
|
|
||||||
image_match = re.search(r"<image>(.*?)</image>(.*)", test_prompt, re.DOTALL)
|
|
||||||
image_data = image_match.group(1) if image_match else None
|
|
||||||
test_prompt = image_match.group(2) if image_match else test_prompt
|
|
||||||
else:
|
|
||||||
image_data = None
|
|
||||||
|
|
||||||
# Create the test input once
|
# Create the test input once
|
||||||
test_input = RequestFuncInput(
|
test_input = RequestFuncInput(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
prompt=test_prompt,
|
prompt=test_request.prompt,
|
||||||
api_url=api_url,
|
api_url=api_url,
|
||||||
prompt_len=test_prompt_len,
|
prompt_len=test_request.prompt_len,
|
||||||
output_len=min(test_output_len, 32),
|
output_len=min(test_request.output_len, 32),
|
||||||
lora_name=lora_name,
|
lora_name=lora_name,
|
||||||
image_data=image_data,
|
image_data=test_request.image_data,
|
||||||
extra_request_body=extra_request_body,
|
extra_request_body=extra_request_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1271,36 +1266,23 @@ async def benchmark(
|
|||||||
benchmark_start_time = time.perf_counter()
|
benchmark_start_time = time.perf_counter()
|
||||||
tasks: List[asyncio.Task] = []
|
tasks: List[asyncio.Task] = []
|
||||||
async for request in get_request(input_requests, request_rate):
|
async for request in get_request(input_requests, request_rate):
|
||||||
prompt, prompt_len, output_len = (
|
|
||||||
request.prompt,
|
|
||||||
request.prompt_len,
|
|
||||||
request.output_len,
|
|
||||||
)
|
|
||||||
if lora_names is not None and len(lora_names) != 0:
|
if lora_names is not None and len(lora_names) != 0:
|
||||||
idx = random.randint(0, len(lora_names) - 1)
|
idx = random.randint(0, len(lora_names) - 1)
|
||||||
lora_name = lora_names[idx]
|
lora_name = lora_names[idx]
|
||||||
else:
|
else:
|
||||||
lora_name = None
|
lora_name = None
|
||||||
|
|
||||||
if "<image>" in prompt:
|
|
||||||
import re
|
|
||||||
|
|
||||||
image_match = re.search(r"<image>(.*?)</image>(.*)", prompt, re.DOTALL)
|
|
||||||
image_data = image_match.group(1) if image_match else None
|
|
||||||
prompt = image_match.group(2) if image_match else prompt
|
|
||||||
else:
|
|
||||||
image_data = None
|
|
||||||
|
|
||||||
request_func_input = RequestFuncInput(
|
request_func_input = RequestFuncInput(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
prompt=prompt,
|
prompt=request.prompt,
|
||||||
api_url=api_url,
|
api_url=api_url,
|
||||||
prompt_len=prompt_len,
|
prompt_len=request.prompt_len,
|
||||||
output_len=output_len,
|
output_len=request.output_len,
|
||||||
lora_name=lora_name,
|
lora_name=lora_name,
|
||||||
image_data=image_data,
|
image_data=request.image_data,
|
||||||
extra_request_body=extra_request_body,
|
extra_request_body=extra_request_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
tasks.append(
|
tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
limited_request_func(request_func_input=request_func_input, pbar=pbar)
|
limited_request_func(request_func_input=request_func_input, pbar=pbar)
|
||||||
|
|||||||
@@ -175,6 +175,10 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|||||||
if not image_data:
|
if not image_data:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Ensure image_data is a list
|
||||||
|
if isinstance(image_data, str):
|
||||||
|
image_data = [image_data]
|
||||||
|
|
||||||
base_output = self.load_mm_data(
|
base_output = self.load_mm_data(
|
||||||
prompt=input_text,
|
prompt=input_text,
|
||||||
image_data=image_data,
|
image_data=image_data,
|
||||||
|
|||||||
Reference in New Issue
Block a user