Enable cuda graph by default (#612)
This commit is contained in:
@@ -1,45 +1,43 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 bench_one.py --input-len 2048 --batch-size 1 2 4 8 16 32 64 128 256 512
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=None)
|
||||
parser.add_argument("--backend", type=str, default="srt")
|
||||
parser.add_argument("--batch-size", type=int, default=1)
|
||||
parser.add_argument("--max-tokens", type=int, default=256)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.port is None:
|
||||
if args.backend == "srt":
|
||||
args.port = 30000
|
||||
elif args.backend == "vllm":
|
||||
args.port = 21000
|
||||
elif args.backend == "lightllm":
|
||||
args.port = 22000
|
||||
elif args.backend == "ginfer":
|
||||
args.port = 9988
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
def run_one_batch_size(bs):
|
||||
url = f"{args.host}:{args.port}"
|
||||
a = 20
|
||||
max_new_tokens = args.max_tokens
|
||||
|
||||
a = 20
|
||||
prompt = f"{a, }"
|
||||
|
||||
tic = time.time()
|
||||
if args.backend == "srt":
|
||||
if args.input_len:
|
||||
inputs = {"input_ids": [
|
||||
[int(x) for x in np.random.randint(0, high=16384, size=(args.input_len,))] for _ in range(bs)
|
||||
]}
|
||||
else:
|
||||
inputs = {"text": [
|
||||
f"{i, }" for i in range(bs)
|
||||
]}
|
||||
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"text": [prompt] * args.batch_size,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"ignore_eos": True,
|
||||
},
|
||||
**inputs,
|
||||
},
|
||||
)
|
||||
elif args.backend == "lightllm":
|
||||
@@ -91,5 +89,41 @@ if __name__ == "__main__":
|
||||
ret = response.json()
|
||||
print(ret)
|
||||
|
||||
speed = args.batch_size * max_new_tokens / latency
|
||||
print(f"latency: {latency:.2f} s, speed: {speed:.2f} token/s")
|
||||
output_throughput = bs * max_new_tokens / latency
|
||||
print(f"latency: {latency:.2f} s, speed: {output_throughput:.2f} token/s")
|
||||
|
||||
with open("tmp_output.txt", "a") as fout:
|
||||
res = {
|
||||
"input_len": args.input_len,
|
||||
"output_len": args.max_tokens,
|
||||
"batch_size": bs,
|
||||
"latency": latency,
|
||||
"output_throughput": output_throughput
|
||||
}
|
||||
fout.write(json.dumps(res) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=None)
|
||||
parser.add_argument("--backend", type=str, default="srt")
|
||||
parser.add_argument("--input-len", type=int, default=None)
|
||||
parser.add_argument("--batch-size", type=int, nargs='*', default=[1])
|
||||
parser.add_argument("--max-tokens", type=int, default=256)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.port is None:
|
||||
if args.backend == "srt":
|
||||
args.port = 30000
|
||||
elif args.backend == "vllm":
|
||||
args.port = 21000
|
||||
elif args.backend == "lightllm":
|
||||
args.port = 22000
|
||||
elif args.backend == "ginfer":
|
||||
args.port = 9988
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
for bs in args.batch_size:
|
||||
run_one_batch_size(bs)
|
||||
|
||||
Reference in New Issue
Block a user