oai: Adds support for OpenAI chat completions API in bench_serving (#7036)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com> Co-authored-by: yhyang201 <47235274+yhyang201@users.noreply.github.com> Co-authored-by: Mick <mickjagger19@icloud.com>
This commit is contained in:
@@ -265,6 +265,138 @@ async def async_request_openai_completions(
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_chat_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
"""Makes a request to the OpenAI Chat Completions API.
|
||||
|
||||
Handles both streaming and non-streaming responses, including support
|
||||
for image data in messages. Calculates and returns various performance
|
||||
metrics.
|
||||
|
||||
Args:
|
||||
request_func_input: Input parameters for the request.
|
||||
pbar: Optional tqdm progress bar to update.
|
||||
|
||||
Returns:
|
||||
RequestFuncOutput: Output of the request, including generated text,
|
||||
latency, TTFT, ITL, and success status.
|
||||
"""
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
"chat/completions"
|
||||
), "OpenAI Chat Completions API URL must end with 'chat/completions'."
|
||||
|
||||
if request_func_input.image_data:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": request_func_input.image_data},
|
||||
},
|
||||
{"type": "text", "text": request_func_input.prompt},
|
||||
],
|
||||
},
|
||||
]
|
||||
else:
|
||||
messages = [{"role": "user", "content": request_func_input.prompt}]
|
||||
|
||||
async with _create_bench_client_session() as session:
|
||||
payload = {
|
||||
"model": request_func_input.model,
|
||||
"messages": messages,
|
||||
"temperature": 0.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"stream": not args.disable_stream,
|
||||
**request_func_input.extra_request_body,
|
||||
}
|
||||
headers = get_auth_headers()
|
||||
|
||||
output = RequestFuncOutput.init_new(request_func_input)
|
||||
|
||||
generated_text = ""
|
||||
output_len = request_func_input.output_len
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(
|
||||
url=api_url, json=payload, headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
if args.disable_stream:
|
||||
# Non-streaming response
|
||||
response_json = await response.json()
|
||||
output.generated_text = response_json["choices"][0]["message"][
|
||||
"content"
|
||||
]
|
||||
output.success = True
|
||||
output.latency = time.perf_counter() - st
|
||||
output.ttft = (
|
||||
output.latency
|
||||
) # For non-streaming, TTFT = total latency
|
||||
output.output_len = response_json.get("usage", {}).get(
|
||||
"completion_tokens", output_len
|
||||
)
|
||||
else:
|
||||
# Streaming response
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
||||
latency = time.perf_counter() - st
|
||||
if chunk == "[DONE]":
|
||||
pass
|
||||
else:
|
||||
data = json.loads(chunk)
|
||||
|
||||
# Check if this chunk contains content
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
|
||||
if content:
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(
|
||||
timestamp - most_recent_timestamp
|
||||
)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += content
|
||||
|
||||
# Check for usage info in final chunk
|
||||
output_len = (data.get("usage") or {}).get(
|
||||
"completion_tokens", output_len
|
||||
)
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
output.output_len = output_len
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_truss(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
@@ -544,6 +676,7 @@ def get_dataset(args, tokenizer):
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
fixed_output_len=args.random_output_len,
|
||||
apply_chat_template=args.apply_chat_template,
|
||||
random_sample=True,
|
||||
)
|
||||
else:
|
||||
@@ -555,8 +688,11 @@ ASYNC_REQUEST_FUNCS = {
|
||||
"sglang": async_request_sglang_generate,
|
||||
"sglang-native": async_request_sglang_generate,
|
||||
"sglang-oai": async_request_openai_completions,
|
||||
"sglang-oai-chat": async_request_openai_chat_completions,
|
||||
"vllm": async_request_openai_completions,
|
||||
"vllm-chat": async_request_openai_chat_completions,
|
||||
"lmdeploy": async_request_openai_completions,
|
||||
"lmdeploy-chat": async_request_openai_chat_completions,
|
||||
"trt": async_request_trt_llm,
|
||||
"gserver": async_request_gserver,
|
||||
"truss": async_request_truss,
|
||||
@@ -661,6 +797,7 @@ def sample_mmmu_requests(
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
apply_chat_template: bool = True,
|
||||
random_sample: bool = True,
|
||||
) -> List[DatasetRow]:
|
||||
"""
|
||||
@@ -670,6 +807,7 @@ def sample_mmmu_requests(
|
||||
num_requests: Number of requests to sample.
|
||||
tokenizer: Tokenizer to use for token counting.
|
||||
fixed_output_len: If provided, use this fixed output length for all requests.
|
||||
apply_chat_template: Whether to apply the chat template to the prompt.
|
||||
random_sample: Whether to randomly sample or take the first N.
|
||||
|
||||
Returns:
|
||||
@@ -739,28 +877,30 @@ def sample_mmmu_requests(
|
||||
|
||||
# Construct the prompt
|
||||
prompt = f"Question: {question}\n\nAnswer: "
|
||||
|
||||
try:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
except Exception as e:
|
||||
# Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL)
|
||||
print(f"Error applying chat template: {e}, fallback to <image> tag")
|
||||
prompt = f"<image>{prompt}"
|
||||
if apply_chat_template:
|
||||
try:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
],
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
except Exception as e:
|
||||
# Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL)
|
||||
print(
|
||||
f"Error applying chat template: {e}, fallback to <image> tag"
|
||||
)
|
||||
prompt = f"<image>{prompt}"
|
||||
|
||||
# Calculate token lengths for text only (without image data)
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
@@ -1538,12 +1678,19 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
if args.base_url
|
||||
else f"http://{args.host}:{args.port}/generate"
|
||||
)
|
||||
args.apply_chat_template = True
|
||||
elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
|
||||
api_url = (
|
||||
f"{args.base_url}/v1/completions"
|
||||
if args.base_url
|
||||
else f"http://{args.host}:{args.port}/v1/completions"
|
||||
)
|
||||
elif args.backend in ["sglang-oai-chat", "vllm-chat", "lmdeploy-chat"]:
|
||||
api_url = (
|
||||
f"{args.base_url}/v1/chat/completions"
|
||||
if args.base_url
|
||||
else f"http://{args.host}:{args.port}/v1/chat/completions"
|
||||
)
|
||||
elif args.backend == "trt":
|
||||
api_url = (
|
||||
f"{args.base_url}/v2/models/ensemble/generate_stream"
|
||||
|
||||
Reference in New Issue
Block a user