[Minor] Many cleanup (#1357)

This commit is contained in:
Lianmin Zheng
2024-09-09 04:14:11 -07:00
committed by GitHub
parent c9b75917d5
commit e4d68afcf0
24 changed files with 416 additions and 296 deletions

View File

@@ -8,7 +8,7 @@ import numpy as np
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_select
from sglang.utils import read_jsonl
from sglang.utils import download_and_cache_file, read_jsonl
def get_one_example(lines, i, include_answer):
@@ -26,25 +26,29 @@ def get_few_shot_examples(lines, k):
def main(args):
lines = read_jsonl(args.data_path)
# Select backend
call_select = get_call_select(args)
# Read data
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
k = args.num_shot
few_shot_examples = get_few_shot_examples(lines, k)
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
choices = []
labels = []
for i in range(len(lines[: args.num_questions])):
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
choices.append(lines[i]["endings"])
labels.append(lines[i]["label"])
preds = [None] * len(labels)
# Select backend
call_select = get_call_select(args)
# Run requests
if args.backend != "lmql":
# Use thread pool
@@ -65,7 +69,6 @@ def main(args):
total=len(questions),
)
)
else:
# Use asyncio
async def batched_call(batch_size):
@@ -108,7 +111,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=20)
parser.add_argument("--num-shots", type=int, default=20)
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_other_args_and_parse(parser)