feat: support TRT LLM benchmark and multiple benchmarks (#670)
This commit is contained in:
@@ -19,6 +19,7 @@ import traceback
|
||||
import warnings
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
@@ -59,6 +60,72 @@ def remove_prefix(text: str, prefix: str) -> str:
|
||||
return text[len(prefix) :] if text.startswith(prefix) else text
|
||||
|
||||
|
||||
# trt llm not support ignore_eos
|
||||
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
|
||||
async def async_request_trt_llm(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate_stream")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert not request_func_input.use_beam_search
|
||||
assert request_func_input.best_of == 1
|
||||
payload = {
|
||||
"accumulate_tokens": True,
|
||||
"text_input": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"top_p": 1.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:")
|
||||
|
||||
data = json.loads(chunk)
|
||||
output.generated_text += data["text_output"]
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.latency = most_recent_timestamp - st
|
||||
output.success = True
|
||||
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
# set ignore_eos True by default
|
||||
async def async_request_openai_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
@@ -167,6 +234,7 @@ ASYNC_REQUEST_FUNCS = {
|
||||
"sglang": async_request_openai_completions,
|
||||
"vllm": async_request_openai_completions,
|
||||
"lmdeploy": async_request_openai_completions,
|
||||
"trt": async_request_trt_llm,
|
||||
}
|
||||
|
||||
|
||||
@@ -449,6 +517,7 @@ async def benchmark(
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
request_rate: float,
|
||||
disable_tqdm: bool,
|
||||
enable_multi: bool,
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@@ -542,6 +611,37 @@ async def benchmark(
|
||||
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
||||
print("=" * 50)
|
||||
|
||||
if enable_multi:
|
||||
if (
|
||||
metrics.median_ttft_ms is not None
|
||||
and metrics.mean_itl_ms is not None
|
||||
and metrics.output_throughput is not None
|
||||
):
|
||||
result = {
|
||||
"dataset_name": args.dataset_name,
|
||||
"request_rate": request_rate,
|
||||
"median_ttft": metrics.median_ttft_ms,
|
||||
"median_itl": metrics.mean_itl_ms,
|
||||
"output_token_throughput": metrics.output_throughput,
|
||||
"sharegpt_output_len": args.sharegpt_output_len,
|
||||
"random_input_len": args.random_input_len,
|
||||
"random_output_len": args.random_output_len,
|
||||
}
|
||||
else:
|
||||
print(f"Error running benchmark for request rate: {request_rate}")
|
||||
print("-" * 30)
|
||||
|
||||
# Determine output file name
|
||||
if args.output_file:
|
||||
output_file_name = args.output_file
|
||||
else:
|
||||
now = datetime.now().strftime("%m%d%H")
|
||||
output_file_name = f"{args.backend}_{now}.jsonl"
|
||||
|
||||
# Append results to a JSONL file
|
||||
with open(output_file_name, "a") as file:
|
||||
file.write(json.dumps(result) + "\n")
|
||||
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
@@ -572,6 +672,11 @@ async def benchmark(
|
||||
return result
|
||||
|
||||
|
||||
def parse_request_rate_range(request_rate_range):
|
||||
start, stop, step = map(int, request_rate_range.split(","))
|
||||
return list(range(start, stop, step))
|
||||
|
||||
|
||||
def fire(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
@@ -581,6 +686,7 @@ def fire(args: argparse.Namespace):
|
||||
"sglang": 30000,
|
||||
"lmdeploy": 23333,
|
||||
"vllm": 8000,
|
||||
"trt": 8000,
|
||||
}.get(args.backend, 30000)
|
||||
|
||||
api_url = (
|
||||
@@ -594,6 +700,16 @@ def fire(args: argparse.Namespace):
|
||||
else f"http://{args.host}:{args.port}/v1/models"
|
||||
)
|
||||
|
||||
if args.backend == "trt":
|
||||
api_url = (
|
||||
f"{args.base_url}/v2/models/ensemble/generate_stream"
|
||||
if args.base_url
|
||||
else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream"
|
||||
)
|
||||
if args.model is None:
|
||||
print("Please provide a model using `--model` when using `trt` backend.")
|
||||
sys.exit(1)
|
||||
|
||||
if args.model is None:
|
||||
try:
|
||||
response = requests.get(model_url)
|
||||
@@ -637,17 +753,35 @@ def fire(args: argparse.Namespace):
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||
|
||||
asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_rate=args.request_rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
if args.multi:
|
||||
request_rates = parse_request_rate_range(args.request_rate_range)
|
||||
|
||||
for rate in request_rates:
|
||||
asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_rate=rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
enable_multi=args.multi,
|
||||
)
|
||||
)
|
||||
else:
|
||||
asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_rate=args.request_rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
enable_multi=args.multi,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# to avoid relying on SGLang's components
|
||||
@@ -751,6 +885,18 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Specify to disable tqdm progress bar.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multi",
|
||||
action="store_true",
|
||||
help="Use request rate range rather than single value.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-rate-range",
|
||||
type=str,
|
||||
default="2,34,2",
|
||||
help="Range of request rates in the format start,stop,step. Default is 2,34,2",
|
||||
)
|
||||
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
||||
|
||||
set_ulimit()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user