diff --git a/python/sglang/bench_server_latency.py b/python/sglang/bench_server_latency.py index 6bbb3954b..66e59d0d4 100644 --- a/python/sglang/bench_server_latency.py +++ b/python/sglang/bench_server_latency.py @@ -6,6 +6,8 @@ It accepts arguments similar to those of launch_server.py. Usage: python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8 + +python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 """ import argparse @@ -15,7 +17,7 @@ import json import multiprocessing import os import time -from typing import Tuple +from typing import Optional, Tuple import numpy as np import requests @@ -32,6 +34,8 @@ class BenchArgs: input_len: Tuple[int] = (1024,) output_len: Tuple[int] = (16,) result_filename: str = "result.jsonl" + base_url: str = "" + skip_warmup: bool = False @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -48,6 +52,8 @@ class BenchArgs: parser.add_argument( "--result-filename", type=str, default=BenchArgs.result_filename ) + parser.add_argument("--base-url", type=str, default=BenchArgs.base_url) + parser.add_argument("--skip-warmup", action="store_true") @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -139,17 +145,21 @@ def run_one_case( def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): - proc, base_url = launch_server_process(server_args) + if bench_args.base_url: + proc, base_url = None, bench_args.base_url + else: + proc, base_url = launch_server_process(server_args) # warmup - run_one_case( - base_url, - batch_size=16, - input_len=1024, - output_len=16, - run_name="", - result_filename="", - ) + if not bench_args.skip_warmup: + run_one_case( + base_url, + batch_size=16, + input_len=1024, + output_len=16, + run_name="", + result_filename="", + ) # benchmark try: @@ -165,7 +175,8 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): bench_args.result_filename, ) finally: - kill_child_process(proc.pid) + if proc: + kill_child_process(proc.pid) print(f"\nResults are saved to {bench_args.result_filename}") diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 2f07973b8..2ca35aca9 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -222,6 +222,85 @@ async def async_request_openai_completions( return output +async def async_request_sglang_generate( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + prompt = request_func_input.prompt + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "text": prompt, + "sampling_params": { + "temperature": 0.0, + "max_new_tokens": request_func_input.output_len, + "ignore_eos": not args.disable_ignore_eos, + }, + "stream": not args.disable_stream, + **request_func_input.extra_request_body, + } + headers = {} + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + # print(chunk_bytes) + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if data["text"]: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text = data["text"] + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = request_func_input.output_len + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + async def async_request_gserver( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, @@ -264,7 +343,9 @@ def get_tokenizer( ASYNC_REQUEST_FUNCS = { - "sglang": async_request_openai_completions, + "sglang": async_request_sglang_generate, + "sglang-native": async_request_sglang_generate, + "sglang-oai": async_request_openai_completions, "vllm": async_request_openai_completions, "lmdeploy": async_request_openai_completions, "trt": async_request_trt_llm, @@ -387,6 +468,8 @@ def sample_sharegpt_requests( continue filtered_dataset.append((prompt, prompt_len, output_len)) + print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}") + print(f"#Output tokens: {np.sum([x[2] for x in filtered_dataset])}") return filtered_dataset @@ -784,24 +867,33 @@ def run_benchmark(args_: argparse.Namespace): if args.port is None: args.port = { "sglang": 30000, + "sglang-native": 30000, + "sglang-oai": 30000, "lmdeploy": 23333, "vllm": 8000, "trt": 8000, "gserver": 9988, }.get(args.backend, 30000) - api_url = ( - f"{args.base_url}/v1/completions" - if args.base_url - else f"http://{args.host}:{args.port}/v1/completions" - ) model_url = ( f"{args.base_url}/v1/models" if args.base_url else f"http://{args.host}:{args.port}/v1/models" ) - if args.backend == "trt": + if args.backend in ["sglang", "sglang-native"]: + api_url = ( + f"{args.base_url}/generate" + if args.base_url + else f"http://{args.host}:{args.port}/generate" + ) + elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]: + api_url = ( + f"{args.base_url}/v1/completions" + if args.base_url + else f"http://{args.host}:{args.port}/v1/completions" + ) + elif args.backend == "trt": api_url = ( f"{args.base_url}/v2/models/ensemble/generate_stream" if args.base_url