Files
sglang/benchmark/mmlu/bench_sglang.py

179 lines
4.7 KiB
Python
Raw Normal View History

import argparse
import json
import os
import time
import numpy as np
import pandas as pd
import tiktoken
from tqdm import tqdm
2024-04-28 21:06:22 +08:00
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
choices = ["A", "B", "C", "D"]
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
def format_subject(subject):
l = subject.split("_")
s = ""
for entry in l:
s += " " + entry
return s
2024-04-28 21:06:22 +08:00
def format_example(df, idx, include_answer=True):
prompt = df.iloc[idx, 0]
k = df.shape[1] - 2
for j in range(k):
2024-04-28 21:06:22 +08:00
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
prompt += "\nAnswer:"
if include_answer:
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
return prompt
2024-04-28 21:06:22 +08:00
def gen_prompt(train_df, subject, k=-1):
2024-04-28 21:06:22 +08:00
prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(
format_subject(subject)
)
if k == -1:
k = train_df.shape[0]
for i in range(k):
prompt += format_example(train_df, i)
return prompt
2024-04-28 21:06:22 +08:00
def evaluate(args, subject, dev_df, test_df):
prompts = []
labels = []
k = args.ntrain
few_shot_examples = gen_prompt(dev_df, subject, k)
while len(tokenizer.encode(few_shot_examples)) > 1536:
k -= 1
few_shot_examples = gen_prompt(dev_df, subject, k)
for i in range(test_df.shape[0]):
prompt_end = format_example(test_df, i, include_answer=False)
prompts.append(prompt_end)
2024-04-28 21:06:22 +08:00
label = test_df.iloc[i, test_df.shape[1] - 1]
labels.append(label)
arguments = [{"question": p} for p in prompts]
#####################################
######### SGL Program Begin #########
#####################################
import sglang as sgl
2024-03-22 13:37:57 -07:00
if args.backend.startswith("gpt-"):
2024-04-28 21:06:22 +08:00
2024-03-22 13:37:57 -07:00
@sgl.function
def few_shot_mmlu(s, examples, question):
s += sgl.user(examples + question)
s += sgl.assistant(sgl.gen("answer"))
2024-04-28 21:06:22 +08:00
2024-03-22 13:37:57 -07:00
else:
2024-04-28 21:06:22 +08:00
2024-03-22 13:37:57 -07:00
@sgl.function
def few_shot_mmlu(s, examples, question):
s += examples + question + sgl.gen("answer")
#####################################
########## SGL Program End ##########
#####################################
# Select backend
backend = select_sglang_backend(args)
tic = time.time()
states = few_shot_mmlu.bind(examples=few_shot_examples).run_batch(
2024-04-28 21:06:22 +08:00
arguments,
temperature=0,
max_new_tokens=1,
backend=backend,
num_threads=args.parallel,
)
preds = [
s["answer"].strip()[0] if len(s["answer"].strip()) > 0 else "" for s in states
]
latency = time.time() - tic
cors = [pred == label for pred, label in zip(preds, labels)]
acc = np.mean(cors)
cors = np.array(cors)
2024-04-28 21:06:22 +08:00
print(
"Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format(
acc, latency, len(prompts), subject
)
)
return cors, acc, latency
def main(args):
2024-04-28 21:06:22 +08:00
subjects = sorted(
[
f.split("_test.csv")[0]
for f in os.listdir(os.path.join(args.data_dir, "test"))
if "_test.csv" in f
]
)
all_cors = []
all_latencies = []
num_requests = 0
2024-04-28 21:06:22 +08:00
for subject in tqdm(subjects[: args.nsub]):
dev_df = pd.read_csv(
os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None
)[: args.ntrain]
test_df = pd.read_csv(
os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None
)
cors, acc, latency = evaluate(args, subject, dev_df, test_df)
all_cors.append(cors)
all_latencies.append(latency)
num_requests += len(test_df)
total_latency = np.sum(all_latencies)
print("Total latency: {:.3f}".format(total_latency))
weighted_acc = np.mean(np.concatenate(all_cors))
print("Average accuracy: {:.3f}".format(weighted_acc))
# Write results
with open(args.result_file, "a") as fout:
value = {
"task": "mmlu",
"backend": args.backend,
"num_gpus": 1,
"latency": round(total_latency, 3),
"accuracy": round(weighted_acc, 3),
"num_requests": num_requests,
"other": {
"nsub": args.nsub,
"parallel": args.parallel,
2024-04-28 21:06:22 +08:00
},
}
fout.write(json.dumps(value) + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ntrain", "-k", type=int, default=5)
parser.add_argument("--data_dir", "-d", type=str, default="data")
parser.add_argument("--save_dir", "-s", type=str, default="results")
parser.add_argument("--nsub", type=int, default=60)
args = add_common_sglang_args_and_parse(parser)
main(args)