Format Benchmark Code (#399)

This commit is contained in:
Liangsheng Yin
2024-04-28 21:06:22 +08:00
committed by GitHub
parent 19818b9c2f
commit 95c4e0dfac
41 changed files with 1169 additions and 608 deletions

View File

@@ -1,17 +1,22 @@
import argparse
import asyncio
from concurrent.futures import ThreadPoolExecutor
import json
from functools import partial
import time
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np
from sglang.test.test_utils import add_common_other_args_and_parse, call_select_lightllm, call_select_vllm
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_select_lightllm,
call_select_vllm,
)
from sglang.utils import read_jsonl
def get_one_example(lines, i, include_answer):
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
if include_answer:
ret += lines[i]["endings"][lines[i]["label"]]
return ret
@@ -34,7 +39,7 @@ def main(args):
questions = []
choices = []
labels = []
for i in range(len(lines[:args.num_questions])):
for i in range(len(lines[: args.num_questions])):
questions.append(get_one_example(lines, i, False))
choices.append(lines[i]["endings"])
labels.append(lines[i]["label"])
@@ -51,7 +56,11 @@ def main(args):
elif args.backend == "guidance":
from guidance import models, select
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def call_select(context, choices):
out = model + context + select(choices, name="answer")
@@ -61,8 +70,10 @@ def main(args):
elif args.backend == "lmql":
import lmql
model = lmql.model("meta-llama/Llama-2-7b-chat-hf",
endpoint=f"{args.host}:{args.port}")
model = lmql.model(
"meta-llama/Llama-2-7b-chat-hf", endpoint=f"{args.host}:{args.port}"
)
@lmql.query(model=model)
async def program(ctx, choices):
@@ -83,8 +94,8 @@ def main(args):
# Use thread pool
def get_one_answer(i):
preds[i] = call_select(
context=few_shot_examples + questions[i],
choices=choices[i])
context=few_shot_examples + questions[i], choices=choices[i]
)
tic = time.time()
if args.parallel == 1:
@@ -98,13 +109,13 @@ def main(args):
async def batched_call(batch_size):
for i in range(0, len(questions), batch_size):
tasks = []
for q, c in zip(questions[i:i+batch_size], choices[i:i+batch_size]):
tasks.append(call_select(
context=few_shot_examples + q,
choices=c))
for q, c in zip(
questions[i : i + batch_size], choices[i : i + batch_size]
):
tasks.append(call_select(context=few_shot_examples + q, choices=c))
rets = await asyncio.gather(*tasks)
for j in range(len(rets)):
preds[i+j] = rets[j]
preds[i + j] = rets[j]
tic = time.time()
asyncio.run(batched_call(batch_size=args.parallel))
@@ -128,7 +139,7 @@ def main(args):
"other": {
"num_questions": args.num_questions,
"parallel": args.parallel,
}
},
}
fout.write(json.dumps(value) + "\n")