Improve benchmark scripts & fix llava (#613)

This commit is contained in:
Lianmin Zheng
2024-07-13 15:00:26 -07:00
committed by GitHub
parent 665815969a
commit 65c6577696
4 changed files with 43 additions and 22 deletions

View File

@@ -15,19 +15,19 @@ def run_one_batch_size(bs):
url = f"{args.host}:{args.port}"
max_new_tokens = args.max_tokens
a = 20
prompt = f"{a, }"
if args.input_len:
input_ids = [
[int(x) for x in np.random.randint(0, high=16384, size=(args.input_len,))] for _ in range(bs)
]
else:
text = [f"{i, }" for i in range(bs)]
tic = time.time()
if args.backend == "srt":
if args.input_len:
inputs = {"input_ids": [
[int(x) for x in np.random.randint(0, high=16384, size=(args.input_len,))] for _ in range(bs)
]}
inputs = {"input_ids": input_ids}
else:
inputs = {"text": [
f"{i, }" for i in range(bs)
]}
inputs = {"text": text}
response = requests.post(
url + "/generate",
@@ -44,7 +44,7 @@ def run_one_batch_size(bs):
response = requests.post(
url + "/generate",
json={
"inputs": prompt,
"inputs": text[0],
"parameters": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
@@ -53,13 +53,19 @@ def run_one_batch_size(bs):
},
)
elif args.backend == "vllm":
if args.input_len:
inputs = {"prompt": input_ids}
else:
inputs = {"prompt": text}
response = requests.post(
url + "/generate",
url + "/v1/completions",
json={
"prompt": prompt,
"model": args.vllm_model_name,
"temperature": 0,
"max_tokens": max_new_tokens,
"ignore_eos": True,
**inputs,
},
)
elif args.backend == "ginfer":
@@ -71,7 +77,7 @@ def run_one_batch_size(bs):
tic = time.time()
sample_request = sampler_pb2.SampleTextRequest(
prompt=prompt,
prompt=text[0],
settings=sampler_pb2.SampleSettings(
max_len=max_new_tokens,
rng_seed=0,
@@ -92,7 +98,7 @@ def run_one_batch_size(bs):
output_throughput = bs * max_new_tokens / latency
print(f"latency: {latency:.2f} s, speed: {output_throughput:.2f} token/s")
with open("tmp_output.txt", "a") as fout:
with open("results.jsonl", "a") as fout:
res = {
"input_len": args.input_len,
"output_len": args.max_tokens,
@@ -111,6 +117,7 @@ if __name__ == "__main__":
parser.add_argument("--input-len", type=int, default=None)
parser.add_argument("--batch-size", type=int, nargs='*', default=[1])
parser.add_argument("--max-tokens", type=int, default=256)
parser.add_argument("--vllm-model-name", type=str, default="meta-llama/Meta-Llama-3-70B")
args = parser.parse_args()
if args.port is None: