[Fix] Fix select by ensuring each request has at least one token (#1318)
This commit is contained in:
@@ -178,19 +178,22 @@ class Req:
|
||||
def adjust_max_prefix_ids(self):
|
||||
self.fill_ids = self.origin_input_ids + self.output_ids
|
||||
input_len = len(self.fill_ids)
|
||||
max_prefix_len = input_len
|
||||
|
||||
# FIXME: To work around some bugs in logprob computation, we need to ensure each
|
||||
# request has at least one token. Later, we can relax this requirement and use `input_len`.
|
||||
max_prefix_len = input_len - 1
|
||||
|
||||
if self.sampling_params.max_new_tokens > 0:
|
||||
# Need at least one token to compute logits
|
||||
max_prefix_len = min(max_prefix_len, input_len - 1)
|
||||
|
||||
if self.return_logprob:
|
||||
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
||||
|
||||
if self.normalized_prompt_logprob is None:
|
||||
# Need at least two tokens to compute normalized logprob
|
||||
max_prefix_len = min(max_prefix_len, input_len - 2)
|
||||
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
||||
|
||||
max_prefix_len = max(max_prefix_len, 0)
|
||||
return self.fill_ids[:max_prefix_len]
|
||||
|
||||
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
|
||||
|
||||
@@ -2,8 +2,12 @@
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.utils import fetch_and_cache_jsonl
|
||||
|
||||
|
||||
def test_few_shot_qa():
|
||||
@@ -447,3 +451,67 @@ def test_chat_completion_speculative():
|
||||
)
|
||||
|
||||
gen_character_spec().sync()
|
||||
|
||||
|
||||
def test_hellaswag_select():
|
||||
"""Benchmark the accuracy of sgl.select on the HellaSwag dataset."""
|
||||
|
||||
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
|
||||
lines = fetch_and_cache_jsonl(url)
|
||||
|
||||
# Construct prompts
|
||||
def get_one_example(lines, i, include_answer):
|
||||
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
||||
if include_answer:
|
||||
ret += lines[i]["endings"][lines[i]["label"]]
|
||||
return ret
|
||||
|
||||
def get_few_shot_examples(lines, k):
|
||||
ret = ""
|
||||
for i in range(k):
|
||||
ret += get_one_example(lines, i, True) + "\n\n"
|
||||
return ret
|
||||
|
||||
num_questions = 200
|
||||
num_shots = 20
|
||||
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
||||
|
||||
questions = []
|
||||
choices = []
|
||||
labels = []
|
||||
for i in range(len(lines[:num_questions])):
|
||||
questions.append(get_one_example(lines, i, False))
|
||||
choices.append(lines[i]["endings"])
|
||||
labels.append(lines[i]["label"])
|
||||
arguments = [{"question": q, "choices": c} for q, c in zip(questions, choices)]
|
||||
|
||||
#####################################
|
||||
######### SGL Program Begin #########
|
||||
#####################################
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
@sgl.function
|
||||
def few_shot_hellaswag(s, question, choices):
|
||||
s += few_shot_examples + question
|
||||
s += sgl.select("answer", choices=choices)
|
||||
|
||||
#####################################
|
||||
########## SGL Program End ##########
|
||||
#####################################
|
||||
|
||||
# Run requests
|
||||
tic = time.time()
|
||||
rets = few_shot_hellaswag.run_batch(
|
||||
arguments,
|
||||
temperature=0,
|
||||
num_threads=64,
|
||||
progress_bar=True,
|
||||
)
|
||||
preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
|
||||
latency = time.time() - tic
|
||||
|
||||
# Compute accuracy
|
||||
accuracy = np.mean(np.array(preds) == np.array(labels))
|
||||
|
||||
return accuracy, latency
|
||||
|
||||
@@ -4,6 +4,7 @@ import base64
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
@@ -15,6 +16,7 @@ from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -260,3 +262,40 @@ class LazyImport:
|
||||
def __call__(self, *args, **kwargs):
|
||||
module = self._load()
|
||||
return module(*args, **kwargs)
|
||||
|
||||
|
||||
def fetch_and_cache_jsonl(url, cache_file="cached_data.jsonl"):
|
||||
"""Read and cache a jsonl file from a url."""
|
||||
|
||||
# Check if the cache file already exists
|
||||
if os.path.exists(cache_file):
|
||||
print("Loading data from cache...")
|
||||
with open(cache_file, "r") as f:
|
||||
data = [json.loads(line) for line in f]
|
||||
else:
|
||||
print("Downloading data from URL...")
|
||||
# Stream the response to show the progress bar
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status() # Check for request errors
|
||||
|
||||
# Total size of the file in bytes
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
chunk_size = 1024 # Download in chunks of 1KB
|
||||
|
||||
# Use tqdm to display the progress bar
|
||||
with open(cache_file, "wb") as f, tqdm(
|
||||
desc=cache_file,
|
||||
total=total_size,
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as bar:
|
||||
for chunk in response.iter_content(chunk_size=chunk_size):
|
||||
f.write(chunk)
|
||||
bar.update(len(chunk))
|
||||
|
||||
# Convert the data to a list of dictionaries
|
||||
with open(cache_file, "r") as f:
|
||||
data = [json.loads(line) for line in f]
|
||||
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user