diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 9991a40ab..b1abcd35e 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -19,6 +19,7 @@ import traceback import warnings from argparse import ArgumentParser as FlexibleArgumentParser from dataclasses import dataclass, field +from datetime import datetime from typing import AsyncGenerator, List, Optional, Tuple, Union import aiohttp @@ -59,6 +60,72 @@ def remove_prefix(text: str, prefix: str) -> str: return text[len(prefix) :] if text.startswith(prefix) else text +# trt llm not support ignore_eos +# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505 +async def async_request_trt_llm( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + assert api_url.endswith("generate_stream") + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + assert not request_func_input.use_beam_search + assert request_func_input.best_of == 1 + payload = { + "accumulate_tokens": True, + "text_input": request_func_input.prompt, + "temperature": 0.0, + "top_p": 1.0, + "max_tokens": request_func_input.output_len, + "stream": True, + } + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:") + + data = json.loads(chunk) + output.generated_text += data["text_output"] + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + + output.latency = most_recent_timestamp - st + output.success = True + + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + # set ignore_eos True by default async def async_request_openai_completions( request_func_input: RequestFuncInput, @@ -167,6 +234,7 @@ ASYNC_REQUEST_FUNCS = { "sglang": async_request_openai_completions, "vllm": async_request_openai_completions, "lmdeploy": async_request_openai_completions, + "trt": async_request_trt_llm, } @@ -449,6 +517,7 @@ async def benchmark( input_requests: List[Tuple[str, int, int]], request_rate: float, disable_tqdm: bool, + enable_multi: bool, ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -542,6 +611,37 @@ async def benchmark( print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) print("=" * 50) + if enable_multi: + if ( + metrics.median_ttft_ms is not None + and metrics.mean_itl_ms is not None + and metrics.output_throughput is not None + ): + result = { + "dataset_name": args.dataset_name, + "request_rate": request_rate, + "median_ttft": metrics.median_ttft_ms, + "median_itl": metrics.mean_itl_ms, + "output_token_throughput": metrics.output_throughput, + "sharegpt_output_len": args.sharegpt_output_len, + "random_input_len": args.random_input_len, + "random_output_len": args.random_output_len, + } + else: + print(f"Error running benchmark for request rate: {request_rate}") + print("-" * 30) + + # Determine output file name + if args.output_file: + output_file_name = args.output_file + else: + now = datetime.now().strftime("%m%d%H") + output_file_name = f"{args.backend}_{now}.jsonl" + + # Append results to a JSONL file + with open(output_file_name, "a") as file: + file.write(json.dumps(result) + "\n") + result = { "duration": benchmark_duration, "completed": metrics.completed, @@ -572,6 +672,11 @@ async def benchmark( return result +def parse_request_rate_range(request_rate_range): + start, stop, step = map(int, request_rate_range.split(",")) + return list(range(start, stop, step)) + + def fire(args: argparse.Namespace): random.seed(args.seed) np.random.seed(args.seed) @@ -581,6 +686,7 @@ def fire(args: argparse.Namespace): "sglang": 30000, "lmdeploy": 23333, "vllm": 8000, + "trt": 8000, }.get(args.backend, 30000) api_url = ( @@ -594,6 +700,16 @@ def fire(args: argparse.Namespace): else f"http://{args.host}:{args.port}/v1/models" ) + if args.backend == "trt": + api_url = ( + f"{args.base_url}/v2/models/ensemble/generate_stream" + if args.base_url + else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream" + ) + if args.model is None: + print("Please provide a model using `--model` when using `trt` backend.") + sys.exit(1) + if args.model is None: try: response = requests.get(model_url) @@ -637,17 +753,35 @@ def fire(args: argparse.Namespace): else: raise ValueError(f"Unknown dataset: {args.dataset_name}") - asyncio.run( - benchmark( - backend=backend, - api_url=api_url, - model_id=model_id, - tokenizer=tokenizer, - input_requests=input_requests, - request_rate=args.request_rate, - disable_tqdm=args.disable_tqdm, + if args.multi: + request_rates = parse_request_rate_range(args.request_rate_range) + + for rate in request_rates: + asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=rate, + disable_tqdm=args.disable_tqdm, + enable_multi=args.multi, + ) + ) + else: + asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=args.request_rate, + disable_tqdm=args.disable_tqdm, + enable_multi=args.multi, + ) ) - ) # to avoid relying on SGLang's components @@ -751,6 +885,18 @@ if __name__ == "__main__": action="store_true", help="Specify to disable tqdm progress bar.", ) + parser.add_argument( + "--multi", + action="store_true", + help="Use request rate range rather than single value.", + ) + parser.add_argument( + "--request-rate-range", + type=str, + default="2,34,2", + help="Range of request rates in the format start,stop,step. Default is 2,34,2", + ) + parser.add_argument("--output-file", type=str, help="Output JSONL file name.") set_ulimit()