Enable cuda graph by default (#612)

This commit is contained in:
Lianmin Zheng
2024-07-13 05:29:46 -07:00
committed by GitHub
parent 396a69240f
commit 665815969a
10 changed files with 331 additions and 84 deletions

View File

@@ -1,45 +1,43 @@
"""
Usage:
python3 bench_one.py --input-len 2048 --batch-size 1 2 4 8 16 32 64 128 256 512
"""
import argparse
import json
import time
import numpy as np
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=None)
parser.add_argument("--backend", type=str, default="srt")
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--max-tokens", type=int, default=256)
args = parser.parse_args()
if args.port is None:
if args.backend == "srt":
args.port = 30000
elif args.backend == "vllm":
args.port = 21000
elif args.backend == "lightllm":
args.port = 22000
elif args.backend == "ginfer":
args.port = 9988
else:
raise ValueError(f"Invalid backend: {args.backend}")
def run_one_batch_size(bs):
url = f"{args.host}:{args.port}"
a = 20
max_new_tokens = args.max_tokens
a = 20
prompt = f"{a, }"
tic = time.time()
if args.backend == "srt":
if args.input_len:
inputs = {"input_ids": [
[int(x) for x in np.random.randint(0, high=16384, size=(args.input_len,))] for _ in range(bs)
]}
else:
inputs = {"text": [
f"{i, }" for i in range(bs)
]}
response = requests.post(
url + "/generate",
json={
"text": [prompt] * args.batch_size,
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
"ignore_eos": True,
},
**inputs,
},
)
elif args.backend == "lightllm":
@@ -91,5 +89,41 @@ if __name__ == "__main__":
ret = response.json()
print(ret)
speed = args.batch_size * max_new_tokens / latency
print(f"latency: {latency:.2f} s, speed: {speed:.2f} token/s")
output_throughput = bs * max_new_tokens / latency
print(f"latency: {latency:.2f} s, speed: {output_throughput:.2f} token/s")
with open("tmp_output.txt", "a") as fout:
res = {
"input_len": args.input_len,
"output_len": args.max_tokens,
"batch_size": bs,
"latency": latency,
"output_throughput": output_throughput
}
fout.write(json.dumps(res) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=None)
parser.add_argument("--backend", type=str, default="srt")
parser.add_argument("--input-len", type=int, default=None)
parser.add_argument("--batch-size", type=int, nargs='*', default=[1])
parser.add_argument("--max-tokens", type=int, default=256)
args = parser.parse_args()
if args.port is None:
if args.backend == "srt":
args.port = 30000
elif args.backend == "vllm":
args.port = 21000
elif args.backend == "lightllm":
args.port = 22000
elif args.backend == "ginfer":
args.port = 9988
else:
raise ValueError(f"Invalid backend: {args.backend}")
for bs in args.batch_size:
run_one_batch_size(bs)