[Fix] Fix select by ensuring each request has at least one token (#1318)

This commit is contained in:
Lianmin Zheng
2024-09-03 06:31:45 -07:00
committed by GitHub
parent 12cb115d38
commit 1e495e0847
4 changed files with 120 additions and 3 deletions

View File

@@ -2,8 +2,12 @@
import json
import re
import time
import numpy as np
import sglang as sgl
from sglang.utils import fetch_and_cache_jsonl
def test_few_shot_qa():
@@ -447,3 +451,67 @@ def test_chat_completion_speculative():
)
gen_character_spec().sync()
def test_hellaswag_select():
"""Benchmark the accuracy of sgl.select on the HellaSwag dataset."""
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
lines = fetch_and_cache_jsonl(url)
# Construct prompts
def get_one_example(lines, i, include_answer):
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
if include_answer:
ret += lines[i]["endings"][lines[i]["label"]]
return ret
def get_few_shot_examples(lines, k):
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
num_questions = 200
num_shots = 20
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
choices = []
labels = []
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
choices.append(lines[i]["endings"])
labels.append(lines[i]["label"])
arguments = [{"question": q, "choices": c} for q, c in zip(questions, choices)]
#####################################
######### SGL Program Begin #########
#####################################
import sglang as sgl
@sgl.function
def few_shot_hellaswag(s, question, choices):
s += few_shot_examples + question
s += sgl.select("answer", choices=choices)
#####################################
########## SGL Program End ##########
#####################################
# Run requests
tic = time.time()
rets = few_shot_hellaswag.run_batch(
arguments,
temperature=0,
num_threads=64,
progress_bar=True,
)
preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
latency = time.time() - tic
# Compute accuracy
accuracy = np.mean(np.array(preds) == np.array(labels))
return accuracy, latency