Files
sglang/benchmark/latency_throughput/bench_one.py
2024-07-15 20:44:04 -07:00

148 lines
4.4 KiB
Python

"""
Usage:
python3 bench_one.py --input-len 2048 --batch-size 1 2 4 8 16 32 64 128 256 512
"""
import argparse
import json
import time
import numpy as np
import requests
def run_one_batch_size(bs):
url = f"{args.host}:{args.port}"
max_new_tokens = args.max_tokens
if args.input_len:
input_ids = [
[int(x) for x in np.random.randint(0, high=16384, size=(args.input_len,))]
for _ in range(bs)
]
else:
text = [f"{i, }" for i in range(bs)]
tic = time.time()
if args.backend == "srt":
if args.input_len:
inputs = {"input_ids": input_ids}
else:
inputs = {"text": text}
response = requests.post(
url + "/generate",
json={
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
"ignore_eos": True,
},
**inputs,
},
)
elif args.backend == "lightllm":
response = requests.post(
url + "/generate",
json={
"inputs": text[0],
"parameters": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
"ignore_eos": True,
},
},
)
elif args.backend == "vllm":
if args.input_len:
inputs = {"prompt": input_ids}
else:
inputs = {"prompt": text}
response = requests.post(
url + "/v1/completions",
json={
"model": args.vllm_model_name,
"temperature": 0,
"max_tokens": max_new_tokens,
"ignore_eos": True,
**inputs,
},
)
elif args.backend == "ginfer":
import grpc
from ginfer import sampler_pb2, sampler_pb2_grpc
sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)
tic = time.time()
sample_request = sampler_pb2.SampleTextRequest(
prompt=text[0],
settings=sampler_pb2.SampleSettings(
max_len=max_new_tokens,
rng_seed=0,
temperature=0,
nucleus_p=1,
),
)
stream = sampler.SampleText(sample_request)
response = "".join([x.text for x in stream])
latency = time.time() - tic
if isinstance(response, str):
ret = response
else:
ret = response.json()
print(ret)
input_len = args.input_len if args.input_len else 1
output_len = max_new_tokens
output_throughput = bs * max_new_tokens / latency
overall_throughput = bs * (input_len + output_len) / latency
print(f"latency: {latency:.2f} s")
print(f"output throughput: {output_throughput:.2f} token/s")
print(f"(input + output) throughput: {overall_throughput:.2f} token/s")
with open("results.jsonl", "a") as fout:
res = {
"backend": args.backend,
"input_len": args.input_len,
"output_len": args.max_tokens,
"batch_size": bs,
"latency": latency,
"output_throughput": output_throughput,
"overall_throughput": overall_throughput,
}
fout.write(json.dumps(res) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=None)
parser.add_argument("--backend", type=str, default="srt")
parser.add_argument("--input-len", type=int, default=None)
parser.add_argument("--batch-size", type=int, nargs="*", default=[1])
parser.add_argument("--max-tokens", type=int, default=256)
parser.add_argument(
"--vllm-model-name", type=str, default="meta-llama/Meta-Llama-3-70B"
)
args = parser.parse_args()
if args.port is None:
if args.backend == "srt":
args.port = 30000
elif args.backend == "vllm":
args.port = 21000
elif args.backend == "lightllm":
args.port = 22000
elif args.backend == "ginfer":
args.port = 9988
else:
raise ValueError(f"Invalid backend: {args.backend}")
for bs in args.batch_size:
run_one_batch_size(bs)