feat: support TRT LLM benchmark and multiple benchmarks (#670)
This commit is contained in:
@@ -19,6 +19,7 @@ import traceback
|
|||||||
import warnings
|
import warnings
|
||||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@@ -59,6 +60,72 @@ def remove_prefix(text: str, prefix: str) -> str:
|
|||||||
return text[len(prefix) :] if text.startswith(prefix) else text
|
return text[len(prefix) :] if text.startswith(prefix) else text
|
||||||
|
|
||||||
|
|
||||||
|
# trt llm not support ignore_eos
|
||||||
|
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
|
||||||
|
async def async_request_trt_llm(
|
||||||
|
request_func_input: RequestFuncInput,
|
||||||
|
pbar: Optional[tqdm] = None,
|
||||||
|
) -> RequestFuncOutput:
|
||||||
|
api_url = request_func_input.api_url
|
||||||
|
assert api_url.endswith("generate_stream")
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
|
assert not request_func_input.use_beam_search
|
||||||
|
assert request_func_input.best_of == 1
|
||||||
|
payload = {
|
||||||
|
"accumulate_tokens": True,
|
||||||
|
"text_input": request_func_input.prompt,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"max_tokens": request_func_input.output_len,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
output = RequestFuncOutput()
|
||||||
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
|
ttft = 0.0
|
||||||
|
st = time.perf_counter()
|
||||||
|
most_recent_timestamp = st
|
||||||
|
try:
|
||||||
|
async with session.post(url=api_url, json=payload) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
async for chunk_bytes in response.content:
|
||||||
|
chunk_bytes = chunk_bytes.strip()
|
||||||
|
if not chunk_bytes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:")
|
||||||
|
|
||||||
|
data = json.loads(chunk)
|
||||||
|
output.generated_text += data["text_output"]
|
||||||
|
timestamp = time.perf_counter()
|
||||||
|
# First token
|
||||||
|
if ttft == 0.0:
|
||||||
|
ttft = time.perf_counter() - st
|
||||||
|
output.ttft = ttft
|
||||||
|
|
||||||
|
# Decoding phase
|
||||||
|
else:
|
||||||
|
output.itl.append(timestamp - most_recent_timestamp)
|
||||||
|
|
||||||
|
most_recent_timestamp = timestamp
|
||||||
|
|
||||||
|
output.latency = most_recent_timestamp - st
|
||||||
|
output.success = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
output.error = response.reason or ""
|
||||||
|
output.success = False
|
||||||
|
except Exception:
|
||||||
|
output.success = False
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
output.error = "".join(traceback.format_exception(*exc_info))
|
||||||
|
|
||||||
|
if pbar:
|
||||||
|
pbar.update(1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
# set ignore_eos True by default
|
# set ignore_eos True by default
|
||||||
async def async_request_openai_completions(
|
async def async_request_openai_completions(
|
||||||
request_func_input: RequestFuncInput,
|
request_func_input: RequestFuncInput,
|
||||||
@@ -167,6 +234,7 @@ ASYNC_REQUEST_FUNCS = {
|
|||||||
"sglang": async_request_openai_completions,
|
"sglang": async_request_openai_completions,
|
||||||
"vllm": async_request_openai_completions,
|
"vllm": async_request_openai_completions,
|
||||||
"lmdeploy": async_request_openai_completions,
|
"lmdeploy": async_request_openai_completions,
|
||||||
|
"trt": async_request_trt_llm,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -449,6 +517,7 @@ async def benchmark(
|
|||||||
input_requests: List[Tuple[str, int, int]],
|
input_requests: List[Tuple[str, int, int]],
|
||||||
request_rate: float,
|
request_rate: float,
|
||||||
disable_tqdm: bool,
|
disable_tqdm: bool,
|
||||||
|
enable_multi: bool,
|
||||||
):
|
):
|
||||||
if backend in ASYNC_REQUEST_FUNCS:
|
if backend in ASYNC_REQUEST_FUNCS:
|
||||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||||
@@ -542,6 +611,37 @@ async def benchmark(
|
|||||||
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
|
if enable_multi:
|
||||||
|
if (
|
||||||
|
metrics.median_ttft_ms is not None
|
||||||
|
and metrics.mean_itl_ms is not None
|
||||||
|
and metrics.output_throughput is not None
|
||||||
|
):
|
||||||
|
result = {
|
||||||
|
"dataset_name": args.dataset_name,
|
||||||
|
"request_rate": request_rate,
|
||||||
|
"median_ttft": metrics.median_ttft_ms,
|
||||||
|
"median_itl": metrics.mean_itl_ms,
|
||||||
|
"output_token_throughput": metrics.output_throughput,
|
||||||
|
"sharegpt_output_len": args.sharegpt_output_len,
|
||||||
|
"random_input_len": args.random_input_len,
|
||||||
|
"random_output_len": args.random_output_len,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
print(f"Error running benchmark for request rate: {request_rate}")
|
||||||
|
print("-" * 30)
|
||||||
|
|
||||||
|
# Determine output file name
|
||||||
|
if args.output_file:
|
||||||
|
output_file_name = args.output_file
|
||||||
|
else:
|
||||||
|
now = datetime.now().strftime("%m%d%H")
|
||||||
|
output_file_name = f"{args.backend}_{now}.jsonl"
|
||||||
|
|
||||||
|
# Append results to a JSONL file
|
||||||
|
with open(output_file_name, "a") as file:
|
||||||
|
file.write(json.dumps(result) + "\n")
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"duration": benchmark_duration,
|
"duration": benchmark_duration,
|
||||||
"completed": metrics.completed,
|
"completed": metrics.completed,
|
||||||
@@ -572,6 +672,11 @@ async def benchmark(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def parse_request_rate_range(request_rate_range):
|
||||||
|
start, stop, step = map(int, request_rate_range.split(","))
|
||||||
|
return list(range(start, stop, step))
|
||||||
|
|
||||||
|
|
||||||
def fire(args: argparse.Namespace):
|
def fire(args: argparse.Namespace):
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
@@ -581,6 +686,7 @@ def fire(args: argparse.Namespace):
|
|||||||
"sglang": 30000,
|
"sglang": 30000,
|
||||||
"lmdeploy": 23333,
|
"lmdeploy": 23333,
|
||||||
"vllm": 8000,
|
"vllm": 8000,
|
||||||
|
"trt": 8000,
|
||||||
}.get(args.backend, 30000)
|
}.get(args.backend, 30000)
|
||||||
|
|
||||||
api_url = (
|
api_url = (
|
||||||
@@ -594,6 +700,16 @@ def fire(args: argparse.Namespace):
|
|||||||
else f"http://{args.host}:{args.port}/v1/models"
|
else f"http://{args.host}:{args.port}/v1/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.backend == "trt":
|
||||||
|
api_url = (
|
||||||
|
f"{args.base_url}/v2/models/ensemble/generate_stream"
|
||||||
|
if args.base_url
|
||||||
|
else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream"
|
||||||
|
)
|
||||||
|
if args.model is None:
|
||||||
|
print("Please provide a model using `--model` when using `trt` backend.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
if args.model is None:
|
if args.model is None:
|
||||||
try:
|
try:
|
||||||
response = requests.get(model_url)
|
response = requests.get(model_url)
|
||||||
@@ -637,17 +753,35 @@ def fire(args: argparse.Namespace):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||||
|
|
||||||
asyncio.run(
|
if args.multi:
|
||||||
benchmark(
|
request_rates = parse_request_rate_range(args.request_rate_range)
|
||||||
backend=backend,
|
|
||||||
api_url=api_url,
|
for rate in request_rates:
|
||||||
model_id=model_id,
|
asyncio.run(
|
||||||
tokenizer=tokenizer,
|
benchmark(
|
||||||
input_requests=input_requests,
|
backend=backend,
|
||||||
request_rate=args.request_rate,
|
api_url=api_url,
|
||||||
disable_tqdm=args.disable_tqdm,
|
model_id=model_id,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
input_requests=input_requests,
|
||||||
|
request_rate=rate,
|
||||||
|
disable_tqdm=args.disable_tqdm,
|
||||||
|
enable_multi=args.multi,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
asyncio.run(
|
||||||
|
benchmark(
|
||||||
|
backend=backend,
|
||||||
|
api_url=api_url,
|
||||||
|
model_id=model_id,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
input_requests=input_requests,
|
||||||
|
request_rate=args.request_rate,
|
||||||
|
disable_tqdm=args.disable_tqdm,
|
||||||
|
enable_multi=args.multi,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# to avoid relying on SGLang's components
|
# to avoid relying on SGLang's components
|
||||||
@@ -751,6 +885,18 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Specify to disable tqdm progress bar.",
|
help="Specify to disable tqdm progress bar.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--multi",
|
||||||
|
action="store_true",
|
||||||
|
help="Use request rate range rather than single value.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--request-rate-range",
|
||||||
|
type=str,
|
||||||
|
default="2,34,2",
|
||||||
|
help="Range of request rates in the format start,stop,step. Default is 2,34,2",
|
||||||
|
)
|
||||||
|
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
||||||
|
|
||||||
set_ulimit()
|
set_ulimit()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user