diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 5057f3e15..4e606dfa2 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -265,6 +265,138 @@ async def async_request_openai_completions( return output +async def async_request_openai_chat_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + """Makes a request to the OpenAI Chat Completions API. + + Handles both streaming and non-streaming responses, including support + for image data in messages. Calculates and returns various performance + metrics. + + Args: + request_func_input: Input parameters for the request. + pbar: Optional tqdm progress bar to update. + + Returns: + RequestFuncOutput: Output of the request, including generated text, + latency, TTFT, ITL, and success status. + """ + api_url = request_func_input.api_url + assert api_url.endswith( + "chat/completions" + ), "OpenAI Chat Completions API URL must end with 'chat/completions'." + + if request_func_input.image_data: + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": request_func_input.image_data}, + }, + {"type": "text", "text": request_func_input.prompt}, + ], + }, + ] + else: + messages = [{"role": "user", "content": request_func_input.prompt}] + + async with _create_bench_client_session() as session: + payload = { + "model": request_func_input.model, + "messages": messages, + "temperature": 0.0, + "max_tokens": request_func_input.output_len, + "stream": not args.disable_stream, + **request_func_input.extra_request_body, + } + headers = get_auth_headers() + + output = RequestFuncOutput.init_new(request_func_input) + + generated_text = "" + output_len = request_func_input.output_len + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + if args.disable_stream: + # Non-streaming response + response_json = await response.json() + output.generated_text = response_json["choices"][0]["message"][ + "content" + ] + output.success = True + output.latency = time.perf_counter() - st + output.ttft = ( + output.latency + ) # For non-streaming, TTFT = total latency + output.output_len = response_json.get("usage", {}).get( + "completion_tokens", output_len + ) + else: + # Streaming response + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + + # Check if this chunk contains content + delta = data.get("choices", [{}])[0].get("delta", {}) + content = delta.get("content", "") + + if content: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append( + timestamp - most_recent_timestamp + ) + + most_recent_timestamp = timestamp + generated_text += content + + # Check for usage info in final chunk + output_len = (data.get("usage") or {}).get( + "completion_tokens", output_len + ) + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = output_len + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + async def async_request_truss( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, @@ -544,6 +676,7 @@ def get_dataset(args, tokenizer): num_requests=args.num_prompts, tokenizer=tokenizer, fixed_output_len=args.random_output_len, + apply_chat_template=args.apply_chat_template, random_sample=True, ) else: @@ -555,8 +688,11 @@ ASYNC_REQUEST_FUNCS = { "sglang": async_request_sglang_generate, "sglang-native": async_request_sglang_generate, "sglang-oai": async_request_openai_completions, + "sglang-oai-chat": async_request_openai_chat_completions, "vllm": async_request_openai_completions, + "vllm-chat": async_request_openai_chat_completions, "lmdeploy": async_request_openai_completions, + "lmdeploy-chat": async_request_openai_chat_completions, "trt": async_request_trt_llm, "gserver": async_request_gserver, "truss": async_request_truss, @@ -661,6 +797,7 @@ def sample_mmmu_requests( num_requests: int, tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None, + apply_chat_template: bool = True, random_sample: bool = True, ) -> List[DatasetRow]: """ @@ -670,6 +807,7 @@ def sample_mmmu_requests( num_requests: Number of requests to sample. tokenizer: Tokenizer to use for token counting. fixed_output_len: If provided, use this fixed output length for all requests. + apply_chat_template: Whether to apply the chat template to the prompt. random_sample: Whether to randomly sample or take the first N. Returns: @@ -739,28 +877,30 @@ def sample_mmmu_requests( # Construct the prompt prompt = f"Question: {question}\n\nAnswer: " - - try: - prompt = tokenizer.apply_chat_template( - [ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": {"url": image_data}, - }, - {"type": "text", "text": prompt}, - ], - } - ], - add_generation_prompt=True, - tokenize=False, - ) - except Exception as e: - # Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL) - print(f"Error applying chat template: {e}, fallback to tag") - prompt = f"{prompt}" + if apply_chat_template: + try: + prompt = tokenizer.apply_chat_template( + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": image_data}, + }, + {"type": "text", "text": prompt}, + ], + } + ], + add_generation_prompt=True, + tokenize=False, + ) + except Exception as e: + # Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL) + print( + f"Error applying chat template: {e}, fallback to tag" + ) + prompt = f"{prompt}" # Calculate token lengths for text only (without image data) prompt_token_ids = tokenizer.encode(prompt) @@ -1538,12 +1678,19 @@ def run_benchmark(args_: argparse.Namespace): if args.base_url else f"http://{args.host}:{args.port}/generate" ) + args.apply_chat_template = True elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]: api_url = ( f"{args.base_url}/v1/completions" if args.base_url else f"http://{args.host}:{args.port}/v1/completions" ) + elif args.backend in ["sglang-oai-chat", "vllm-chat", "lmdeploy-chat"]: + api_url = ( + f"{args.base_url}/v1/chat/completions" + if args.base_url + else f"http://{args.host}:{args.port}/v1/chat/completions" + ) elif args.backend == "trt": api_url = ( f"{args.base_url}/v2/models/ensemble/generate_stream"