[Minor] Many cleanup (#1357)

This commit is contained in:
Lianmin Zheng
2024-09-09 04:14:11 -07:00
committed by GitHub
parent c9b75917d5
commit e4d68afcf0
24 changed files with 416 additions and 296 deletions

View File

@@ -4,11 +4,12 @@ import time
import numpy as np
from sglang.api import set_default_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import read_jsonl
from sglang.utils import download_and_cache_file, read_jsonl
def get_one_example(lines, i, include_answer):
@@ -26,16 +27,23 @@ def get_few_shot_examples(lines, k):
def main(args):
lines = read_jsonl(args.data_path)
# Select backend
set_default_backend(select_sglang_backend(args))
# Read data
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
k = args.num_shot
few_shot_examples = get_few_shot_examples(lines, k)
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
choices = []
labels = []
for i in range(len(lines[: args.num_questions])):
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
choices.append(lines[i]["endings"])
labels.append(lines[i]["label"])
@@ -56,15 +64,11 @@ def main(args):
########## SGL Program End ##########
#####################################
# Select backend
backend = select_sglang_backend(args)
# Run requests
tic = time.time()
rets = few_shot_hellaswag.run_batch(
arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
@@ -95,7 +99,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=20)
parser.add_argument("--num-shots", type=int, default=20)
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_sglang_args_and_parse(parser)